# Imports and algorithm

In [164]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
from typing import *
import wandb
from mw import log
from sklearn import datasets
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler

torch.set_printoptions(sci_mode=False)
np.set_printoptions(suppress=True)
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x1fd80562cc0>

In [115]:
class ManifoldWorms(nn.Module):
    def __init__(
        self,
        input_size: int,
        output_size: int,
        hidden_size: int,
        env_dim: int,
    ):
        super(ManifoldWorms, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        
        self.hidden_state = torch.zeros(input_size + hidden_size, 1).requires_grad_(True)
        self.outputs_mask = torch.zeros(output_size + hidden_size, 1)
        self.outputs_mask[: output_size] = 1
        
        self.bias = nn.Parameter(torch.zeros(output_size + hidden_size, 1))
        self.positions = nn.ParameterDict(
            {
                "tails": nn.Parameter(
                    torch.randn(input_size + hidden_size, env_dim), requires_grad=True
                ),
                "heads": nn.Parameter(
                    torch.randn(output_size + hidden_size, env_dim), requires_grad=True
                ),
            }
        )

    def forward(
            self,
            x: torch.Tensor,
            max_loops: int = 20,
            with_empty_hidden_state: bool = True,
            eps: float = 1e-4
            ) -> torch.Tensor:
        
        assert x.shape[-2] == self.input_size
        
        if with_empty_hidden_state:
            self.clear_hidden_state()
        
        x = F.pad(x, (0, 0, 0, self.hidden_size))
        self.hidden_state = self.hidden_state + x

        y = torch.zeros(self.output_size + self.hidden_size, 1)

        self.normalize_positions()
        similarities = self.positions["heads"] @ self.positions["tails"].T

        for _ in range(max_loops):
            # core transformations
            new_hidden_state = similarities @ self.hidden_state
            new_hidden_state = new_hidden_state + self.bias
            new_hidden_state = F.tanh(new_hidden_state)

            # move output head's inputs out of the loop
            y = y + new_hidden_state * self.outputs_mask
            new_hidden_state = new_hidden_state * (1 - self.outputs_mask)

            # reshapes the outputs as a new input
            new_hidden_state = F.pad(new_hidden_state, (0, 0, self.input_size - self.output_size, 0))
            self.hidden_state = new_hidden_state

            if new_hidden_state.norm() < eps:
                break
        
        # Outputting similarities for L1 Regularization
        return y[: self.output_size], similarities

    def clear_hidden_state(self):
        self.hidden_state = torch.zeros_like(self.hidden_state).requires_grad_(True)

    def normalize_positions(self):
        for name in self.positions:
            self.positions[name].data.copy_(
                F.normalize(self.positions[name].data, p=2, dim=1)
            )

# Datasets

### Base

In [28]:
class BaseDataset(Dataset):
    def __init__(self, n_features: int = 4):
        self.data = torch.randn(256, n_features, 1)
        self.data /= self.data.abs().max()
        self.data[:, 0] = 1
        self.label = torch.ones(256, 1, 1)

    def __len__(self):
        return 256
    
    def __getitem__(self, idx):
        return self.data[idx], self.label[idx]

In [29]:
train_dataset = BaseDataset()
test_dataset = BaseDataset()

### Sklearn

In [195]:
class SklearnDataset(Dataset):
    def __init__(self, X, y):
        X = (X - X.mean(0)) / X.std(0)
        y = (y - y.mean(0)) / y.std(0)
        self.data = torch.Tensor(X).unsqueeze(-1)
        self.label = torch.Tensor(y).unsqueeze(-1).unsqueeze(-1)
        
    def __len__(self):
        return self.data.shape[0]
    
    def __getitem__(self, idx):
        return self.data[idx], self.label[idx]

##### Synthetic

In [106]:
X, y = datasets.make_regression(n_samples=1_000, n_features=12, noise=10, random_state=42)

##### California Housing (Requires deep NN)

In [196]:
data = datasets.fetch_california_housing()
X, y = data.data, data.target

#### Make dataset

In [197]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

train_dataset = SklearnDataset(X_train, y_train)
test_dataset = SklearnDataset(X_test, y_test)

# Training runs

In [213]:
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)

In [214]:
EPOCHS = 50
USE_WANDB = True
GRADIENT_NORM = True
n_features = train_dataloader.dataset[0][0].shape[0]
hidden_size = 10
env_dims = 3
l1_scale = 0.0

model = ManifoldWorms(n_features, 1, hidden_size, env_dims)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

In [None]:
if USE_WANDB:
    run = wandb.init(project="manifold_worms")

logs = defaultdict(list)
for epoch in range(EPOCHS):

    for k in logs:
        if "test" in k or "train" in k:
            logs[k].clear()

    model.train()
    for X, y in train_dataloader:

        X = X[0].requires_grad_(True)
        y = y[0]
        y_pred, cos_sim = model(X)

        mse_loss = F.mse_loss(y_pred, y)
        l1_loss = l1_scale * cos_sim.abs().sum()
        garbage_loss = model.hidden_state.abs().sum()
        loss = mse_loss + l1_loss + garbage_loss

        logs["train_mse_loss"].append(mse_loss.item())
        logs["train_l1_loss"].append(l1_loss.item())
        logs["train_garbage_loss"].append(garbage_loss.item())

        optimizer.zero_grad()
        loss.backward()

        if GRADIENT_NORM:
            for param in model.parameters():
                if param.grad is not None:
                    param.grad.div_(
                        param.grad.norm().clip(1e-6)
                    )
        #nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)

        optimizer.step()

    model.eval()
    for X, y in test_dataloader:
        X = X[0].requires_grad_(True)
        y = y[0]
        y_pred, cos_sim = model(X)

        mse_loss = F.mse_loss(y_pred, y)
        l1_loss = l1_scale * cos_sim.abs().sum()
        garbage_loss = model.hidden_state.abs().sum()
        loss = mse_loss + l1_loss + garbage_loss

        logs["test_mse_loss"].append(mse_loss.item())
        logs["test_l1_loss"].append(l1_loss.item())
        logs["test_garbage_loss"].append(garbage_loss.item())

    logs["state"].append(
        log.visualize(
            model.positions["tails"].data,
            model.positions["heads"].data
        )
    )
    
    if USE_WANDB:
        scalars = {
            key : sum(values) / len(values)
            for key, values in logs.items() if key != "state"
        }
        for name, param in model.named_parameters():
            if param.grad is not None:
                scalars[f"grad_{name}_mean"] = param.grad.mean().item()
                scalars[f"grad_{name}_std"] = param.grad.std().item()
        run.log(scalars)

if USE_WANDB:
    run.log(
        {
        "video" : wandb.Video(
            np.stack(logs["state"]).transpose(0, 3, 1, 2),
            fps=15,
            format="gif"
        )
        }
    )

In [170]:
if USE_WANDB:
    run.finish()

0,1
grad_bias_mean,▇▅▆▇▁▇█▂▇▂█▂▆▃▆▂▆▅▅▁▃▆▇▆▄▄▄▃▄▅▅▅▃▂▄▅▂▄▄▄
grad_bias_std,▆█▆▄▃▅▁▅▄▅▁▆▇█▇▆▆█▇▄█▆█▅▆██▇█▇▇▇▅█▆█▇███
grad_positions.heads_mean,▃▇▂▂█▁█▂█▃▇▃▅▅▄▆█▅▄▆▄▅▅▆▄▅▆▇▅▄▅▅▄▆▆▆▆▅▅▄
grad_positions.heads_std,▅█▇▃▃▅▁▅▄▅▁▆▇██▇▄█████████▇█████████▇███
grad_positions.tails_mean,▄▅▄▄▅▄▅▃▅▆▆▆█▁█▂▁▆█▄▆▃█▃▄▆▃▅▆█▆▆▆▇▁▂▂▂█▅
grad_positions.tails_std,█████████▇▇▅▅▂▄▃▇▂█▄▇▁▇█▅▇█▇▃▇▇▆▃▄█▆▆▁▇█
test_garbage_loss,██▇▇▆▆▆▆▆▆▅▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_l1_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_mse_loss,███▇▆▅▅▅▄▄▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_garbage_loss,███▇▇▇▆▆▆▆▆▆▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
grad_bias_mean,-0.0063
grad_bias_std,0.31616
grad_positions.heads_mean,-0.01503
grad_positions.heads_std,0.17612
grad_positions.tails_mean,0.00746
grad_positions.tails_std,0.12381
test_garbage_loss,0.00189
test_l1_loss,0.0
test_mse_loss,0.00855
train_garbage_loss,0.00069


In [None]:
_, y = next(iter(train_dataloader))
y[0, 0], model(X)[0][0, 0]

(tensor([0.2428]), tensor(-0.2976, grad_fn=<SelectBackward0>))