# Imports and algorithm

In [63]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
from typing import *
import wandb
from mw import log
from sklearn import datasets
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict
from sklearn.model_selection import train_test_split

torch.set_printoptions(sci_mode=False)
np.set_printoptions(suppress=True)
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x1fd61e7aa50>

In [65]:
class ManifoldWorms(nn.Module):
    def __init__(
        self,
        input_size: int,
        output_size: int,
        hidden_size: int,
        env_dim: int,
    ):
        super(ManifoldWorms, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.positions = nn.ParameterDict(
            {
                "tails": nn.Parameter(
                    torch.randn(input_size + hidden_size, env_dim), requires_grad=True
                ),
                "heads": nn.Parameter(
                    torch.randn(output_size + hidden_size, env_dim), requires_grad=True
                ),
            }
        )
        with torch.no_grad():
            self.normalize_positions()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        in_data = torch.zeros(self.input_size + self.hidden_size, 1)
        in_data[: x.shape[0]] = x

        out_data = torch.zeros(self.output_size, 1)

        self.normalize_positions()
        similarities = self.positions["heads"] @ self.positions["tails"].T

        for _ in range(20):
            x = similarities @ in_data
            out_data = out_data + x[: self.output_size]
            in_data[: self.output_size] = x[: self.output_size] * 0
            if in_data.norm() < 1e-4:
                break
        
        return out_data, similarities

    def normalize_positions(self):
        for name in self.positions:
            self.positions[name].data.copy_(
                F.normalize(self.positions[name], p=2, dim=1).data
            )

# Datasets

### Base

In [3]:
class BaseDataset(Dataset):
    def __init__(self, n_features: int = 4):
        self.data = torch.randn(256, n_features, 1)
        self.data /= self.data.abs().max()
        self.data[:, 0] = 1
        self.label = torch.ones(256, 1, 1)

    def __len__(self):
        return 256
    
    def __getitem__(self, idx):
        return self.data[idx], self.label[idx]

In [4]:
train_dataset = BaseDataset()
test_dataset = BaseDataset()

### Sklearn

In [52]:
class SklearnDataset(Dataset):
    def __init__(self, X, y):
        self.data = torch.Tensor(X).unsqueeze(-1)
        self.data /= self.data.abs().max()
        self.label = torch.Tensor(y).unsqueeze(-1).unsqueeze(-1)
        self.label /= self.label.abs().max()
        
    def __len__(self):
        return self.data.shape[0]
    
    def __getitem__(self, idx):
        return self.data[idx], self.label[idx]

##### Synthetic

In [17]:
X, y = datasets.make_regression(n_samples=1_000, n_features=12, noise=10, random_state=42)

##### California Housing
###### * Requires deep NN

In [53]:
data = datasets.fetch_california_housing()
X, y = data.data, data.target

In [54]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

train_dataset = SklearnDataset(X_train, y_train)
test_dataset = SklearnDataset(X_test, y_test)

# Training runs

In [55]:
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)

In [66]:
EPOCHS = 1_000
USE_WANDB = True
n_features = train_dataloader.dataset[0][0].shape[0]
hidden_size = 20
env_dims = 16
l1_scale = 0.2
weight_radius = 0.2

model = ManifoldWorms(n_features, 1, hidden_size, env_dims)
optimizer = optim.SGD(model.parameters(), lr=1e-4, weight_decay=1e-4)

In [67]:
if USE_WANDB:
    run = wandb.init(project="manifold_worms")

logs = defaultdict(list)
for epoch in range(EPOCHS):

    model.train()
    logs["train"].clear()
    for X, y in train_dataloader:
        X = X[0].requires_grad_(True)
        y = y[0]
        y_pred, cos_sim = model(X)
        loss = F.mse_loss(y_pred, y) + l1_scale * cos_sim.abs().sum()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        logs["train"].append(loss.item())

    model.eval()
    logs["test"].clear()
    for X, y in test_dataloader:
        X = X[0].requires_grad_(True)
        y = y[0]
        y_pred, cos_sim = model(X)
        loss = F.mse_loss(y_pred, y) + l1_scale * cos_sim.abs().sum()
        logs["test"].append(loss.item())

    if USE_WANDB:
        run.log(
            {
                "train_loss" : sum(logs["train"]) / len(logs["train"]),
                "test_loss" : sum(logs["test"]) / len(logs["test"])
            }
        )
    else:
        print("train_loss", sum(logs["train"]) / len(logs["train"]),
              "\ntest loss", sum(logs["test"]) / len(logs["test"]))
    
    logs["state"].append(
        log.visualize(
            model.positions["input_tails"].data,
            model.positions["exit_heads"].data
        )
    )

if USE_WANDB:
    run.log(
        {
        "video" : wandb.Video(
            np.stack(logs["state"]).transpose(0, 3, 1, 2),
            fps=15,
            format="gif"
        )
        }
    )
    run.finish()

RuntimeError: mat1 and mat2 shapes cannot be multiplied (21x28 and 21x1)