In [None]:
from typing import Final
import numpy as np
import plotly.graph_objects as go

# Set random seed for reproducibility
np.random.seed(42)

# Parameters
n: Final[int] = 70  # number of points

# 1. Create three arrays of normally distributed numbers
x1: Final[np.ndarray] = np.random.normal(0, 1, n)
x2: Final[np.ndarray] = np.random.normal(0, 1, n)

y: Final[np.ndarray] = np.random.choice(3, n)

lat: Final[np.ndarray] = (y - 1) * 4 + np.random.normal(0, .5, n) # latent variable

fig = go.Figure()
# Add scatter points with colors based on y
fig.add_trace(go.Scatter3d(
    x=x1,
    y=x2,
    z=lat,
    mode='markers',
    marker=dict(
        size=5,
        color=y,
        colorscale='Viridis',
        opacity=0.8
    ),
    name='Data points'
))

# Update layout
fig.update_layout(
    title='3D Classification with Decision Boundary',
    scene=dict(
        xaxis_title='x1',
        yaxis_title='x2',
        zaxis_title='lat',
    ),
    width=800,
    height=800
)

fig.show()

In [None]:
from itertools import product


r = np.zeros((n, n))
for i, j in product(range(n), range(n)):
    r[i, j] = y[i] == y[j]

x: Final[np.ndarray] = np.stack((x1, x2)).T

n_train: Final[int] = 50
x_train, x_test = x[:n_train], x[n_train:]
y_train, y_test = y[:n_train], y[n_train:]
r_train = r[:n_train][:, :n_train]
r_test_intra = r[n_train:][:, n_train:]
r_test_inter = r[:n_train][:, n_train:]


In [None]:
import logging
from pathlib import Path
from tabrel.sklearn_interface import TabRelClassifier
from tabrel.utils.config import ProjectConfig, ClassifierConfig, TrainingConfig

def run(X_train: np.ndarray, X_test: np.ndarray, use_rel: bool) -> TabRelClassifier:
    config = ProjectConfig(
        training=TrainingConfig(
            batch_size=20,
            query_size=10,
            n_batches=3,
            lr=1e-4,
            n_epochs=50,
            log_dir=Path("out/logs"),
            log_level=logging.DEBUG,
            print_logs_to_console=True,
            checkpoints_dir=Path("out/checkpoints"),
            allow_dirs_exist=True,
            random_seed=42,
        ),
        model=ClassifierConfig(
            n_features=X_train.shape[1],
            d_embedding=20,
            d_model=8,
            nhead=1,
            dim_feedforward=1,
            num_layers=1,
            num_classes=len(np.unique(y)),
            activation="relu",
            rel=use_rel,
            dropout=0.,
        )
    )

    model = TabRelClassifier(config)
    model.fit(X=X_train, y=y_train, r=r_train)
    metrics = model.evaluate(X=X_test, r_inter=r_test_inter, r_intra=r_test_intra, y=y_test)
    print(metrics, f"rel: {use_rel}")
    return model

clf_rel = run(x_train, x_test, use_rel=True)
clf_norel = run(x_train, x_test, use_rel=False)

x_full = np.stack((x1, x2, lat)).T
clf_full_norel = run(x_full[:n_train], x_full[n_train:], use_rel=False)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def plot_matrix_with_y_track(matrix, y, title, xlabel, ylabel, 
                             cmap='viridis',
                             square: bool = True,
                             ):
    gridspec_kw = {"hspace": 0.}
    if square:
        gridspec_kw["height_ratios"] = [10, 1]
    else:
        gridspec_kw["height_ratios"] = [len(matrix) // 3, 1]
    fig, (ax0, ax1) = plt.subplots(
        nrows=2, ncols=1,
        figsize=(8, 9 if square else len(matrix) // 3 + 1),
        gridspec_kw=gridspec_kw,
        sharex=True,
    )
    cbar_ax = fig.add_axes([0.92, 0.2, 0.03, 0.65])
    sns.heatmap(
        matrix,
        ax=ax0,
        cmap=cmap,
        square=square,
        cbar_ax=cbar_ax
    )
    ax0.set_title(title)
    ax0.set_ylabel(ylabel)
    n_samples = matrix.shape[1]
    ticks = np.arange(0, n_samples, max(1, n_samples // 10))
    ax0.set_xticks(ticks)
    ax0.set_xticklabels(ticks)
    if square:
        ax0.set_yticks(ticks)
        ax0.set_yticklabels(ticks)
    ax0.tick_params(axis='x', which='both', bottom=False, top=False)
    # Y track
    y_bar = y.reshape(1, -1)
    sns.heatmap(
        y_bar,
        ax=ax1,
        cmap='viridis',
        cbar=False,
    )
    ax1.set_yticks([])
    ax1.set_ylabel("y")
    ax1.set_xlabel(xlabel)
    plt.show()

In [None]:
import torch

attention_maps = []
q_matrices = []
k_matrices = []
v_matrices = []

def patch_attention_forward(model):
    for module in model.modules():
        if module.__class__.__name__ == "RelationalMultiheadAttention":
            if not hasattr(module, "_original_forward"):
                def new_forward(self, s, attn_mask):
                    n_samples, d_model = s.x.shape
                    q, k, v = (
                        (
                            proj(s.x)
                            .reshape(n_samples, self.num_heads, self.head_dim)
                            .transpose(0, 1)
                        )
                        for proj in (self.q_proj, self.k_proj, self.v_proj)
                    )
                    attn_scores = (q @ k.transpose(-2, -1)) * self.scaling_factor
                    if self.rel:
                        attn_scores += s.r.unsqueeze(0) * self.r_scale + self.r_bias
                    attn_scores = attn_scores.masked_fill(attn_mask != 0, -torch.inf)
                    weights = torch.softmax(attn_scores, dim=-1)
                    self._last_attention = weights.detach().cpu()
                    self._last_q = q.detach().cpu()
                    self._last_k = k.detach().cpu()
                    self._last_v = v.detach().cpu()
                    weights = self.dropout(weights)
                    res = weights @ v
                    res = res.transpose(0, 1).flatten(1)
                    return self.out_proj(res)
                module._original_forward = module.forward
                module.forward = new_forward.__get__(module, module.__class__)

def register_attention_hooks(model):
    for module in model.modules():
        if module.__class__.__name__ == "RelationalMultiheadAttention":
            if getattr(module, "_hook_registered", False):
                continue
            def hook(module, input, output):
                attention_maps.append(module._last_attention)
                q_matrices.append(module._last_q)
                k_matrices.append(module._last_k)
                v_matrices.append(module._last_v)
            module.register_forward_hook(hook)
            module._hook_registered = True

for title, clf, X, r_inter, r_intra in (
    ("full X", clf_full_norel, x_full[n_train:], np.zeros_like(r_test_inter), np.zeros_like(r_test_intra)),
    ("X + rel", clf_rel, x_test, r_test_inter, r_test_intra),
):
    model = clf.fit_data_.model.eval()
    patch_attention_forward(model)
    register_attention_hooks(model)
    xb = clf.fit_data_.x_train
    yb = clf.fit_data_.y_train
    xq = torch.tensor(X, dtype=torch.float32)
    n_train, n_query = len(xb), len(xq)
    r = torch.eye(n_train + n_query)
    r[:n_train, :n_train] = clf.fit_data_.r_train
    r[n_train:, n_train:] = torch.tensor(r_intra)
    r[:n_train, n_train:] = torch.tensor(r_inter)
    from tabrel.utils.linalg import mirror_triu
    r = mirror_triu(r)
    # Run inference
    attention_maps.clear()
    q_matrices.clear()
    k_matrices.clear()
    v_matrices.clear()
    with torch.no_grad():
        _ = model(xb, yb, xq, r)
    # Visualize
    if attention_maps:
        for i_layer, i_head in product(range(len(attention_maps)), range(len(attention_maps[0]))):
            attn = attention_maps[i_layer][i_head].numpy()
            q = q_matrices[i_layer][i_head].numpy().T  # shape: (head_dim, n_samples)
            k = k_matrices[i_layer][i_head].numpy().T
            v = v_matrices[i_layer][i_head].numpy().T
            y_bar = y
            plot_matrix_with_y_track(attn, y_bar, f"Attention Map ({title}, Head {i_head}, Layer {i_layer})", "Key Index", "Query Index")
            plot_matrix_with_y_track(q, y_bar, f"Q Matrix ({title}, Head {i_head}, Layer {i_layer})", "Sample Index", "Q dim", square=False)
            plot_matrix_with_y_track(k, y_bar, f"K Matrix ({title}, Head {i_head}, Layer {i_layer})", "Sample Index", "K dim", square=False)
            plot_matrix_with_y_track(v, y_bar, f"V Matrix ({title}, Head {i_head}, Layer {i_layer})", "Sample Index", "V dim", square=False)
    else:
        print("No attention maps captured.")

In [None]:
attention_maps