In [None]:
from typing import Iterator

import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
import torch
from torch import nn
from torch.utils.data import DataLoader

import torch
from torch import nn
from torch.nn import functional as F, Parameter
from torch.utils.data import DataLoader
from torch.utils.data import random_split
import pytorch_lightning as pl

In [None]:
import utils
from experiments import *

In [None]:
import wandb
wandb.init(project="Autoencoders Algebra")

In [None]:
class Algebraic(nn.Module):

    def __init__(self, input_dim):
        super().__init__()
        # self.funcs = [torch.sin, torch.cos, torch.tan]
        # self.param = nn.ParameterDict({
        #     's1': nn.Parameter(torch.tensor([1], dtype=torch.float32), requires_grad=True),
        #     's2': nn.Parameter(torch.tensor(np.zeros(input_dim), dtype=torch.float32), requires_grad=True),
        #     'c1': nn.Parameter(torch.tensor([1], dtype=torch.float32), requires_grad=True),
        #     'c2': nn.Parameter(torch.tensor(np.zeros(input_dim), dtype=torch.float32), requires_grad=True),
        #     't1': nn.Parameter(torch.tensor([1], dtype=torch.float32), requires_grad=True),
        #     't2': nn.Parameter(torch.tensor(np.zeros(input_dim), dtype=torch.float32), requires_grad=True),
        # })

    def forward(self, x):
        return torch.concat((
            x,
            torch.sin(x),
            torch.cos(x),
            torch.arcsin(((x + 1) % 2) - 1)
        ), dim=1)

    output_dim = 4 # update according to `forward`

    # def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
    #     return self.param.values().__iter__()

In [None]:
class DefaultAE(nn.Module):
    name = "Default AE"
    def __init__(self, input_dim: int, hidden_dim: int, intermediate_dim: int):
        super(DefaultAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, intermediate_dim),
            nn.Tanh(),
            nn.Linear(intermediate_dim, intermediate_dim),
            nn.Tanh(),
            nn.Linear(intermediate_dim, intermediate_dim),
            nn.Tanh(),
            nn.Linear(intermediate_dim, intermediate_dim),
            nn.Tanh(),
            nn.Linear(intermediate_dim, intermediate_dim),
            nn.Tanh(),
            nn.Linear(intermediate_dim, intermediate_dim),
            nn.Tanh(),
            nn.Linear(intermediate_dim, hidden_dim),
            nn.Tanh(),
        )
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, intermediate_dim),
            nn.Tanh(),
            nn.Linear(intermediate_dim, intermediate_dim),
            nn.Tanh(),
            nn.Linear(intermediate_dim, intermediate_dim),
            nn.Tanh(),
            nn.Linear(intermediate_dim, intermediate_dim),
            nn.Tanh(),
            nn.Linear(intermediate_dim, intermediate_dim),
            nn.Tanh(),
            nn.Linear(intermediate_dim, intermediate_dim),
            nn.Tanh(),
            nn.Linear(intermediate_dim, input_dim),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
class AlgebraAE(nn.Module):
    name = "Algebra AE"
    def __init__(self, input_dim: int, hidden_dim: int, intermediate_dim: int):
        super(AlgebraAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, intermediate_dim),
            nn.Tanh(),
            nn.Linear(intermediate_dim, intermediate_dim),
            nn.Tanh(),
            nn.Linear(intermediate_dim, intermediate_dim),
            nn.Tanh(),
            nn.Linear(intermediate_dim, intermediate_dim),
            nn.Tanh(),
            nn.Linear(intermediate_dim, intermediate_dim),
            nn.Tanh(),
            nn.Linear(intermediate_dim, intermediate_dim),
            nn.Tanh(),
            nn.Linear(intermediate_dim, hidden_dim),
            # Algebra begin
            Algebraic(hidden_dim),
            nn.Linear(Algebraic.output_dim * hidden_dim, hidden_dim),
        )
        self.decoder = nn.Sequential(
            Algebraic(hidden_dim),
            # Algebra end
            nn.Linear(Algebraic.output_dim * hidden_dim, intermediate_dim),
            nn.Tanh(),
            nn.Linear(intermediate_dim, intermediate_dim),
            nn.Tanh(),
            nn.Linear(intermediate_dim, intermediate_dim),
            nn.Tanh(),
            nn.Linear(intermediate_dim, intermediate_dim),
            nn.Tanh(),
            nn.Linear(intermediate_dim, intermediate_dim),
            nn.Tanh(),
            nn.Linear(intermediate_dim, intermediate_dim),
            nn.Tanh(),
            nn.Linear(intermediate_dim, 2),
        )
    def forward(self, x):
        embedding = self.encoder(x)
        return self.decoder(embedding)

## Training

In [None]:
def train_model(train_dataloader, valid_dataloader, input_dim: int, experiment_name: str, hidden_layer_dim: int, ae, epochs=1000, intermediate_dim=20, batch_size=10):
    model = ae(input_dim, hidden_layer_dim, intermediate_dim).to(device)
    criterion = nn.MSELoss().to(device)
    optimizer = torch.optim.Adam(model.parameters())

    counter = 0
    for epoch in tqdm(range(epochs)):
        train_losses = []
        valid_losses = []

        model.train()
        for batch_pts in train_dataloader:
            inp = batch_pts.float().to(device)
            output = model(inp)
            loss = criterion(output, inp)
            train_losses.append(loss.item())

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        
        model.eval()
        for batch_pts in valid_dataloader:
            inp = batch_pts.float().to(device)
            output = model(inp)
            loss = criterion(output, inp)
            valid_losses.append(loss.item())
        
        train_loss = np.sqrt(np.average(train_losses))
        valid_loss = np.sqrt(np.average(valid_losses))

        wandb.log({f"{experiment_name}_{hidden_layer_dim}_train_loss_{ae.name}": train_loss})
        wandb.log({f"{experiment_name}_{hidden_layer_dim}_val_loss_{ae.name}": valid_loss})
        
        if valid_loss < decision_threshold / 2:
            counter += 1
        else:
            counter = 0
        if counter > 50:
            wandb.alert(title="Early stopping", text=f"Early stopping for {experiment_name}{hidden_layer_dim} on epoch #{epoch}/{epochs}")
            print("Early stopping")
            break

    return model

In [None]:
device = "cpu"
decision_threshold = 0.02

In [None]:
def train_n_eff(exp: PhysExperiment, ae, epochs=5000, intermediate_dim=20, batch_size=10):
    n_eff = exp.n_eff
    traj = exp.single_trajectory(42)
    traj_scaled = MaxAbsScaler().fit_transform(traj) # scale
    traj_train, traj_val, traj_test = random_split(traj_scaled, [0.8, 0.1, 0.1])

    train_dataloader = DataLoader(traj_train, batch_size=batch_size, shuffle=True)
    valid_dataloader = DataLoader(traj_val, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(traj_test, batch_size=batch_size, shuffle=True)
    
    model = train_model(train_dataloader, valid_dataloader, traj.shape[1], exp.experiment_name, n_eff, ae, epochs, intermediate_dim, batch_size)
    test_traj = torch.Tensor(traj_scaled).to(device)
    with torch.no_grad():
        embedding = model.encoder(test_traj).detach().cpu().numpy()
        transformed = model(test_traj).detach().cpu().numpy()

        all_trajs = np.concatenate((traj_scaled, transformed))
        color = np.concatenate((np.zeros(shape=(traj.shape[0], 1)), np.ones(shape=(transformed.shape[0], 1))))
        exported = np.append(all_trajs, color, axis=1)
        table = wandb.Table(columns=exp.column_names + ["transformed"], data=exported)
        wandb.log({f"{exp.experiment_name} before/after ({ae.name})": table})
    if n_eff == 1:
        # coloring
        traj_with_color = np.append(traj_scaled, embedding, axis=1)
        wandb.log({f"{exp.experiment_name} coloring for n_eff=1 embedding ({ae.name})": wandb.Table(exp.column_names + ["color"], data=traj_with_color)})
    elif n_eff == 2:
        # 2d embedding
        wandb.log({f"{exp.experiment_name} 2d n_eff embedding ({ae.name})": wandb.Table(["projection1", "projection2"], embedding)})
    elif n_eff == 3:
        # 3d embedding
        wandb.log({f"{exp.experiment_name} 3d n_eff embedding ({ae.name})": wandb.Object3D(embedding)})
    else:
        wandb.alert(f"no visual representation for n_eff={n_eff} with experiment {exp.experiment_name} ({ae.name})")

In [None]:
exps = [Pendulum, HarmonicOscillator]
aes = [DefaultAE, AlgebraAE]

for exp in tqdm(exps, position=0):
    for ae in tqdm(aes, position=1, leave=False):
        train_n_eff(exp, ae, epochs=1000, intermediate_dim=10)

In [None]:
wandb.finish()