In [None]:
from tqdm import tqdm 

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.optim import Adam

from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader

from torch.distributions import Uniform, Distribution

In [None]:
from utils import CouplingLayer, ScalingLayer

In [None]:
torch.manual_seed(42)
torch.set_default_dtype(torch.float32)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
input_dim = 28*28  # input size (MNIST)
hidden_dim = 1000  # output size of the hidden layers
num_coupling_layers = 5  # number of coupling layers
num_layers = 6  # number of linear layers for each coupling layer

epochs = 10
batch_size = 128
lr = 1e-3

In [None]:
class LogisticDistribution(Distribution):
  def __init__(self):
    super().__init__()

  def log_prob(self, x):
    return -(F.softplus(x) + F.softplus(-x))

  def sample(self, size):
    z = Uniform(torch.FloatTensor([0.]), torch.FloatTensor([1.])).sample(size)

    return torch.log(z) - torch.log(1. - z)
  
class NICE(nn.Module):
  def __init__(self, input_dim, hidden_dim=1000, num_coupling_layers=3, num_layers=6, device='cpu'):
    super().__init__()

    self.input_dim = input_dim
    self.hidden_dim = hidden_dim
    self.num_coupling_layers = num_coupling_layers
    self.num_layers = num_layers  # number of linear layers for each coupling layer

    # alternating mask orientations for consecutive coupling layers
    masks = [self._get_mask(input_dim, orientation=(i % 2 == 0)).to(device)
                                            for i in range(num_coupling_layers)]

    self.coupling_layers = nn.ModuleList([CouplingLayer(input_dim=input_dim,
                                hidden_dim=hidden_dim,
                                mask=masks[i], num_layers=num_layers)
                              for i in range(num_coupling_layers)])

    self.scaling_layer = ScalingLayer(input_dim=input_dim)

    self.prior = LogisticDistribution()
    self.device = device

  def forward(self, x):
    
    z = x
    for i in range(len(self.coupling_layers)):  # pass through each coupling layer
      z = self.coupling_layers[i](z)
    z, log_det_jacobian = self.scaling_layer(z)

    log_likelihood = torch.sum(self.prior.log_prob(z), dim=1) + log_det_jacobian

    return z, log_likelihood

  def inverse(self, z):
    x = z
    x = self.scaling_layer.inverse(x)
    for i in reversed(range(len(self.coupling_layers))):  # pass through each coupling layer in reversed order
      x = self.coupling_layers[i].inverse(x)
    return x

  def sample(self, num_samples):
    z = self.prior.sample([num_samples, self.input_dim]).view(num_samples, self.input_dim)
    z = z.to(self.device)
    return self.inverse(z)

  def _get_mask(self, dim, orientation=True):
    mask = torch.zeros(dim)
    mask[::2] = 1.
    if orientation:
      mask = 1. - mask # flip mask if orientation is True
    return mask.float()

In [None]:
# Define the dataset and data loader
train_dataset = MNIST(root='./data', train=True, download=True, transform=ToTensor())
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)

In [None]:
# Define the NICE model
model = NICE(input_dim=input_dim, num_coupling_layers=num_coupling_layers, num_layers=num_layers, device=device).to(device)

# Train the model
model.train()

optimizer = Adam(model.parameters(), lr=lr)

In [None]:
for epoch in range(epochs):
  tot_log_likelihood = 0
  batch_counter = 0

  for batch_id, (x, _) in tqdm(enumerate(train_loader)):
      
      model.zero_grad()

      x = x.to(device)
      x = x.view(-1, 28*28)  # flatten
      
      z, log_likelihood = model(x)
      loss = -torch.mean(log_likelihood)  # NLL

      loss.backward()
      optimizer.step()      

      tot_log_likelihood -= loss
      batch_counter += 1

  mean_log_likelihood = tot_log_likelihood / batch_counter  # normalize w.r.t. the batches
  print(f'Epoch {epoch+1:d} completed. Log Likelihood: {mean_log_likelihood:.4f}')

In [None]:
# import os
# if not os.path.isdir("saved_models"):
#     os.makedirs("saved_models")

# torch.save(model.state_dict(), "saved_models/NICE.pt")