In [None]:
from typing import Final
import numpy as np
import plotly.graph_objects as go

# Set random seed for reproducibility
np.random.seed(42)

# Parameters
n: Final[int] = 1000  # number of points

# 1. Create three arrays of normally distributed numbers
x1: Final[np.ndarray] = np.random.normal(0, 1, n)
x2: Final[np.ndarray] = np.random.normal(0, 1, n)
lat: Final[np.ndarray] = np.random.normal(0, 1, n)  # latent variable

# 2. Create array y = 1 if (x1 + x2 + lat) > 0 else 0
y: Final[np.ndarray] = np.where(x1 + x2 + lat > 0, 1, 0)

# 3. Create the 3D scatter plot
fig = go.Figure()

# Add scatter points with colors based on y
fig.add_trace(go.Scatter3d(
    x=x1,
    y=x2,
    z=lat,
    mode='markers',
    marker=dict(
        size=5,
        color=y,
        colorscale='Viridis',  # or any other colorscale
        opacity=0.8
    ),
    name='Data points'
))

# 4. Add a semi-transparent decision boundary hyperplane (x + y + z = 0)
# Create a grid for the plane
plane_size = 3
x_plane = np.linspace(-plane_size, plane_size, 20)
y_plane = np.linspace(-plane_size, plane_size, 20)
x_plane, y_plane = np.meshgrid(x_plane, y_plane)
z_plane = -x_plane - y_plane  # z = -x - y (from x + y + z = 0)

fig.add_trace(go.Surface(
    x=x_plane,
    y=y_plane,
    z=z_plane,
    colorscale='Greys',
    opacity=0.5,
    showscale=False,
    name='Decision boundary'
))

# Update layout
fig.update_layout(
    title='3D Classification with Decision Boundary',
    scene=dict(
        xaxis_title='x1',
        yaxis_title='x2',
        zaxis_title='lat',
    ),
    width=800,
    height=800
)

fig.show()

In [None]:
from itertools import product


r = np.zeros((n, n))
for i, j in product(range(n), range(n)):
    r[i, j] = np.abs(lat[i] - lat[j])

x: Final[np.ndarray] = np.stack((x1, x2)).T
print(f"x shape: {x.shape}, y shape: {y.shape}, r shape: {r.shape}")

n_train: Final[int] = 800
x_train, y_train, x_test, y_test = x[:n_train], y[:n_train], x[n_train:], y[n_train:]
r_train = r[:n_train, :n_train]
r_test_intra = r[n_train:, n_train:]
r_test_inter = r[:n_train, n_train:]
print(f"r_test_intra shape: {r_test_intra.shape}, r_test_inter shape: {r_test_inter.shape}")

In [None]:
from dataclasses import replace
from tabrel.sklearn_interface import TabRelClassifier
from tabrel.utils.config import ProjectConfig

config = ProjectConfig.default()
for use_rel in (True, False):
    config = replace(config, 
                     model=replace(config.model,
                                    n_features=2,
                                    rel=use_rel))
    model = TabRelClassifier(config)
    model.fit(X=x_train, y=y_train, r=r_train)
    metrics = model.evaluate(X=x_test, r_inter=r_test_inter, r_intra=r_test_intra, y=y_test)
    print(f"use_rel={use_rel}")
    print(metrics)