In [None]:
%load_ext autoreload
%autoreload 2
%load_ext tensorboard

import numpy as np
import torch
import torch.nn as nn
from sklearn.datasets import fetch_california_housing

class Mlp(nn.Module):
    def __init__(self, 
                 in_dim: int,
                 hidden_dim: int = 64,
                 out_dim: int = 8,
                 dropout: float = .1,
                 ) -> None:
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, out_dim),
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

def to_torch(arr: np.ndarray) -> torch.Tensor:
    return torch.tensor(arr, dtype=torch.float32)


def test_mlp() -> None:
    housing = fetch_california_housing()
    x, y =to_torch(housing.data), to_torch(housing.target)
    mlp = Mlp(in_dim=x.shape[1])
    print(mlp(x).shape)

test_mlp()

In [None]:
class RelNwMlpRegr(nn.Module):
    """
    Relationship-aware Nadaraya-Watson kernel regression
    """

    def __init__(self, init_sigma: float, init_r_scale: float, input_dim: int, dropout: float) -> None:
        super().__init__()
        self.sigma = nn.Parameter(torch.tensor([float(init_sigma)]))
        self.r_scale = nn.Parameter(torch.tensor([float(init_r_scale)]))
        self.x_tranform = Mlp(in_dim=input_dim, dropout=dropout)


    def forward(
        self,
        x_backgnd: torch.Tensor,  # (n_backgnd, n_features)
        y_backgnd: torch.Tensor,  # (n_backgnd,)
        x_query: torch.Tensor,  # (n_query, n_features)
        r: torch.Tensor,  # (n_query, n_backgnd)
    ) -> torch.Tensor:
        """
        Returns predicted y: (n_query,)
        """
        x_backgnd, x_query = self.x_tranform(x_backgnd), self.x_tranform(x_query)
        
        n_query, n_backgnd = r.shape
        x_query_exp = x_query.unsqueeze(1).expand(
            -1, n_backgnd, -1
        )  # (n_query, n_backgnd, n_features)
        x_backgnd_exp = x_backgnd.unsqueeze(0).expand(
            n_query, -1, -1
        )  # (n_query, n_backgnd, n_features)

        dists = torch.norm(x_query_exp - x_backgnd_exp, dim=2)

        # Compute kernel weights: (n_query, n_backgnd)
        k_vals = torch.exp(-dists / self.sigma + self.r_scale * r)

        # Normalize weights
        k_sum = k_vals.sum(dim=1, keepdim=True) + 1e-8  # avoid division by zero
        weights = k_vals / k_sum  # (n_query, n_backgnd)

        # Weighted sum of y_backgnd: (n_query,)
        y_pred = torch.matmul(weights, y_backgnd)

        return y_pred


In [None]:
from typing import Final

import numpy as np

from tabrel.benchmark.nw_regr import generate_multidim_noisy_data, make_random_r, train_nw_arbitrary, NwModelConfig

n_samples: Final[int] = 1000
seed: Final[int] = 42
x_dim: Final[int] = 8
x_np, y_np, c = generate_multidim_noisy_data(n_samples, n_clusters=3, x_dim=x_dim, seed=seed)

n_query = n_back = n_samples // 3
n_train = n_query + n_back
n_val = n_samples - n_train
back_ids = np.arange(n_back)
query_ids = np.arange(n_query) + n_back
train_ids = np.arange(n_train)
val_ids = np.arange(n_val) + n_train

r = make_random_r(seed, c)

n_epochs: Final[int] = 5000
mse, r2, _, _ = train_nw_arbitrary(
    x_backgnd=x_np[back_ids],
    y_backgnd=y_np[back_ids],
    x_query=x_np[query_ids],
    y_query=y_np[query_ids],
    x_val=x_np[val_ids],
    y_val=y_np[val_ids],
    r_query_backgnd=r[query_ids][:, back_ids],
    r_val_nonval=r[val_ids][:, train_ids],
    cfg=NwModelConfig(input_dim=x_dim, trainable_weights_matrix=True,),
    lr=1e-3,
    n_epochs=n_epochs
)

print(mse, r2)

In [None]:
from datetime import datetime
log_dir: Final[str] = "tb_logs" + datetime.isoformat(datetime.now()).replace(":", "_")
print(log_dir)

In [None]:
%tensorboard --logdir {log_dir}
from sklearn.metrics import mean_squared_error, r2_score
from tqdm import tqdm
from torch.utils.tensorboard.writer import SummaryWriter



x_torch, y_torch = to_torch(x_np), to_torch(y_np)
x_train, y_train = x_torch[train_ids], y_torch[train_ids]
x_mean, x_std = x_train.mean(dim=0), x_train.std(dim=0)
x_norm = (x_torch - x_mean) / (x_std + 1e-8)

x_b, y_b = x_norm[back_ids], y_torch[back_ids]
x_q, y_q = x_norm[query_ids], y_torch[query_ids]
x_v, y_v = x_norm[val_ids], y_torch[val_ids]

r_q_b = to_torch(r)[query_ids][:, back_ids]
r_val_nval = to_torch(r)[val_ids][:, train_ids]

torch.random.manual_seed(seed)
model = RelNwMlpRegr(
    init_sigma=1.,
    init_r_scale=1.,
    input_dim=x_dim,
    dropout=.2,
)
optimizer = torch.optim.AdamW(model.parameters(), weight_decay=1e-4)
loss_fn = torch.nn.MSELoss()
writer = SummaryWriter(log_dir=log_dir)

model.train()
for epoch in tqdm(range(n_epochs)):
    optimizer.zero_grad()
    y_pred = model(x_b, y_b, x_q, r_q_b)
    loss = loss_fn(y_pred, y_q)
    loss.backward()
    optimizer.step()
    writer.add_scalar("train/loss", loss.item(), epoch)

    if epoch % 20 == 0:
        model.eval()

        with torch.no_grad():
            y_v_pred = model(
                x_norm[train_ids],
                y_train,
                x_v,
                r_val_nval,
            )
            y_v_pred_np = y_v_pred.numpy()
            y_v_np = y_v.numpy()

        writer.add_scalar("val/mse", mean_squared_error(y_v_np, y_v_pred_np), epoch)
        writer.add_scalar("val/r2", r2_score(y_v_np, y_v_pred_np), epoch)

        model.train()


In [None]:
# TODO optuna for dropout and weights decay, for hidden layer size; other activations?
# TODO transform Y (with another MLP?)
# TODO another synthetic?