In [None]:
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset


In [None]:
from tools import obj_dic, show_heatmap_contours, show_heatmap

SEED = 1412

def gen_data(N, seed=SEED):
    r = np.random.default_rng(seed)
    u = r.uniform(0, 1, (N,))
    u2 = u[...,None]
    m = np.array([-10, -5]) + 20*u2
    m -= 10*np.array([0, 1])*(2*u2-1)**2
    m += 10*np.array([0, 1])*np.sin(u2*10)
    p = m + r.normal(0, .1, (N, 2))
    return p, obj_dic(locals())

data, gt = gen_data(2500)
print(data.shape)

plt.scatter(data[:,0], data[:,1], marker='.', alpha=0.1)


In [None]:
class AE1(nn.Module):
    
    def __init__(self):
        super(AE1, self).__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(2, 1),
            nn.Linear(1, 2),
        )

    def forward(self, x):
        logits = self.linear_relu_stack(x)
        return logits

m1 = AE1()

In [None]:
# untrained AE network
print(data[0,:])
m1.forward(torch.Tensor(data[0,:]))

In [None]:
def plot_model(m):
    plt.scatter(data[:,0], data[:,1], marker='.', alpha=0.1)
    recons = m(torch.Tensor(data)).detach().numpy()
    plt.scatter(recons[:,0], recons[:,1], marker='.', alpha=0.1)

In [None]:
plot_model(m1)

In [None]:
train_dataloader = DataLoader(
    TensorDataset(torch.Tensor(data),
                  torch.Tensor(data)),
    batch_size=128)

In [None]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (batch+1) % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

In [None]:
def train(model, epochs = 100):
    loss_fn = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

    for t in range(epochs):
        #print(f"Epoch {t+1}\n-------------------------------")
        train_loop(train_dataloader, model, loss_fn, optimizer)
        #test_loop(test_dataloader, model, loss_fn)
    print("Done!")

In [None]:
train(m1)

In [None]:
plot_model(m1)

In [None]:
class AE2(nn.Module):
    
    def __init__(self):
        super(AE2, self).__init__()
        D = 2 # dim of the X space
        L = 1 # dim of the latent space
        self.encoder = nn.Sequential(
            nn.Linear(D, 100),
            nn.ReLU(),
            nn.Linear(100, L),
        )
        self.decoder = nn.Sequential(
            nn.Linear(L, 100),
            nn.ReLU(),
            nn.Linear(100, D),
        )

    def forward(self, x):
        pred = self.decoder(self.encoder(x))
        return pred

m2 = AE2()

In [None]:
train(m2, epochs=100)

In [None]:
plot_model(m2)

In [None]:
def plot_latent_hist(m):
    plt.hist(m.encoder(torch.Tensor(data))[:,0].detach().numpy(), bins=100);

plot_latent_hist(m2)

In [None]:
class VAE2(nn.Module):
    
    def __init__(self):
        D = 2 # dim of the X space
        L = 1 # dim of the latent space
        super(VAE2, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(D, 100),
            nn.ReLU(),
            nn.Linear(100,    L+L   ), # a mean on z, and a logvar on z (so that exp(logvar) is always positive
        )
        self.decoder = nn.Sequential(
            nn.Linear(1, 100),
            nn.ReLU(),
            nn.Linear(100, D),
        )

    def forward(self, x):
        D = 2 # dim of the X space
        L = 1 # dim of the latent space
        mu_logvar = self.encoder(x)
        
        mu = mu_logvar[:,0:1]
        std = torch.exp(mu_logvar[:,1:2]/2)

        z = mu + std * torch.normal(0, 1, (x.shape[0], L))
        pred = self.decoder(z)
        return pred

vm2 = VAE2()

In [None]:
def train_loop_with_KL(dataloader, model, loss_fn, optimizer, showloss):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        mu_logvar = model.encoder(X)
        var = torch.exp(mu_logvar[:,1])
        kl = 0.5 * (mu_logvar[:,0]**2 + var - torch.log(var))
        loss = loss_fn(pred, y) + kl.mean()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (batch) % 100 == 0 and showloss:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f} {kl.mean():>7f}  [{current:>5d}/{size:>5d}]")

def train_with_KL(model, epochs = 100, showloss=False, lr=2e-2):
    loss_fn = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for t in range(epochs):
        #print(f"Epoch {t+1}\n-------------------------------")
        train_loop_with_KL(train_dataloader, model, loss_fn, optimizer, showloss)
        #test_loop(test_dataloader, model, loss_fn)
    print("Done!")


In [None]:
train_with_KL(vm2)

In [None]:
plot_model(vm2)

In [None]:
plot_latent_hist(vm2)

In [None]:
# Unfruitful attempt to improve this simple case


class ResnetLinear(nn.Module):

    def __init__(self, INOUT, MID):
        super(ResnetLinear, self).__init__()
        self.residual = nn.Sequential(
            nn.Linear(INOUT, MID),
            nn.SiLU(),
            nn.Linear(MID, INOUT)
        )
    
    def forward(self, x):
        return x + self.residual(x)

class VAE3(nn.Module):
    
    def __init__(self):
        D = 2 # dim of the X space
        L = 1 # dim of the latent space
        super(VAE3, self).__init__()
        H = 10
        self.encoder = nn.Sequential(
            nn.Linear(D, H),
            ResnetLinear(H, H),
            ResnetLinear(H, H),
            nn.SiLU(),
            nn.Linear(H, L+L)
        )
        self.decoder = nn.Sequential(
            nn.Linear(L, H),
            ResnetLinear(H, H),
            ResnetLinear(H, H),
            nn.SiLU(),
            nn.Linear(H, D)
        )

    def forward(self, x):
        D = 2 # dim of the X space
        L = 1 # dim of the latent space
        mu_logvar = self.encoder(x)
        
        mu = mu_logvar[:,0:1]
        std = torch.exp(mu_logvar[:,1:2]/2)

        z = mu + torch.normal(0, 1, (x.shape[0], L))
        pred = self.decoder(z)
        return pred

def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_normal_(m.weight)
        m.bias.data.fill_(0.01)
        
vm3 = VAE3()
#vm3.encoder.apply(init_weights)
#vm3.decoder.apply(init_weights)

In [None]:
train_with_KL(vm3, epochs=100, showloss=True)
#train(vm3, epochs=100)

In [None]:
plot_model(vm3)

In [None]:
plot_latent_hist(vm3)