In [1]:
import mw
import torch
import torch.optim as optim
from sklearn import datasets
from collections import defaultdict
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score
import torch.nn.functional as F
import wandb

In [2]:
data = datasets.fetch_california_housing()
X, y = data.data, data.target

# Data preprocessing

In [3]:
# normalize
X_max, X_min = X.max(0), X.min(0)
y_max, y_min = y.max(0), y.min(0)
X = (X - X_min) / (X_max - X_min)
y = (y - y_min) / (y_max - y_min)

In [4]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

### Initialization

In [5]:
class SklearnDataset(Dataset):
    def __init__(self, X, y):
        self.data = torch.Tensor(X).unsqueeze(-1)
        self.label = torch.Tensor(y).unsqueeze(-1).unsqueeze(-1)
        
    def __len__(self):
        return self.data.shape[0]
    
    def __getitem__(self, idx):
        return self.data[idx], self.label[idx]

        
train_dataset = SklearnDataset(X_train, y_train)
test_dataset = SklearnDataset(X_test, y_test)

In [6]:
BATCH_SIZE = 24

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [20]:
INPUT_SIZE = 8
HIDDEN_SIZE = 8
N_HIDDEN_LAYERS = 2
OUTPUT_SIZE = 1
DIMS = 6

model_base = mw.models.Regression(
    input_size=INPUT_SIZE,
    hidden_size=HIDDEN_SIZE,
    n_hidden_layers=N_HIDDEN_LAYERS,
    output_size=OUTPUT_SIZE
)
model_mw = mw.models.ManifoldWorms(
    input_size=INPUT_SIZE,
    hidden_size=HIDDEN_SIZE * N_HIDDEN_LAYERS,
    output_size=1,
    d=6
)

optim_base = optim.AdamW(model_base.parameters(), lr=1e-4, weight_decay=1e-5)
optim_mw = optim.AdamW(model_mw.parameters(), lr=1e-4, weight_decay=1e-5)
optim_mw.register_step_post_hook(model_mw.post_step)

<torch.utils.hooks.RemovableHandle at 0x1a289246240>

In [19]:
USE_WANDB = True

# Train

In [None]:
if USE_WANDB:
    run = wandb.init(project="manifold_worms")

logs = defaultdict(list)
for epoch in range(100):

    for k in logs:
        if any([x in k for x in ["train", "test"]]):
            logs[k].clear()

    model_mw.train()
    model_base.train()
    for X, y in train_dataloader:

        # mw training
        model_mw.clear_state()
        y_pred_mw = model_mw(X)
        for _ in range(100):
            increment = model_mw()
            y_pred_mw = y_pred_mw + increment
            if increment.norm() < 1e-4:
                break

        rmse_loss = F.mse_loss(y_pred_mw, y).sqrt()
        garbage_loss = model_mw.state.mean(0).abs().sum()
        loss_mw = rmse_loss + garbage_loss
        r2 = r2_score(y.flatten().tolist(), y_pred_mw.flatten().tolist())
        logs["mw_train_loss"].append(loss_mw.item())
        logs["mw_train_r2"].append(r2)
        optim_mw.zero_grad()
        loss_mw.backward()
        model_mw.normalize_grads()
        optim_mw.step()

        # baseline training
        y_pred_base = model_base(X[..., 0])
        loss_base = F.mse_loss(y_pred_base, y[..., 0]).sqrt()
        r2 = r2_score(y.flatten().tolist(), y_pred_base.flatten().tolist())
        logs["base_train_loss"].append(loss_base.item())
        logs["base_train_r2"].append(r2)
        optim_base.zero_grad()
        loss_base.backward()
        optim_base.step()

    model_mw.eval()
    model_base.eval()
    for X, y in test_dataloader:

        # mw eval
        model_mw.clear_state()
        y_pred_mw = model_mw(X)
        for _ in range(100):
            increment = model_mw()
            if increment.norm() < 1e-4:
                break
            y_pred_mw = y_pred_mw + increment
        rmse_loss = F.mse_loss(y_pred_mw, y).sqrt()
        garbage_loss = model_mw.state.mean(0).abs().sum()
        loss_mw = rmse_loss + garbage_loss
        r2 = r2_score(y.flatten().tolist(), y_pred_mw.flatten().tolist())
        logs["mw_test_loss"].append(loss_mw.item())
        logs["mw_test_r2"].append(r2)

        # baseline eval
        y_pred_base = model_base(X[..., 0])
        loss_base = F.mse_loss(y_pred_base, y[..., 0]).sqrt()
        r2 = r2_score(y.flatten().tolist(), y_pred_base.flatten().tolist())
        logs["base_test_loss"].append(loss_base.item())
        logs["base_test_r2"].append(r2)
    
    if USE_WANDB:
        scalars = {
            key : sum(values) / len(values)
            for key, values in logs.items() if key != "state"
        }
        for model, name in [(model_mw, 'mw'), (model_base, 'base')]:
            for name, param in model.named_parameters():
                if param.grad is not None:
                    scalars[f"{model}_grad_{name}_mean"] = param.grad.mean().item()
                    scalars[f"{model}_grad_{name}_std"] = param.grad.std().item()
        run.log(scalars)

[34m[1mwandb[0m: Currently logged in as: [33mrubn[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
