In [None]:
import torch
import torch.nn as nn

from torch.optim import Adam

from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader

from torch.distributions import Normal  # using torch distributions

import matplotlib.pyplot as plt

In [None]:
torch.manual_seed(42)
torch.set_default_dtype(torch.float32)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
input_dim = 28*28  # input size (MNIST)
hidden_dim = 1000  # output size of the hidden layers
num_coupling_layers = 5  # number of coupling layers
num_layers = 6  # number of linear layers for each coupling layer

epochs = 10
batch_size = 128
lr = 1e-3

## Masks

In [None]:
mask = torch.zeros(20)
mask[::2] = 1
mask = mask.float()

print(mask)

In [None]:
x = torch.randn(20)

print(x)

In [None]:
print(x * mask)

In [None]:
print(x * (1 - mask))

## Torch Distributions

In [None]:
distribution = Normal(0, 1)

In [None]:
z = distribution.sample((20,))

print(z)

In [None]:
distribution.log_prob(z)  # log-likelihoods

## Additive Coupling

Recap of additive coupling.

We split the input $x$ in two equal parts $x_1$ and $x_2$.

A coupling layer transforms only $x_2$ based on $x_1$ (or vice versa).

\begin{equation*}
    y_1 = x_1, \ \ \ \ y_2 = x_2 + m_{\theta}(x_1), \ \ \ \ y = \text{concat}(y_1, y_2)
\end{equation*}

In [None]:
class CouplingLayer(nn.Module):

  def __init__(self, input_dim, hidden_dim, mask, num_layers=4):
    super().__init__()

    self.mask = mask

    modules = [nn.Linear(input_dim, hidden_dim), 
               nn.LeakyReLU(0.2)]
    
    for _ in range(num_layers - 2):
      modules.append(nn.Linear(hidden_dim, hidden_dim))
      modules.append(nn.LeakyReLU(0.2))
    modules.append(nn.Linear(hidden_dim, input_dim))

    self.m = nn.Sequential(*modules)

  def forward(self, x):
      x1 = self.mask * x
      x2 = (1 - self.mask) * x
      y1 = x1
      y2 = x2 + (self.m(x1) * (1 - self.mask))
      return y1 + y2
    
  # inverse mapping
  def inverse(self, x):
    y1 = self.mask * x
    y2 =(1 - self.mask) * x
    x1 = y1
    x2 = y2 - (self.m(y1) * (1 - self.mask))
    return x1 + x2

## Not-So-NICE (Gaussian prior + No scaling)

This architecture consists just in a sequence of coupling layers, with alternating masks:

- half of the masks cover even indices, the other half covers odd indices

In [None]:
class NotSoNICE(nn.Module):
  def __init__(self, input_dim, hidden_dim=1000, num_coupling_layers=3, num_layers=6, device='cpu'):
    super().__init__()

    self.input_dim = input_dim
    self.hidden_dim = hidden_dim
    self.num_coupling_layers = num_coupling_layers
    self.num_layers = num_layers  # number of linear layers for each coupling layer

    # alternating mask orientations for consecutive coupling layers
    masks = [self._get_mask(input_dim, orientation=(i % 2 == 0)).to(device)
                                            for i in range(num_coupling_layers)]

    self.coupling_layers = nn.ModuleList([CouplingLayer(input_dim=input_dim,
                                hidden_dim=hidden_dim,
                                mask=masks[i], num_layers=num_layers)
                              for i in range(num_coupling_layers)])

    self.prior = Normal(0, 1)
    self.device = device

  def forward(self, x):
    
    z = x
    for i in range(len(self.coupling_layers)):  # pass through each coupling layer
      z = self.coupling_layers[i](z)

    log_likelihood = torch.sum(self.prior.log_prob(z), dim=1)
    return z, log_likelihood

  # we don't call this during training, but we use it for inference
  def inverse(self, z):
    x = z
    for i in reversed(range(len(self.coupling_layers))):  # pass through each coupling layer in reversed order
      x = self.coupling_layers[i].inverse(x)
    return x

  def sample(self, num_samples):
    z = self.prior.sample([num_samples, self.input_dim]).view(num_samples, self.input_dim)
    return self.inverse(z)

  def _get_mask(self, dim, orientation=True):
    mask = torch.zeros(dim)
    mask[::2] = 1
    if orientation:
      mask = 1 - mask # flip mask if orientation is True
    return mask.float()

In [None]:
# Define the dataset and data loader
train_dataset = MNIST(root='./data', train=True, transform=ToTensor())
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

In [None]:
# Define the NICE model
model = NotSoNICE(input_dim=input_dim, num_coupling_layers=num_coupling_layers, num_layers=num_layers).to(device)

model.train()
for i, layer in enumerate(model.coupling_layers):
    model.coupling_layers[i].mask = layer.mask.to(device)

# Train the model
model.train()

optimizer = Adam(model.parameters(), lr=lr)

## Training

- Architecture: NotSoNICE
- Optimizer: Adam
- Loss: Negative log-likelihood (<b>Why?</b>)

In [None]:
for epoch in range(epochs):
  mean_likelihood = 0.0
  batch_counter = 0

  for batch_id, (x, _) in enumerate(train_loader):
      
      model.zero_grad()

      x = x.to(device)
      x = x.view(-1, 28*28)  # flatten
      
      z, likelihood = model(x)
      loss = -torch.mean(likelihood)  # Negative log-likelihood

      loss.backward()
      optimizer.step()      

      mean_likelihood -= loss
      batch_counter += 1

  mean_likelihood /= batch_counter  # normalize w.r.t. the batches
  print(f'Epoch {epoch+1:d} completed. Log Likelihood: {mean_likelihood:.2f}')

In [None]:
# import os
# if not os.path.isdir("saved_models"):
#     os.makedirs("saved_models")

# torch.save(model.state_dict(), "saved_models/NotSoNICE.pt")