In [None]:
import torch
import torch.nn as nn

import numpy as np
import matplotlib.pyplot as plt

In [None]:
from utils import CNNModel

In [None]:
torch.manual_seed(42)  # PRNG

## Swish activation function

$$
y = x \cdot \text{sigmoid}(x) = x \cdot \frac{1}{1 + e^{-x}}
$$

One of the many possible smooth versions of ReLU.

In [None]:
x_range = torch.arange(-5, 5, 0.1)
y_range = x_range * torch.sigmoid(x_range)

plt.figure()
plt.plot(x_range, y_range, '-', color='C1')
plt.xlim([x_range[0], x_range[-1]])
plt.grid(linestyle=':')
plt.xlabel('x')
plt.ylabel('Swish(x)')
plt.show()

## Manual gradients with Pytorch

Suppose we have some function 
$$
y = f(x), x \in \mathbb{R}^{d}
$$

and we want to compute
$$
\nabla_x y = \left[ \frac{\partial y}{\partial x_1}, \cdots,  \frac{\partial y}{\partial x_1} \right]^{\top}
$$

This can be done by using <code>torch.autograd.grad</code>

In [None]:
x = torch.arange(3, dtype=float)
x.requires_grad = True

y = torch.sum(x**2)

grad, = torch.autograd.grad(y, x)  # gradient of y w.r.t. x

In [None]:
# guess the result?
print(grad)

## Sampling via Langevin Monte Carlo Markov Chain

Iterative sampling starting from uniform noise $x(0)\sim \mathcal{U}(-1, 1)$:

$$
x(t+1) = x(t) + \epsilon\cdot \nabla_{x(t)} f_{\theta}(x(t)) + z(t)
$$
where
$$
z(t) \sim(0, 2\epsilon I)
$$

In practice we tune the standard deviation of $z(t)$ manually.

Default version of the model is very unstable
Tweaks to stabilize the training:
- clip the gradient $\nabla_{x(t)} f_{\theta}(x(t))$ (either values or norm)
- clip the values of $x(t)$ after applying the noise and after applying the gradient

In [None]:
def generate_samples(model, x_hat, steps=60, step_size=10):
    """
    Simplified version of the generate_samples function in utils.py
    """
    # Before MCMC, set model parameters to "required_grad=False"
    # since we are only interested in the gradients w.r.t. x
    for p in model.parameters():
        p.requires_grad = False
    x_hat.requires_grad = True

    # allocate a tensor to generate noise each loop iteration
    # more efficient than creating a new tensor every iteration
    noise = torch.randn(x_hat.shape, device=x_hat.device)

    x_steps = []

    # iterative MCMC sampling
    for _ in range(steps):
        noise.normal_(0, 0.005)  # normally the stddev would be sqrt(2*eps), but this turned out to be a better value
        # noise.normal_(0, np.sqrt(2*step_size))  # this noise is too big --> bad results
        x_hat = torch.clamp(x_hat + noise, -1, 1)  # clip x_hat after applying noise

        # calculate gradients w.r.t. input.
        out = model(x_hat).sum()
        grad, = torch.autograd.grad(out, x_hat, only_inputs=True)

        # for stability (limits gradient value)
        grad = torch.clamp(grad, -0.03, 0.03)

        # Langevin step
        x_hat = x_hat + step_size * grad

        # clip x_hat after applying gradient
        x_hat = torch.clamp(x_hat, -1, 1)

        x_steps.append(x_hat.clone().detach())

    # Reactivate gradients for the model parameters
    for p in model.parameters():
        p.requires_grad = True

    return torch.stack(x_steps, dim=0)

In [None]:
model = CNNModel()

loaded_state_dict = torch.load("saved_models/EBM.pt", map_location=torch.device('cpu'))  # when you train with CUDA but evaluate on CPU
model.load_state_dict(loaded_state_dict)

model.eval()

In [None]:
x_0 = torch.rand((10, 1, 28, 28)) * 2 - 1  # Start from uniform noise in (-1, 1) 
x_steps = generate_samples(model, x_0, steps=500, step_size=10)

for k in range(x_steps.shape[1]):
    fig, ax = plt.subplots(nrows=1, ncols=10)
    fig.set_figheight(6)
    fig.set_figwidth(14)
    for i, x_step in enumerate(x_steps):
        if i % 50 == 0:
            idx = i // 50
            ax[idx].imshow(x_step[k].view(28, 28).cpu().detach().numpy(), cmap='binary')
            ax[idx].set_axis_off()
    plt.draw()
plt.show()