In [None]:
# %cd ../src

In [None]:
%load_ext tensorboard
%load_ext autoreload
%autoreload 2
from typing import Final

import matplotlib.pyplot as plt
import numpy as np
import optuna
import pandas as pd
import torch

from tabrel.train import train_relnet
from tabrel.benchmark.nw_regr import make_r, make_random_r, generate_toy_regr_data

n_samples: Final[int] = 300
seed: Final[int] = 42
x, y, c = generate_toy_regr_data(n_samples=n_samples, n_clusters=3, seed=seed, distr="uniform", y_func="square")
plt.scatter(x, y, c=c)

c_np = c.numpy()
r_deterministic = make_r(c_np)
r_random = make_random_r(seed=seed, clusters=c_np)

In [None]:
logdir = "tb_logs"

In [None]:
# %tensorboard --logdir {logdir}
n_back: Final[int] = n_samples // 3
n_query: Final[int] = n_samples // 3
n_train: Final[int] = n_back + n_query
n_val: Final[int] = n_samples - n_train

back_inds: Final[np.ndarray] = np.arange(n_back)
query_inds: Final[np.ndarray] = np.arange(n_query) + n_back
val_inds: Final[np.ndarray] = np.arange(n_val) + n_train

x_np, y_np = x.numpy(), y.numpy()

def train_relnet_shorthand(r: np.ndarray, seed: int, plot: bool,
                           lr: float = 1e-2, lr_decay: float = 0.9,) -> float:
    torch.random.manual_seed(seed)
    mse, r2, _, y_val_pred, y_val_true = train_relnet(
        x_np,
        y_np,
        r=r,
        backgnd_indices=back_inds,
        query_indices=query_inds,
        val_indices=val_inds,
        lr=lr,
        n_epochs=2000,
        progress_bar=False,
        print_loss=False,
        lr_decay=lr_decay,
        lr_decay_step=100,
        tb_logdir=logdir,
    )
    # print(f"mse: {mse:.3f}\tr2: {r2:.3f}")
    if plot:
        plt.scatter(x_np[val_inds], y_val_true, c=c_np[val_inds])
        plt.scatter(x_np[val_inds], y_val_pred, c=c_np[val_inds], marker="*")
    return mse

def objective(trial: optuna.Trial):
    lr_decay = trial.suggest_float("lr_decay", .85, .95)
    lr = trial.suggest_float("lr", 1e-4, 1e-1)
    return train_relnet_shorthand(r_deterministic, seed=42, plot=False, lr=lr, lr_decay=lr_decay)

In [None]:
train_relnet_shorthand(r_deterministic, seed=42, plot=True)

In [None]:
from datetime import datetime

study = optuna.create_study(direction="minimize", study_name=f"r_deterministic_{datetime.now()}", storage="sqlite:///db.sqlite3")
study.optimize(objective, n_trials=100)

In [None]:
study.best_trial

In [None]:
train_relnet_shorthand(
    r_deterministic,
    seed=42,
    plot=True,
    lr=study.best_trial.params["lr"],
    lr_decay=study.best_trial.params["lr_decay"],
)

In [None]:
train_relnet_shorthand(r_random, seed=42, plot=True)