# Imports and algorithm

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
from typing import *
import wandb
from mw import log
from sklearn import datasets
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict
from sklearn.model_selection import train_test_split

torch.set_printoptions(sci_mode=False)
np.set_printoptions(suppress=True)

In [27]:
class ManifoldWorms(nn.Module):
    def __init__(
        self,
        input_size: int,
        output_size: int,
        env_dim: int,
    ):
        super(ManifoldWorms, self).__init__()
        self.positions = nn.ParameterDict(
            {
                "input_tails": nn.Parameter(
                    torch.randn(input_size, env_dim), requires_grad=True
                ),
                "exit_heads": nn.Parameter(
                    torch.randn(output_size, env_dim), requires_grad=True
                ),
            }
        )
        with torch.no_grad():
            self.normalize_positions()

    def forward(self) -> torch.Tensor:
        self.normalize_positions()
        similarities = self.positions["exit_heads"] @ self.positions["input_tails"].T
        return similarities

    def normalize_positions(self):
        for name in self.positions:
            self.positions[name].data.copy_(
                F.normalize(self.positions[name], p=2, dim=1).data
            )


# Datasets

### Base

In [3]:
class BaseDataset(Dataset):
    def __init__(self, n_features: int = 4):
        self.data = torch.randn(256, n_features, 1)
        self.data /= self.data.abs().max()
        self.data[:, 0] = 1
        self.label = torch.ones(256, 1, 1)

    def __len__(self):
        return 256
    
    def __getitem__(self, idx):
        return self.data[idx], self.label[idx]

In [4]:
train_dataset = BaseDataset()
test_dataset = BaseDataset()

### Sklearn

In [16]:
class SklearnDataset(Dataset):
    def __init__(self, X, y):
        self.data = torch.Tensor(X).unsqueeze(-1)
        self.data /= self.data.abs().max()
        self.label = torch.Tensor(y).unsqueeze(-1).unsqueeze(-1)
        self.label /= self.label.abs().max()
        
    def __len__(self):
        return self.data.shape[0]
    
    def __getitem__(self, idx):
        return self.data[idx], self.label[idx]

##### Synthetic

In [17]:
X, y = datasets.make_regression(n_samples=1_000, n_features=12, noise=10, random_state=42)

##### California Housing
###### * Requires deep NN

In [34]:
data = datasets.fetch_california_housing()
X, y = data.data, data.target

In [35]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

train_dataset = SklearnDataset(X_train, y_train)
test_dataset = SklearnDataset(X_test, y_test)

# Training runs

In [43]:
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)

In [44]:
EPOCHS = 100
USE_WANDB = True
n_features = train_dataloader.dataset[0][0].shape[0]
env_dims = 16
l1_scale = 0.2
weight_radius = 0.2

model = ManifoldWorms(n_features, 1, env_dims)
optimizer = optim.SGD(model.parameters(), lr=1e-4, weight_decay=1e-4)

In [45]:
if USE_WANDB:
    run = wandb.init(project="manifold_worms")

logs = defaultdict(list)
for epoch in range(EPOCHS):

    model.train()
    logs["train"].clear()
    for X, y in train_dataloader:
        X = X[0].requires_grad_(True)
        y = y[0]
        weights = model()
        y_pred = weights @ X
        loss = F.mse_loss(y_pred, y) + l1_scale * weights.abs().sum()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        logs["train"].append(loss.item())

    model.eval()
    logs["test"].clear()
    for X, y in test_dataloader:
        X = X[0].requires_grad_(True)
        y = y[0]
        weights = model()
        y_pred = weights @ X
        loss = F.mse_loss(y_pred, y) + l1_scale * weights.abs().sum()
        logs["test"].append(loss.item())

    if USE_WANDB:
        run.log(
            {
                "train_loss" : sum(logs["train"]) / len(logs["train"]),
                "test_loss" : sum(logs["test"]) / len(logs["test"])
            }
        )
    else:
        print("train_loss", sum(logs["train"]) / len(logs["train"]),
              "\ntest loss", sum(logs["test"]) / len(logs["test"]))
    
    logs["state"].append(
        log.visualize(
            model.positions["input_tails"].data,
            model.positions["exit_heads"].data
        )
    )

if USE_WANDB:
    run.log(
        {
        "video" : wandb.Video(
            np.stack(logs["state"]).transpose(0, 3, 1, 2),
            fps=15,
            format="gif"
        )
        }
    )
    run.finish()

0,1
test_loss,▃▄▇▅▇▄▅▄▆▆▄▂▄▅▂▄▄▄▅▃▂▄▃▃█▄▃▆▄▂▅▃▃▄▃▁▆▆▇▂
train_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
test_loss,0.22344
train_loss,0.2249
