## Energy-Based Model, with Hamiltonian Monte Carlo Sampling

### Tips

+ Training with naive contrastive-divergence algorithm will make your model diverge quickly (think about why). Therefore, you need to add a L2 regularization term $\alpha(E_\theta(x+)^2 + E_\theta(x-)^2)$ to stabilize training.

+ Take a look at the paper [Implicit Generation and Generalization in Energy Based Models](https://arxiv.org/pdf/1903.08689.pdf) to learn more about useful tricks to get your model working.

## Set Up Code

If you use Colab in this coding project, please uncomment the code, fill the `GOOGLE_DRIVE_PATH_AFTER_MYDRIVE` and run the following cells to mount your Google drive. Then, the notebook can find the required file (i.e., utils.py). If you run the notebook locally, you can skip the following cells.

In [1]:
%load_ext autoreload
%autoreload 2
# %matplotlib widget

# from google.colab import drive
# drive.mount('/content/drive')

# import sys
# sys.path.append(GOOGLE_DRIVE_PATH)

!nvidia-smi
!which python

Sat Sep 21 17:02:41 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.86.10              Driver Version: 535.86.10    CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-SXM2-32GB           On  | 00000004:04:00.0 Off |                    0 |
| N/A   55C    P0              71W / 300W |      0MiB / 32768MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
from ebm_hamiltonian.models.utils import hello
hello()

Files already downloaded and verified
Files already downloaded and verified
Good luck!


In [3]:
from collections import deque
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from tqdm import tqdm

import numpy as np
import os
import time
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F

from math import sqrt

seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

from ebm_hamiltonian.models.utils import save_model, load_model, corruption, train_set, val_set

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

os.makedirs('./ebm', exist_ok=True)

## Enenrgy Model

In [4]:
from ebm_hamiltonian.models.resnet import ResNet_MNIST, CNN_MNIST, SNCNN_MNIST
from ebm_hamiltonian.models.basicblock import BasicBlock

# input size: 1x28x28

def MLP(hidden_dim=None):
    model = nn.Sequential(
        nn.Flatten(),
        nn.Linear(28*28, 1024),
        nn.ReLU(),
        nn.Linear(1024, 512), 
        nn.ReLU(),
        # nn.Linear(512, 512), 
        # nn.ReLU(),
        nn.Linear(512, 1)
    )
    return model

def resnet34(hidden_dim=128):
    """Constructs a ResNet-34 model."""
    model = ResNet_MNIST(block=BasicBlock, 
                         hidden_dim=hidden_dim,
                   layers=[5, 5, 5]
                   )
    return model

def Swish(x):
    return x * torch.sigmoid(x)

def cnn(hidden_dim=128, activation=Swish):
    model = CNN_MNIST(block=BasicBlock, hidden_dim=hidden_dim, activation=activation)
    return model

def sncnn(hidden_dim=128, activation=Swish):
    model = SNCNN_MNIST(block=BasicBlock, hidden_dim=hidden_dim, activation=activation)
    return model


## Sampling

Langevin dynamics. Pay attention to the gradients of both your energy model and input.

Hamiltonian Monte Carlo. TODO

In [5]:
def get_energy_grad(energy_model, x):
    x = x.clone().detach()
    x.requires_grad = True
    energy_model.zero_grad()
    energy = energy_model(x)
    energy.sum().backward()
    energy_model.zero_grad()
    # print('x.grad:', x.grad.min(), x.grad.max())
    return x.grad

def langevin_step(energy_model, x, eps, step_lr, max_grad_norm=None):
    """
    Perform one step of Langevin dynamics sampling.

    Args:
        energy_model (nn.Module): The energy-based model used for sampling.
        x (torch.Tensor): The input tensor to update via Langevin dynamics.
        step_lr (float): The step size of the optimizer used to update the input. 
        eps (float): noise level
        max_grad_norm (float or None): The maximum norm of the gradient for gradient clipping.

        x <- x + eps^2/2 * grad_x log p(x) + eps * N(0, I)
            = x - eps^2/2 * grad_x E(x) + eps * N(0, I)

    Returns:
        torch.Tensor: The updated input tensor after one step of Langevin dynamics.
    """
    gradient = get_energy_grad(energy_model, x)

    noise = torch.randn_like(x) * (eps**0.5)

    # if max_grad_norm is not None:
    #     gradient = gradient.clamp(-max_grad_norm, max_grad_norm)
    # gradient = gradient.reshape(x.shape[0],-1)

    # if max_grad_norm is not None:
    #     do_clip = gradient.norm(dim=-1,keepdim=True) > max_grad_norm
    #     gradient = torch.where(do_clip, gradient / gradient.norm(dim=-1, keepdim=True) * max_grad_norm, gradient)
    # gradient = gradient.reshape_as(x)
    # print('gradient.range:', gradient.min(), gradient.max())

    x_new = x - step_lr * gradient + noise
    
    x_new = torch.clamp(x_new, 0, 1)  # Ensure pixel values are between 0 and 1

    return x_new

def hamiltonian_step(energy_model, x, A, B, n_steps, max_grad_norm=None, reject=False, bound=True):
    """
    Perform one step of Hamiltonian MCMC.

    Args:
        energy_model (nn.Module): The energy-based model used for sampling.
        x (torch.Tensor): The input tensor to update via Langevin dynamics.
        A: the magnitude of change of image per pixel
        B: should be the same magnitude as the gradient of energy
        eps (float): step size
        p_norm: the initial momentum norm
        n_steps (int): number of leapfrog steps before resampling p (or the L in the paper)
        max_grad_norm (float or None): The maximum norm of the gradient for gradient clipping. (optional)
        reject: if True, we may reject due to energy difference

    Returns:
        torch.Tensor: The updated input tensor after one step of Langevin dynamics.
    """
    eps = sqrt(A/B)
    p_norm = sqrt(A*B)
    x_init = x.clone().detach()
    gradient = get_energy_grad(energy_model, x)
        
    p = torch.randn_like(x) * p_norm
    assert p.shape == torch.Size([x.shape[0], 1, 28, 28])

    init_total_energy = energy_model(x).squeeze(1) + 0.5 * p.pow(2).sum(dim=(1, 2, 3))

    p = p - 0.5 * eps * gradient

    for i in range(n_steps):
        x = x + eps * p
        if bound:
            p = torch.where(x < 0, -p, p)
            p = torch.where(x > 1, -p, p)
            x = torch.where(x < 0, -x, x)
            x = torch.where(x > 1, 2 - x, x)
        x = torch.clamp(x, 0, 1)
        if i != n_steps - 1:
            gradient = get_energy_grad(energy_model, x)
            p = p - eps * gradient
        # print('x:',x[0,0,0,0])
        # print('p:',p[0,0,0,0])

    gradient = get_energy_grad(energy_model, x)
    p = p - 0.5 * eps * gradient

    current_energy = energy_model(x).squeeze(1) + 0.5 * p.pow(2).sum(dim=(1, 2, 3))

    # print("total energy difference: ", (current_energy - init_total_energy).mean().item())
    assert init_total_energy.shape == current_energy.shape and init_total_energy.shape == torch.Size([x.shape[0]])

    # optional: decide whether to reject
    if reject:
        accept_prob = torch.exp(init_total_energy - current_energy)
        accept = torch.rand(x.shape[0]).to(x.device) < accept_prob
        x = torch.where(accept[:, None, None, None], x, x_init)

    return x

In [6]:
# model = cnn().to(device)

In [7]:
# import math
# sqrt = math.sqrt
# def test(use_hamiltonian=False):
#     x = torch.rand(256, 1, 28, 28).to(device)
#     # print(x)
#     # model = MLP().to(device)
#     print("energy: ", model(x).sum().item())
#     if use_hamiltonian:
#         A = 2e-2 # the change of x
#         B = 2e-4 # the change of p / eps, which should be relative to the gradient
#         print("using hamiltonian")
#         for i in range(5):
#             x = hamiltonian_step(model, x, A=A, B=B, n_steps=10, max_grad_norm=0.03, reject=True)
#             # print(x)
#             print("energy:" ,model(x).sum().item())
#     else:
#         print("using langevin")
#         for i in range(100):
#             x = langevin_step(model, x, eps=0.01, step_lr=1e3, max_grad_norm=0.03)
#             # print(x)
#             print("energy:" ,model(x).sum().item())

# test(use_hamiltonian=1)
# # eps * p_norm = A = 3e-2
# # or: eps * grad (1e-2) = 5e-2
# # p_norm = eps * B (B = 1e-2)
# # eps = sqrt(A/B)
# # p_norm = sqrt(A*B)

MLP: 结果差不多, HMC稍微好一点

__langevin__: 0.2s, 60 steps, step_lr=10, eps=0.01, energy -> -210
          0.3s, 100 steps, step_lr=10, eps=0.01, energy -> -205

__hmc__: 0.0s, 1 group of 10 steps, eps = sqrt3, p_norm = sqrt(3e-4), energy -> -117

0.3s, 10 groups of 10 steps, eps = sqrt3, p_norm = sqrt(3e-4), energy -> -230

0.3s, 10 groups of 5 steps, eps = sqrt3, p_norm = sqrt(3e-4), energy -> -203
     
0.4s, 20 groups of 5 steps, eps = sqrt3, p_norm = sqrt(3e-4), energy -> -224

resnet: 不太稳定

cnn: 

__langevin__: 1.9s, 100 steps, step_lr=1000, eps=0.01, energy 12 -> 9

__hmc__: 0.6s, 3 groups of 10 steps, A = 2e-2, B = 2e-4, energy 12 -> 9

## Inpainting

Implement the inpainting procedure. Think about the difference between sampling and inpainting.

In [8]:
def inpainting(energy_model, x, mask, n_steps, step_lr, max_grad_norm, eps=0.005, add_noise=True):
    """
    Inpainting function that completes an image given a masked input using Langevin dynamics.

    Args:
        energy_model (nn.Module): The energy-based model used to generate the image.
        x (torch.Tensor): The input tensor, a masked image that needs to be completed.
        mask (torch.Tensor): The mask tensor, with the same shape as x, where 1 indicates the corresponding
                             pixel is visible and 0 indicates it is missing.
        n_steps (int): The number of steps of Langevin dynamics to run.
        step_lr (float): The step size of Langevin dynamics.
        max_grad_norm (float or None): The maximum gradient norm to be used for gradient clipping. If None, 
                                       no gradient clipping is performed.

    Returns:
        torch.Tensor: The completed image tensor.
    """
    reverse_mask = 1-mask


    for _ in range(n_steps):
        gradient = get_energy_grad(energy_model, x)
        if max_grad_norm is not None:
            gradient = gradient.clamp(-max_grad_norm, max_grad_norm)

        add_noise = False
        if add_noise:
            noise = eps * torch.randn_like(x) * reverse_mask
        else:
            noise = 0
        if _ == n_steps-1:
            noise = 0

        x = x - step_lr * gradient * reverse_mask + noise
        x = torch.clamp(x, 0, 1)  # Ensure pixel values are between 0 and 1


    return x

def inpainting_hmc(energy_model, x, mask, n_steps_per_sample, n_samples, A, B, max_grad_norm, reject=False, bound=True):
    """
    Inpainting function that completes an image given a masked input using MHC.

    Args:
        energy_model (nn.Module): The energy-based model used to generate the image.
        x (torch.Tensor): The input tensor, a masked image that needs to be completed.
        mask (torch.Tensor): The mask tensor, with the same shape as x, where 1 indicates the corresponding
                             pixel is visible and 0 indicates it is missing.
        n_steps_per_sample (int): The number of leapfrog steps to run for each sample.
        n_samples (int): The number of samples of p to generate.
        A & B: the same as in hamiltonian_step
        max_grad_norm (float or None): The maximum gradient norm to be used for gradient clipping. If None, 
                                       no gradient clipping is performed.
        reject: if True, we may reject due to energy difference

    Returns:
        torch.Tensor: The completed image tensor.
    """

    eps = sqrt(A/B)
    p_norm = sqrt(A*B)

    reverse_mask = 1-mask

    current_energy = energy_model(x).squeeze(1)

    for _ in range(n_samples):
        x_init = x.clone().detach()
        gradient = get_energy_grad(energy_model, x)
            
        p = torch.randn_like(x) * p_norm * reverse_mask # we only allow moving the pixels that are free
        assert p.shape == torch.Size([x.shape[0], 1, 28, 28])

        init_total_energy = current_energy + 0.5 * p.pow(2).sum(dim=(1, 2, 3))

        p = p - 0.5 * eps * gradient * reverse_mask

        for i in range(n_steps_per_sample):
            x = x + eps * p * reverse_mask
            if bound:
                p = torch.where(x < 0, -p, p)
                p = torch.where(x > 1, -p, p)
                x = torch.where(x < 0, -x, x)
                x = torch.where(x > 1, 2 - x, x)
            x = torch.clamp(x, 0, 1)
            if i != n_steps_per_sample - 1:
                gradient = get_energy_grad(energy_model, x)
                p = p - eps * gradient * reverse_mask
            # print('x:',x[0,0,13,15])
            # print('p:',p[0,0,13,15])

        gradient = get_energy_grad(energy_model, x)
        p = p - 0.5 * eps * gradient * reverse_mask

        new_energy = energy_model(x).squeeze(1)
        new_total_energy = new_energy + 0.5 * p.pow(2).sum(dim=(1, 2, 3))

        # print("total energy difference: ", (new_total_energy - init_total_energy).mean().item())

        # optional: decide whether to reject
        if reject:
            accept_prob = torch.exp(init_total_energy - new_total_energy)
            accept = torch.rand(x.shape[0]).to(x.device) < accept_prob
            x = torch.where(accept[:, None, None, None], x, x_init)
            current_energy = torch.where(accept, new_energy, current_energy)
        else:
            current_energy = new_energy

    return x

In [9]:
def evaluate(energy_model, val_loader, use_hmc, sh, device=device):
    """
    Evaluates the energy model on the validation set and returns the corruption MSE,
    recovered MSE, corrupted images, and recovered images for visualization.

    Args:
        energy_model (nn.Module): Trained energy-based model.
        val_loader (torch.utils.data.DataLoader): Validation data loader.
        use_hmc (bool)
        sh (sampling_hyperparamters): a dictionary of parameters, depending on whether to use hmc
        langevin:
            n_sample_steps (int): Number of Langevin dynamics steps to take when sampling.
            step_lr (float): Learning rate to use during Langevin dynamics.
            langevin_grad_norm (float): Maximum L2 norm of the Langevin dynamics gradient.
            eps & add_noise
        hmc:
            n_steps_per_sample (int): Number of leapfrog steps to run for each sample.
            n_samples (int): Number of samples of p to generate.
            A & B: the same as in hamiltonian_step
            max_grad_norm (float or None): The maximum gradient norm to be used for gradient clipping.
            reject: if True, we may reject due to energy difference
        device (str): Device to use (default='cuda').
    """
    mse = corruption_mse = 0
    energy_before_sampling = energy_after_sampling = 0
    n_batches = 0
    energy_model.eval()

    if use_hmc:
        n_steps_per_sample = sh['n_steps_per_sample']
        n_samples = sh['n_samples']
        A = sh['A']
        B = sh['B']
        max_grad_norm = sh['max_grad_norm']
        reject = sh['reject']
        bound = sh.get('bound', True)

    else:
        n_sample_steps = sh['n_sample_steps']
        step_lr = sh['step_lr']
        langevin_grad_norm = sh['langevin_grad_norm']
        add_noise = sh['add_noise']
        eps = sh['eps']

    pbar = tqdm(total=len(val_loader.dataset))
    pbar.set_description('Eval')
    for data, _ in val_loader:
        n_batches += data.shape[0]
        data = data.to(device)
        broken_data, mask = corruption(data, type_='ebm')
        energy_before_sampling += energy_model(broken_data).sum().item()
        if use_hmc:
            recovered_img = inpainting_hmc(energy_model, broken_data, mask,
                                           n_steps_per_sample, n_samples, A, B, max_grad_norm, reject, bound)
        else:
            recovered_img = inpainting(energy_model, broken_data, mask,
                                   n_sample_steps, step_lr, langevin_grad_norm, add_noise=add_noise, eps=eps)
        energy_after_sampling += energy_model(recovered_img).sum().item()

        mse += np.mean((data.detach().cpu().numpy().reshape(-1, 28 * 28) - recovered_img.detach().cpu().numpy().reshape(-1, 28 * 28)) ** 2, -1).sum().item()
        corruption_mse += np.mean((data.detach().cpu().numpy().reshape(-1, 28 * 28) - broken_data.detach().cpu().numpy().reshape(-1, 28 * 28)) ** 2, -1).sum().item()

        pbar.update(data.shape[0])
        pbar.set_description('Corruption MSE: {:.6f}, Recovered MSE: {:.6f}, E Before Sampling: {:.6f}, E After Sampling: {:.6f}'.format(
            corruption_mse / n_batches, mse / n_batches, energy_before_sampling / n_batches, energy_after_sampling / n_batches))
        # show the image
        # import matplotlib.pyplot as plt
        # plt.imshow(recovered_img[0].detach().cpu().numpy().reshape(28, 28), cmap='gray')
        # plt.show()
        # break

    pbar.close()
    return (corruption_mse / n_batches, mse / n_batches, data[:100].detach().cpu(), broken_data[:100].detach().cpu(), recovered_img[:100].detach().cpu())


## Training

In [10]:
import random
import matplotlib.pyplot as plt

def train(n_epochs, energy_model, train_loader, val_loader, optimizer, use_hmc, sh, l2_alpha, device=device, buffer_maxsize=int(1e4), replay_ratio=0.95, 
          save_interval=1, max_norm=None):
    if use_hmc:
        n_steps_per_sample = sh['n_steps_per_sample']
        n_samples = sh['n_samples']
        A = sh['A']
        B = sh['B']
        max_grad_norm = sh.get('max_grad_norm', None)
        reject = sh.get('reject', False)
        bound = sh.get('bound', True)
    else:
        n_sample_steps = sh['n_sample_steps']
        step_lr = sh['step_lr']
        langevin_grad_norm = sh['langevin_grad_norm']
        langevin_eps = sh['langevin_eps']
        eval_step_lr = sh.get('eval_step_lr', None)
        if eval_step_lr is None:
            eval_step_lr = step_lr
    energy_model.to(device)
    buffer_maxsize = int(1e4)
    buffer_size = buffer_ptr = 0
    replay_buffer = torch.zeros(buffer_maxsize, 1, 28, 28)
    buffer_size = buffer_ptr = 0
    best_mse = np.inf

    def show_replay_buffer():
        for i in range(4):
            idx = np.random.randint(0, buffer_size)
            # print(replay_buffer[idx])
            plt.imshow(replay_buffer[idx].squeeze().cpu().numpy(), cmap='gray')
            # plt.show()
            if os.path.exists(f'./ebm/replay_buffer/{epoch + 1}') == False:
                os.makedirs(f'./ebm/replay_buffer/{epoch + 1}')
            plt.savefig(f'./ebm/replay_buffer/{epoch + 1}/{idx}.png')

    for epoch in range(n_epochs):
        train_loss = energy_before = energy_plus = energy_minus = n_batches = 0
        pbar = tqdm(total=len(train_loader.dataset))
        pbar.set_description('Train')
        for i, (x_plus, _) in enumerate(train_loader):
            n_batches += x_plus.shape[0]
            bs = x_plus.shape[0]
            # x_plus shape:  torch.Size([256, 1, 28, 28])

            # init negative samples
            if buffer_size == 0:
                x_minus = torch.rand_like(x_plus)
            else:
                x_minus = torch.zeros_like(x_plus)

                for j in range(bs):
                    p = random.random()
                    if (p < replay_ratio): 
                    # Sample a random index from the replay buffer
                        idx = np.random.randint(0, buffer_size)
                    # Get the corresponding sample from the replay buffer
                        x_minus[j] = replay_buffer[idx]
                    else:
                        x_minus[j] = torch.rand_like(x_plus[j])

            x_minus = x_minus.to(device)
            energy_before += energy_model(x_minus).sum().item()

            # sample negative samples

            if use_hmc:
                for _ in range(n_samples):
                    x_minus = hamiltonian_step(energy_model, x_minus, A, B, n_steps_per_sample, max_grad_norm=max_grad_norm, reject=reject, bound=bound)
            else:
                for _ in range(n_sample_steps):
                    x_minus = langevin_step(energy_model, x_minus, langevin_eps, step_lr=step_lr, max_grad_norm=langevin_grad_norm)

            # extend buffer
            if buffer_ptr + bs <= buffer_maxsize:
                replay_buffer[buffer_ptr: buffer_ptr +
                              bs] = ((x_minus * 255).to(torch.uint8).float() / 255).cpu()
            else:
                x_minus_ = (
                    (x_minus * 255).to(torch.uint8).float() / 255).cpu()
                replay_buffer[buffer_ptr:] = x_minus_[
                    :buffer_maxsize - buffer_ptr]
                remaining = bs - (buffer_maxsize - buffer_ptr)
                replay_buffer[:remaining] = x_minus_[
                    buffer_maxsize - buffer_ptr:]
            buffer_ptr = (buffer_ptr + bs) % buffer_maxsize
            buffer_size = min(buffer_maxsize, buffer_size + bs)

            # compute loss
            energy_model.train()
            x_plus = x_plus.to(device)
            x_minus = x_minus.to(device)
            e_plus = energy_model(x_plus)
            e_minus = energy_model(x_minus)

            # loss = e_plus + torch.clamp(delta - e_minus, min=0) 
            energy_loss = e_plus - e_minus
            # add a term: l2alpha x (e_plus ^ 2 + e_minus ^ 2)
            l2_loss = l2_alpha * (e_plus ** 2 + e_minus ** 2)
            loss = energy_loss.mean() + l2_loss.mean()
            optimizer.zero_grad()
            loss.backward()
            
            if max_norm is not None:
                torch.nn.utils.clip_grad_norm_(energy_model.parameters(), max_norm=max_norm)  # Gradient clipping
            optimizer.step()

            ##############################################################################

            train_loss += loss.sum().item()
            energy_plus += e_plus.sum().item()
            energy_minus += e_minus.sum().item()

            pbar.update(x_plus.size(0))
            pbar.set_description("Epoch {}, Train Loss: {:.6f}, ".format(epoch + 1, train_loss / n_batches) +
                                 "E Before Sampling: {:.6f}, ".format(energy_before / n_batches) +
                                 "After: {:.6f}, ".format(energy_minus / n_batches) +
                                 "Ground Truth: {:.6f}".format(energy_plus / n_batches))
        pbar.close()
        show_replay_buffer()

        if (epoch + 1) % save_interval == 0:
            os.makedirs(f'./ebm/{epoch + 1}', exist_ok=True)
            energy_model.eval()
            save_model(f'./ebm/{epoch + 1}/ebm.pth',
                       energy_model, optimizer, replay_buffer)

            # evaluate inpaiting
            # feel free to change the inpainting parameters!
            if use_hmc:
                sh = {
                    'n_steps_per_sample': n_steps_per_sample,
                    'n_samples': n_samples,
                    'A': A,
                    'B': B,
                    'max_grad_norm': max_grad_norm,
                    'reject': reject,
                    'bound': bound
                }
            else:
                sh = {
                    'n_sample_steps': n_sample_steps,
                    'step_lr': eval_step_lr,
                    'langevin_grad_norm': langevin_grad_norm,
                    'eps': langevin_eps,
                    'add_noise': False
                }
            c_mse, r_mse, original, broken, recovered = evaluate(
                energy_model, val_loader, use_hmc, sh)
            torchvision.utils.save_image(
                original, f"./ebm/{epoch + 1}/groundtruth.png", nrow=10)
            torchvision.utils.save_image(
                broken, f"./ebm/{epoch + 1}/corrupted.png", nrow=10)
            torchvision.utils.save_image(
                recovered, f"./ebm/{epoch + 1}/recovered.png", nrow=10)
            if r_mse < best_mse:
                print(f'Current best MSE: {best_mse} -> {r_mse}')
                best_mse = r_mse
                save_model('./ebm/ebm_best.pth', energy_model)

Here we define the model!

In [11]:
model = cnn(hidden_dim=1024).to(device)
# model = MlpBackbone((1, 28, 28), 1024).to(device)
model = model.to(device)
# for param in model.parameters():
#     if param.dim() > 1:
#         nn.init.orthogonal_(param)


print(model)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.0, 0.999))

train_loader = DataLoader(train_set, 256, shuffle=True, drop_last=False, pin_memory=True)
val_loader = DataLoader(val_set, 500, shuffle=True, drop_last=False, pin_memory=True)

CNN_MNIST(
  (conv1): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pool1): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (conv3): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (res_layer): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (2): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (fc1)

Now start training!

In [12]:
# training hyper-parameters!
use_hmc = True
epochs = 5
l2_alpha = 0.01
max_norm = 1.0

# when using langevin dynamics
n_sample_steps = 60
step_lr = 100
eps = 5e-3
langevin_grad_norm = 0.3
replay_ratio = 0.95
eval_step_lr = 10

# when using hamiltonian dynamics
n_steps_per_sample = 10
n_samples = 5
A = 2e-2
B = 2e-4
max_grad_norm = 0.03
reject = True
bound = True

if use_hmc:
      sh = {
            'n_steps_per_sample': n_steps_per_sample,
            'n_samples': n_samples,
            'A': A,
            'B': B,
            'max_grad_norm': max_grad_norm,
            'reject': reject,
            'bound': bound
      }

else:
      sh = {
            'n_sample_steps': n_sample_steps,
            'step_lr': step_lr,
            'langevin_grad_norm': langevin_grad_norm,
            'langevin_eps': eps,
            'add_noise': False
      }

train(epochs, model, train_loader, val_loader, optimizer, use_hmc, sh, l2_alpha, device=device, max_norm=max_norm)

# n_sample_steps, step_lr, langevin_eps, langevin_grad_norm, l2_alpha

Epoch 1, Train Loss: -0.018160, E Before Sampling: 2.272072, After: 2.223962, Ground Truth: -3.667976:  14%|▏| 8448/60000 [00:35<03:30, 244.51it/s]  

KeyboardInterrupt: 

Epoch 1, Train Loss: -0.018160, E Before Sampling: 2.272072, After: 2.223962, Ground Truth: -3.667976:  14%|▏| 8448/60000 [00:49<03:30, 244.51it/s]

In [14]:
import matplotlib.pyplot as plt

def show_recover(device = device) :
    for data, _ in val_loader:
        x = data[0]
        x = x.to(device)
        broken_data, mask = corruption(x, type_='ebm')
        # print(broken_data)
        # show the energy of broken_data, first resize the broken_data to (1, 1, 28, 28)
        broken_data = broken_data.unsqueeze(0)
        
        # Calculate the energy of the broken_data using the energy_model
        model.eval()
        energy = model(broken_data)
        print("Energy of broken_data:", energy.item())

        # Perform inpainting on the broken_data
        recovered_img = inpainting(model, broken_data, mask, 500, 0.05, langevin_grad_norm)
        print(recovered_img - broken_data)
        # show the energy of recovered_img
        recovered_img = recovered_img.squeeze(0)
        recovered_img = recovered_img.unsqueeze(0)
        energy = model(recovered_img)
        print("Energy of recovered_img:", energy.item())
        # show the recovered image and the corrupted image
        plt.imshow(recovered_img.squeeze().cpu().numpy(), cmap='gray')
        plt.show()
        plt.imshow(broken_data.squeeze().cpu().numpy(), cmap='gray')

        # show the image
        plt.imshow(broken_data[0].squeeze().cpu().numpy(), cmap='gray')
        break
# show_recover()

## Evaluation

In [27]:
# evaluation parameters!
use_hmc = 1

# when using langevin dynamics
n_sample_steps = 60
step_lr = 50
langevin_grad_norm = 0.3
add_noise = True
eps = 0.02

# when using hamiltonian dynamics
n_steps_per_sample = 10
n_samples = 10
A = 2e-2
B = 2e-4
max_grad_norm = 0.03
reject = True
bound = False


In [28]:
if use_hmc:
    sh = {
        'n_steps_per_sample': n_steps_per_sample,
        'n_samples': n_samples,
        'A': A,
        'B': B,
        'max_grad_norm': max_grad_norm,
        'reject': reject,
        'bound': bound
    }

else:
    sh = {
        'n_sample_steps': n_sample_steps,
        'step_lr': step_lr,
        'langevin_grad_norm': langevin_grad_norm,
        'add_noise': add_noise,
        'eps': eps
    }

model.load_state_dict(load_model('./ebm/ebm_best.pth')[0])
corruption_mse, mse, original, broken, recovered = evaluate(
    model, val_loader, use_hmc, sh)
os.makedirs(f'./ebm/eval', exist_ok=True)

torchvision.utils.save_image(
                original, f"./ebm/eval/groundtruth.png", nrow=10)
torchvision.utils.save_image(
                broken, f"./ebm/eval/corrupted.png", nrow=10)
s = f"./ebm/eval/recovered_langevin.png" if use_hmc == 0 else f"./ebm/eval/recovered_hmc.png"
torchvision.utils.save_image(
                recovered, s, nrow=10)


print(f'Corruption MSE: {corruption_mse}')
print(f'Recovered MSE: {mse}')


Corruption MSE: 0.062915, Recovered MSE: 0.010558, E Before Sampling: 0.002118, E After Sampling: -0.002617:  15%|▏| 1500/10000 [00:10<01:00, 141.63it/s]

KeyboardInterrupt: 

hmc 5*10 36.8s
langevin 100 1m2.8s
langevin 60 39.1s