In [None]:
from typing import Final
from itertools import product

import numpy as np
import plotly.graph_objects as go

# Set random seed for reproducibility
np.random.seed(42)

# Parameters
n: Final[int] = 900  # number of points

# 1. Create three arrays of normally distributed numbers
x1: Final[np.ndarray] = np.random.normal(0, 1, n)
x2: Final[np.ndarray] = np.random.normal(0, 1, n)

y: Final[np.ndarray] = np.random.choice(3, n)

lat: Final[np.ndarray] = (y - 1) * 4 + np.random.normal(0, .5, n) # latent variable


r = np.zeros((n, n))
for i, j in product(range(n), range(n)):
    r[i, j] = y[i] == y[j]
r_numpy = r.copy()

x: Final[np.ndarray] = np.stack((x1, x2)).T

n_train: Final[int] = 200
x_train, x_test = x[:n_train], x[n_train:]
y_train, y_test = y[:n_train], y[n_train:]
r_train = r[:n_train][:, :n_train]
r_test_intra = r[n_train:][:, n_train:]
r_test_inter = r[:n_train][:, n_train:]

fig = go.Figure()
# Add scatter points with colors based on y
fig.add_trace(go.Scatter3d(
    x=x1,
    y=x2,
    z=lat,
    mode='markers',
    marker=dict(
        size=5,
        color=y,
        colorscale='Viridis',
        opacity=0.8
    ),
    name='Data points'
))

# Update layout
fig.update_layout(
    title='3D Classification with Decision Boundary',
    scene=dict(
        xaxis_title='x1',
        yaxis_title='x2',
        zaxis_title='lat',
    ),
    width=800,
    height=800
)

fig.show()

In [None]:
import logging
from pathlib import Path
from tabrel.sklearn_interface import TabRelClassifier
from tabrel.utils.config import ProjectConfig, ClassifierConfig, TrainingConfig

def run(X_train: np.ndarray, X_test: np.ndarray, use_rel: bool) -> TabRelClassifier:
    config = ProjectConfig(
        training=TrainingConfig(
            backgnd_size=100,
            query_size=25,
            n_batches=4,
            lr=1e-4,
            n_epochs=50,
            log_dir=Path("out/logs"),
            log_level=logging.DEBUG,
            print_logs_to_console=True,
            checkpoints_dir=Path("out/checkpoints"),
            allow_dirs_exist=True,
            random_seed=42,
        ),
        model=ClassifierConfig(
            n_features=X_train.shape[1],
            d_embedding=20,
            d_model=8,
            nhead=1,
            dim_feedforward=1,
            num_layers=1,
            num_classes=len(np.unique(y)),
            activation="relu",
            rel=use_rel,
            dropout=0.,
        )
    )

    model = TabRelClassifier(config)
    model.fit(X=X_train, y=y_train, r=r_train)
    metrics = model.evaluate(X=X_test, r_inter=r_test_inter, r_intra=r_test_intra, y=y_test)
    print(metrics, f"rel: {use_rel}")
    return model

clf_rel = run(x_train, x_test, use_rel=True)
clf_norel = run(x_train, x_test, use_rel=False)

x_full = np.stack((x1, x2, lat)).T
clf_full_norel = run(x_full[:n_train], x_full[n_train:], use_rel=False)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def plot_matrix_with_y_track(matrix, y, title, xlabel, ylabel, 
                             cmap='viridis',
                             square: bool = True,
                             ):
    gridspec_kw = {"hspace": 0.}
    if square:
        gridspec_kw["height_ratios"] = [10, 1]
    else:
        gridspec_kw["height_ratios"] = [len(matrix) // 3, 1]
    fig, (ax0, ax1) = plt.subplots(
        nrows=2, ncols=1,
        figsize=(8, 9 if square else len(matrix) // 3 + 1),
        gridspec_kw=gridspec_kw,
        sharex=True,
    )
    cbar_ax = fig.add_axes([0.92, 0.2, 0.03, 0.65])
    sns.heatmap(
        matrix,
        ax=ax0,
        cmap=cmap,
        square=square,
        cbar_ax=cbar_ax
    )
    ax0.set_title(title)
    ax0.set_ylabel(ylabel)
    n_samples = matrix.shape[1]
    ticks = np.arange(0, n_samples, max(1, n_samples // 10))
    ax0.set_xticks(ticks)
    ax0.set_xticklabels(ticks)
    if square:
        ax0.set_yticks(ticks)
        ax0.set_yticklabels(ticks)
    ax0.tick_params(axis='x', which='both', bottom=False, top=False)
    # Y track
    y_bar = y.reshape(1, -1)
    sns.heatmap(
        y_bar,
        ax=ax1,
        cmap='viridis',
        cbar=False,
    )
    ax1.set_yticks([])
    ax1.set_ylabel("y")
    ax1.set_xlabel(xlabel)
    plt.show()

In [None]:
import torch

attention_maps = []
q_matrices = []
k_matrices = []
v_matrices = []

def patch_attention_forward(model):
    for module in model.modules():
        if module.__class__.__name__ == "RelationalMultiheadAttention":
            if not hasattr(module, "_original_forward"):
                def new_forward(self, s, attn_mask):
                    n_samples, d_model = s.x.shape
                    q, k, v = (
                        (
                            proj(s.x)
                            .reshape(n_samples, self.num_heads, self.head_dim)
                            .transpose(0, 1)
                        )
                        for proj in (self.q_proj, self.k_proj, self.v_proj)
                    )
                    attn_scores = (q @ k.transpose(-2, -1)) * self.scaling_factor
                    if self.rel:
                        attn_scores += s.r.unsqueeze(0) * self.r_scale + self.r_bias
                    attn_scores = attn_scores.masked_fill(attn_mask != 0, -torch.inf)
                    weights = torch.softmax(attn_scores, dim=-1)
                    self._last_attention = weights.detach().cpu()
                    self._last_q = q.detach().cpu()
                    self._last_k = k.detach().cpu()
                    self._last_v = v.detach().cpu()
                    weights = self.dropout(weights)
                    res = weights @ v
                    res = res.transpose(0, 1).flatten(1)
                    return self.out_proj(res)
                module._original_forward = module.forward
                module.forward = new_forward.__get__(module, module.__class__)

def register_attention_hooks(model):
    for module in model.modules():
        if module.__class__.__name__ == "RelationalMultiheadAttention":
            if getattr(module, "_hook_registered", False):
                continue
            def hook(module, input, output):
                attention_maps.append(module._last_attention)
                q_matrices.append(module._last_q)
                k_matrices.append(module._last_k)
                v_matrices.append(module._last_v)
            module.register_forward_hook(hook)
            module._hook_registered = True

for title, clf, X, r_inter, r_intra in (
    ("full X", clf_full_norel, x_full[n_train:], np.zeros_like(r_test_inter), np.zeros_like(r_test_intra)),
    ("X + rel", clf_rel, x_test, r_test_inter, r_test_intra),
):
    model = clf.fit_data_.model.eval()
    patch_attention_forward(model)
    register_attention_hooks(model)
    xb = clf.fit_data_.x_train
    yb = clf.fit_data_.y_train
    xq = torch.tensor(X, dtype=torch.float32)
    n_train, n_query = len(xb), len(xq)
    r = torch.eye(n_train + n_query)
    r[:n_train, :n_train] = clf.fit_data_.r_train
    r[n_train:, n_train:] = torch.tensor(r_intra)
    r[:n_train, n_train:] = torch.tensor(r_inter)
    from tabrel.utils.linalg import mirror_triu
    r = mirror_triu(r)
    # Run inference
    attention_maps.clear()
    q_matrices.clear()
    k_matrices.clear()
    v_matrices.clear()
    with torch.no_grad():
        _ = model(xb, yb, xq, r)
    # Visualize
    if attention_maps:
        for i_layer, i_head in product(range(len(attention_maps)), range(len(attention_maps[0]))):
            attn = attention_maps[i_layer][i_head].numpy()
            q = q_matrices[i_layer][i_head].numpy().T  # shape: (head_dim, n_samples)
            k = k_matrices[i_layer][i_head].numpy().T
            v = v_matrices[i_layer][i_head].numpy().T
            y_bar = y
            plot_matrix_with_y_track(attn, y_bar, f"Attention Map ({title}, Head {i_head}, Layer {i_layer})", "Key Index", "Query Index")
            plot_matrix_with_y_track(q, y_bar, f"Q Matrix ({title}, Head {i_head}, Layer {i_layer})", "Sample Index", "Q dim", square=False)
            plot_matrix_with_y_track(k, y_bar, f"K Matrix ({title}, Head {i_head}, Layer {i_layer})", "Sample Index", "K dim", square=False)
            plot_matrix_with_y_track(v, y_bar, f"V Matrix ({title}, Head {i_head}, Layer {i_layer})", "Sample Index", "V dim", square=False)
    else:
        print("No attention maps captured.")

In [None]:
attention_maps

# GAT

In [None]:
import torch
from torch_geometric.data import Data
from torch_geometric.nn import GATv2Conv
import torch.nn.functional as F

def np_to_geometric(x_: np.ndarray, y_: np.ndarray, r_: np.ndarray) -> Data:
    x_tensor = torch.tensor(x_, dtype=torch.float)
    y_tensor = torch.tensor(y_, dtype=torch.long)

    # Build edge index from r matrix (only where r[i,j] == 1)
    edge_index = torch.nonzero(torch.tensor(r_, dtype=torch.long)).T

    return Data(x=x_tensor, y=y_tensor, edge_index=edge_index)

data = np_to_geometric(x, y, r_numpy)
data_train = np_to_geometric(x_train, y_train, r_train)

In [None]:
from torch.nn import Module
from torch_geometric.nn import GATv2Conv

class GATv2Net(Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=1):
        super().__init__()
        self.gat1 = GATv2Conv(in_channels, hidden_channels, heads=heads)
        self.gat2 = GATv2Conv(hidden_channels * heads, out_channels, heads=1)

    def forward(self, x, edge_index):
        x = self.gat1(x, edge_index)
        x = F.relu(x)
        x = self.gat2(x, edge_index)
        return x


In [None]:
# Train/test split
train_mask = torch.zeros(n, dtype=torch.bool)
train_mask[:n_train] = True
test_mask = ~train_mask

model = GATv2Net(in_channels=2, hidden_channels=8, out_channels=3)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = torch.nn.CrossEntropyLoss()

# Training loop
for epoch in range(200):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = loss_fn(out[train_mask], data.y[train_mask])
    # out = model(data_train.x, data_train.edge_index)
    # loss = loss_fn(out, data_train.y)
    loss.backward()
    optimizer.step()

    if epoch % 20 == 0:
        print(f"Epoch {epoch} - Loss: {loss.item():.4f}")


In [None]:
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

# Evaluate model
model.eval()
with torch.no_grad():
    out = model(data.x, data.edge_index)
    pred = out[test_mask].argmax(dim=1)
    true = data.y[test_mask]

# Convert to NumPy
y_pred = pred.cpu().numpy()
y_true = true.cpu().numpy()

# Compute metrics
accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred, average=None)
recall = recall_score(y_true, y_pred, average=None)
f1 = f1_score(y_true, y_pred, average=None)

# Print results
print(f"Test Accuracy:  {accuracy:.4f}")
print(f"Test Precision: {precision}")
print(f"Test Recall:    {recall}")
print(f"Test F1 Score:  {f1}")


# RelMHANet

In [None]:
import torch
import torch.nn as nn
from tabrel.model import RelationalMultiheadAttention, SampleWithRelations

class RelMHANet(nn.Module):
    def __init__(self, in_dim, embed_dim, num_heads, num_classes, num_layers=2, dropout=0.2, rel=True):
        super().__init__()
        self.input_proj = nn.Linear(in_dim, embed_dim)
        self.attn_layers = nn.ModuleList([
            RelationalMultiheadAttention(embed_dim, num_heads, dropout, rel)
            for _ in range(num_layers)
        ])
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x, r, attn_mask=None):
        if attn_mask is None:
            attn_mask = torch.zeros_like(r)
        
        out = self.input_proj(x)
        for attn_layer in self.attn_layers:
            s = SampleWithRelations(out, r)
            out = attn_layer(s, attn_mask)
        
        logits = self.classifier(out)
        return logits


n_train: Final[int] = n // 3 * 2
n_backgnd: Final[int] = n // 3

# Prepare tensors
x_tensor = torch.tensor(x, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.long)
r_tensor = torch.tensor(r_numpy, dtype=torch.float32)

train_mask = torch.zeros(n, dtype=torch.bool)
train_mask[:n_train] = True
test_mask = ~train_mask

backgnd_mask = train_mask.clone()
backgnd_mask[n_backgnd:] = False
probe_mask = train_mask.clone()
probe_mask[:n_backgnd] = False

xy_train = torch.cat(
    [
        x_tensor,
        y_tensor.masked_fill(~backgnd_mask, 0).unsqueeze(1)
    ], 1)

in_dim = xy_train.shape[1]
embed_dim = 8
num_heads = 2
num_classes = len(np.unique(y))

torch.manual_seed(42)
model = RelMHANet(in_dim=in_dim, embed_dim=embed_dim, num_heads=num_heads, num_classes=num_classes, rel=True)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

# Training loop
for epoch in range(200):
    model.train()
    optimizer.zero_grad()
    logits = model(xy_train, r_tensor)
    loss = loss_fn(logits[probe_mask], y_tensor[probe_mask])
    loss.backward()
    optimizer.step()
    if epoch % 20 == 0:
        print(f"Epoch {epoch} - Loss: {loss.item():.4f}")

# Evaluation
model.eval()
with torch.no_grad():
    logits = model(xy_train, r_tensor)
    pred = logits[test_mask].argmax(dim=1)
    true = y_tensor[test_mask]

    y_pred = pred.cpu().numpy()
    y_true = true.cpu().numpy()

    from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average=None)
    recall = recall_score(y_true, y_pred, average=None)
    f1 = f1_score(y_true, y_pred, average=None)

    print(f"Test Accuracy:  {accuracy:.4f}")
    print(f"Test Precision: {precision}")
    print(f"Test Recall:    {recall}")
    print(f"Test F1 Score:  {f1}")

In [None]:
from tabrel.dataset import QueryUniqueBatchDataset

# Split data
n_test = n // 9
n_backgnd = 2 * n_test
n_train = n - n_test

x_tensor = torch.tensor(x, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.long)
r_tensor = torch.tensor(r_numpy, dtype=torch.float32)

train_mask = torch.zeros(n, dtype=torch.bool)
train_mask[:n_train] = True
test_mask = ~train_mask

x_train, y_train, r_train = x_tensor[train_mask], y_tensor[train_mask], r_tensor[train_mask][:, train_mask]
x_test, y_test = x_tensor[test_mask], y_tensor[test_mask]

# QueryUniqueBatchDataset setup
query_size = n_test
back_size = n_backgnd
n_batches = 6
random_state = 42

train_dataset = QueryUniqueBatchDataset(
    x=x_train,
    y=y_train,
    r=r_train,
    query_size=query_size,
    backgnd_size=back_size,
    n_batches=n_batches,
    random_state=random_state,
)

in_dim = x_train.shape[1] + 1
embed_dim = 8
num_heads = 2
num_classes = len(np.unique(y))

torch.manual_seed(random_state)
model = RelMHANet(in_dim=in_dim, embed_dim=embed_dim, num_heads=num_heads, num_classes=num_classes, rel=True)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

# Training loop
for epoch in range(500):
    model.train()
    total_loss = 0
    for xb, yb, xq, yq, r in train_dataset:
        # Construct input: concatenate xb, yb, xq, zeros_like(yq)
        yb_input = yb.unsqueeze(1)
        yq_input = torch.zeros_like(yq).unsqueeze(1)
        x_input = torch.cat([torch.cat([xb, yb_input], 1), torch.cat([xq, yq_input], 1)], 0)
        optimizer.zero_grad()
        logits = model(x_input, r)
        loss = loss_fn(logits[len(yb):], yq)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    if epoch % 20 == 0:
        print(f"Epoch {epoch} - Loss: {total_loss / n_batches:.4f}")

# Inference: use last background and test samples as query
model.eval()
with torch.no_grad():
    # Use last batch's background indices for inference
    # Get background indices from train_dataset
    g = torch.Generator().manual_seed(random_state)
    perm = torch.randperm(len(x_train), generator=g)
    back_idx = perm[query_size * n_batches : query_size * n_batches + back_size]
    xb = x_train[back_idx]
    yb = y_train[back_idx]
    xq = x_test
    yq = torch.zeros_like(y_test)
    s_idx = torch.cat([torch.arange(len(x_train))[back_idx], torch.arange(len(x_train), len(x_train) + len(x_test))])
    # Build r for inference: background + test
    r_inf = torch.zeros(len(xb) + len(xq), len(xb) + len(xq))
    r_inf[:len(xb), :len(xb)] = r_train[back_idx][:, back_idx]
    # No relations between background and test, nor within test
    x_input = torch.cat([torch.cat([xb, yb.unsqueeze(1)], 1), torch.cat([xq, yq.unsqueeze(1)], 1)], 0)
    logits = model(x_input, r_inf)
    pred = logits[len(xb):].argmax(dim=1)
    true = y_test

    y_pred = pred.cpu().numpy()
    y_true = true.cpu().numpy()

    from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average=None)
    recall = recall_score(y_true, y_pred, average=None)
    f1 = f1_score(y_true, y_pred, average=None)

    print(f"Test Accuracy:  {accuracy:.4f}")
    print(f"Test Precision: {precision}")
    print(f"Test Recall:    {recall}")
    print(f"Test F1 Score:  {f1}")

# RelMHARegressor

In [None]:
from tabrel.train import train_relnet
from tabrel.benchmark.nw_regr import generate_toy_regr_data

x_numpy, y_numpy, clusters = generate_toy_regr_data(n_samples=900, n_clusters=3, seed=42)
r_numpy = (clusters.unsqueeze(1) == clusters.unsqueeze(0)).float().numpy()

mse, r2, mae = train_relnet(
    x=x_numpy,
    y=y_numpy,
    r=r_numpy,
    backgnd_indices=np.arange(0, n_backgnd),
    query_indices=np.arange(n_backgnd, n_train),
    val_indices=np.arange(n_train, n),
    lr=0.01,
    n_epochs=200,
)
print(f"Val MSE: {mse:.4f} | R²: {r2:.4f} | MAE: {mae:.4f}")


In [13]:
x_numpy

tensor([[ 0.7645],
        [ 0.8300],
        [-0.2343],
        [ 0.9186],
        [-0.2191],
        [ 0.2018],
        [-0.4869],
        [ 0.5873],
        [ 0.8815],
        [-0.7336],
        [ 0.8692],
        [ 0.1872],
        [ 0.7388],
        [ 0.1354],
        [ 0.4822],
        [-0.1412],
        [ 0.7709],
        [ 0.1478],
        [-0.4668],
        [ 0.2549],
        [-0.4607],
        [-0.1173],
        [-0.4062],
        [ 0.6634],
        [-0.7894],
        [-0.4610],
        [-0.2824],
        [-0.6013],
        [ 0.0944],
        [-0.9877],
        [ 0.9031],
        [-0.8495],
        [ 0.7720],
        [ 0.1664],
        [-0.3247],
        [ 0.6179],
        [ 0.1559],
        [ 0.8080],
        [ 0.1093],
        [-0.3154],
        [ 0.2687],
        [-0.2712],
        [ 0.4209],
        [ 0.8928],
        [ 0.5781],
        [-0.4372],
        [ 0.5773],
        [ 0.1789],
        [ 0.5078],
        [-0.6095],
        [-0.9899],
        [-0.3864],
        [-0.