In [None]:
from typing import Final
import numpy as np
import plotly.graph_objects as go

# Set random seed for reproducibility
np.random.seed(42)

# Parameters
n: Final[int] = 300  # number of points

# 1. Create three arrays of normally distributed numbers
x1: Final[np.ndarray] = np.random.normal(0, 1, n)
x2: Final[np.ndarray] = np.random.normal(0, 1, n)

y: Final[np.ndarray] = np.random.choice(3, n)

lat: Final[np.ndarray] = (y - 1) * 4 + np.random.normal(0, .5, n) # latent variable

fig = go.Figure()
# Add scatter points with colors based on y
fig.add_trace(go.Scatter3d(
    x=x1,
    y=x2,
    z=lat,
    mode='markers',
    marker=dict(
        size=5,
        color=y,
        colorscale='Viridis',
        opacity=0.8
    ),
    name='Data points'
))

# Update layout
fig.update_layout(
    title='3D Classification with Decision Boundary',
    scene=dict(
        xaxis_title='x1',
        yaxis_title='x2',
        zaxis_title='lat',
    ),
    width=800,
    height=800
)

fig.show()

In [None]:
from itertools import product


r = np.zeros((n, n))
for i, j in product(range(n), range(n)):
    r[i, j] = y[i] == y[j]

x: Final[np.ndarray] = np.stack((x1, x2)).T

n_train: Final[int] = 200
x_train, x_test = x[:n_train], x[n_train:]
y_train, y_test = y[:n_train], y[n_train:]
r_train = r[:n_train][:, :n_train]
r_test_intra = r[n_train:][:, n_train:]
r_test_inter = r[:n_train][:, n_train:]


In [None]:
import logging
from pathlib import Path
from tabrel.sklearn_interface import TabRelClassifier
from tabrel.utils.config import ProjectConfig, ClassifierConfig, TrainingConfig

def run(X_train: np.ndarray, X_test: np.ndarray, use_rel: bool) -> None:
    config = ProjectConfig(
        training=TrainingConfig(
            batch_size=20,
            query_size=20,
            n_batches=4,
            lr=1e-4,
            n_epochs=50,
            log_dir=Path("out/logs"),
            log_level=logging.DEBUG,
            print_logs_to_console=True,
            checkpoints_dir=Path("out/checkpoints"),
            allow_dirs_exist=True,
            random_seed=42,
        ),
        model=ClassifierConfig(
            n_features=X_train.shape[1],
            d_embedding=20,
            d_model=8,
            nhead=1,
            dim_feedforward=1,
            num_layers=1,
            num_classes=len(np.unique(y)),
            activation="relu",
            rel=use_rel,
            dropout=0.,
        )
    )

    model = TabRelClassifier(config)
    model.fit(X=X_train, y=y_train, r=r_train)
    metrics = model.evaluate(X=X_test, r_inter=r_test_inter, r_intra=r_test_intra, y=y_test)
    print(metrics, f"rel: {use_rel}")

run(x_train, x_test, use_rel=True)
run(x_train, x_test, use_rel=False)

x_full = np.stack((x1, x2, lat)).T
run(x_full[:n_train], x_full[n_train:], use_rel=False)