In [None]:
import torch

from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader

import numpy as np
from tqdm import tqdm

In [None]:
from utils import CNNModel, generate_samples

In [None]:
torch.manual_seed(42)
np.random.seed(42)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
epochs = 20         # number of training epochs
batch_size = 128    # batch size
lr=1e-4             # learning rate
alpha=0.1           # regularization coeff
steps = 60          # number of Langevin steps during training
step_size = 10      # size of each Langevin steps

In [None]:
x_shape = (1, 28, 28)

In [None]:
# Load the MNIST dataset

# normalize images between -1, 1
# Min of input image = 0 -> 0-0.5 = -0.5 -> gets divided by 0.5 std -> -1
# Max of input image = 255 -> toTensor -> 1 -> (1 - 0.5) / 0.5 -> 1
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,)) 
                               ])

train_dataset = MNIST(root='./data', train=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

## How to select elements with a certain probability

Suppose we have an array of $d$ elements, and for each of them we want to select it with probability $p$.

In [None]:
r = np.random.rand(100)  # generate 10 random elements uniform in (0, 1)
p = 0.05  # select with probability 5%

selected_idx = np.argwhere(r < p)  # 5% probability of r[i] being lower than 0.05 for each i

print(selected_idx)  # need to reshape it

## Training heuristic

- Keep a dataset of generated examples during training
- Take few Langevin steps at each training iteration (e.g., 60)
- At every iteration, re-initialize each sample with 5% probability 

In [None]:
# initialize examples at random
# (1,) + x_shape = (1, 1, 28, 28)
examples = [torch.rand((1,)+x_shape)*2 - 1 for _ in range(batch_size)]

def generate_sample_train(model, steps=60, step_size=10):
    # re-initialize each example with 5% probabiliy
    r = np.random.rand(batch_size)
    r_idx = np.argwhere(r < 0.05).reshape(-1,)
    old_idx = np.argwhere(r > 0.05).reshape(-1,)
    z = torch.rand((len(r_idx),) + x_shape) * 2 - 1  # create random examples
    x_old = torch.cat([examples[idx] for idx in old_idx], dim=0)
    x_new = torch.cat([z, x_old], dim=0).detach().to(device)

    x_new = generate_samples(model, x_new, steps=steps, step_size=step_size)

    examples[:] = list(x_new.to(torch.device("cpu")).chunk(batch_size, dim=0))

    return x_new


In [None]:
model = CNNModel().to(device)
opt = torch.optim.Adam(model.parameters(), lr=lr)

## Loss function: Contrastive divergence

$$
\mathcal{L}(\theta) = \mathbb{E}[f_{\theta}(\text{Langevin}(Z))] - \mathbb{E}[f_{\theta}(X)]
$$

with
$$
X\sim p_{\text{data}}, \text{ and } Z\sim\mathcal{U}(-1, 1)
$$

remembering that

$$
E_{\theta}(x) \approx e^{f_{\theta} (x)} \Rightarrow f_{\theta} (x) \approx \log E_{\theta}(x)
$$

Intuition: Real samples should have high energy, while fake samples should have low energy 

### IMP: Loss behavior

During training, the loss values will stabilize around a certain value.
This is because: 1) the model learns to assign higher energy to the real samples and low energy to the fake samples, but at the same time 2) the generated fake samples will improve

In [None]:
for epoch in range(epochs):
    losses = []
    for x, _ in tqdm(train_loader):

        opt.zero_grad()

        x = x.to(device)

        x_hat = generate_sample_train(model, steps=steps, step_size=step_size)

        # reshape images for Conv network
        x = x.view((-1,)+x_shape)
        x_hat = x_hat.view((-1,)+x_shape)

        out_real = model(x)
        out_fake = model(x_hat)

        cd_loss = out_fake.mean() - out_real.mean()
        reg_loss = (out_real ** 2).mean() + (out_fake ** 2).mean()
        loss = cd_loss + alpha * reg_loss

        loss.backward()
        losses.append(loss.item())

        opt.step()
  
    avg_loss = np.mean(losses)
    print(f"Epoch {epoch+1:03d}:{epochs:03d}, Loss: {avg_loss:.4f}")

In [None]:
# import os
# if not os.path.isdir("saved_models"):
#     os.makedirs("saved_models")

# torch.save(model.state_dict(), "saved_models/EBM.pt")