---

## Deep Learning Coding Project 3-1: Energy-Based Model

Before we start, please put your **Chinese** name and student ID in following format:

Name, 0000000000 // e.g.) 傅炜, 2021123123

YOUR ANSWER HERE

## Introduction

We will use Python 3, [NumPy](https://numpy.org/), and [PyTorch](https://pytorch.org/) packages for implementation. This notebook has been tested under the latest stable release version.

In this coding project, you will implement 4 generative models, i.e., energy-based model, flow-based model, variational auto-encoder, and generative adverserial network, to generate MNIST images.

**We will implement an energy-based model in this notebook.**

In some cells and files you will see code blocks that look like this:

```Python
##############################################################################
#                  TODO: You need to complete the code here                  #
##############################################################################
raise NotImplementedError()
##############################################################################
#                              END OF YOUR CODE                              #
##############################################################################
```

You should replace `raise NotImplementedError()` with your own implementation based on the context, such as:

```Python
##############################################################################
#                  TODO: You need to complete the code here                  #
##############################################################################
y = w * x + b
##############################################################################
#                              END OF YOUR CODE                              #
##############################################################################

```

When completing the notebook, please adhere to the following rules:

+ Do not write or modify any code outside of code blocks
+ Do not add or delete any cells from the notebook.
+ Run all cells before submission. We will not re-run the entire codebook during grading.

**Finally, avoid plagiarism! Any student who violates academic integrity will be seriously dealt with and receive an F for the course.**

### Task

The energy-based method aims to train a parameterized model $E = f(x;\theta)$ to
model the unnormalized data distribution $p(x)\propto \exp(-E)$. In this notebook, we instantiate
$E = f(x;\theta)$ as an MLP. Your tasks are as follows:

1. **Implement all the missing parts in the contrastive-divergence training pipeline.**

Basically, we want to decrease
the energy of positive samples while increase the energy of negative samples. The positive samples are from the training set, and the negative
samples are sampled using Langevin dynamics starting from either random noise or previously generated samples.

2. **Implement an inpainting procedure to recover the original image.**

We corrupt the images by adding noise to the pixels in even rows (see
below). Please implement an inpainting procedure to recover the original
image, then report the
mean squared difference between your recovered images and the ground
truth images.

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib import rcParams

%matplotlib inline

# figure size in inches optional
rcParams['figure.figsize'] = 11, 8

# read images
img_A = mpimg.imread('./ebm/groundtruth.png')
img_B = mpimg.imread('./ebm/corrupted.png')

# display images
fig, ax = plt.subplots(1, 2)
ax[0].imshow(img_A)
ax[1].imshow(img_B)

### Submission

You need to submit your code (this notebook), your trained model (named `./ebm/ebm_best.pth`), and your report:

+ **Code**

Remember to run all the cells before submission. Remain your tuned hyperparameters unchanged.

+ **Model**

In this notebook, we select the best model according to the MSE of inpainting. You can also manually test your models and select the best one. **Please do not submit any other checkpoints except for `./ebm/ebm_best.pth`!**

+ **Report**

Please include inpainting examples and the inpainting MSE on validation set in your
report. Note that you only need to write a single report for this coding project.

### Grading

Your implementation will be graded based on **the mean squared error
of inpainting**.

### Tips

+ Training with naive contrastive-divergence algorithm will make your model diverge quickly (think about why). Therefore, you need to add a L2 regularization term $\alpha(E_\theta(x+)^2 + E_\theta(x-)^2)$ to stabilize training.

+ Keep track of the generated samples during training to get a sense of how well your model is evolving.

+ You can take a look at the paper [Implicit Generation and Generalization in Energy Based Models](https://arxiv.org/pdf/1903.08689.pdf) to learn more about useful tricks to get your model working.

+ Make sure your code runs fine with the evaluation cell in this notebook.

## Set Up Code

If you use Colab in this coding project, please uncomment the code, fill the `GOOGLE_DRIVE_PATH_AFTER_MYDRIVE` and run the following cells to mount your Google drive. Then, the notebook can find the required file (i.e., utils.py). If you run the notebook locally, you can skip the following cells.

In [1]:
%load_ext autoreload
%autoreload 2

In [22]:
# from google.colab import drive
# drive.mount('/content/drive')

In [23]:
# import sys
# sys.path.append(GOOGLE_DRIVE_PATH)

In [2]:
from utils import hello
hello()

Good luck!


Finally, please run the following cell to import some base classes for implementation (no matter whether you use colab).

In [3]:
from collections import deque
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from tqdm.autonotebook import tqdm

import numpy as np
import os
import time
import torch
import torch.nn as nn
import torchvision

seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

from utils import save_model, load_model, corruption, train_set, val_set

seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

device = torch.device(
    "cuda") if torch.cuda.is_available() else torch.device("cpu")

os.makedirs('./ebm', exist_ok=True)

  from tqdm.autonotebook import tqdm


## MLP Model

We have provided an example MLP implementation. Feel free to modify the following cell the implement your own model.

**Note that your model should be an MLP!**

In [4]:
import torch.nn.functional as F
class MlpBackbone(nn.Module):
    def __init__(self, input_shape, hidden_size, activation=F.elu):
        super(MlpBackbone, self).__init__()
        self.input_shape = input_shape  # (C, H, W)
        self.hidden_size = hidden_size

        # Layers
        self.fc1 = nn.Linear(np.prod(self.input_shape), self.hidden_size)

        self.fc2 = nn.Linear(self.hidden_size, int(self.hidden_size/2))

        self.fc3 = nn.Linear(int(self.hidden_size/2), int(self.hidden_size/4))

        #self.fc4 = nn.Linear(int(self.hidden_size/2), int(self.hidden_size/4))

        self.fc5 = nn.Linear(int(self.hidden_size/4), 1)  # Output layer

        self.activation = activation

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)  # Flatten input

        x = self.fc1(x)
        x = self.activation(x)

        x = self.fc2(x)
        x = self.activation(x)

        x = self.fc3(x)
        x = self.activation(x)

        #x = self.fc4(x)
        #x = self.activation(x)

        out = self.fc5(x)                   # Output layer (no BatchNorm or activation)
        return out

## Sampling

Implement Langevin dynamics in the following cell. Pay attention to the gradients of both your energy model and input.

In [5]:
def langevin_step(energy_model, x, step_lr=10, eps=0.005, max_grad_norm=0.03, steps=1, annealing=False):
    """
    Perform multiple steps of Langevin dynamics sampling.

    Args:
        energy_model (nn.Module): The energy-based model used for sampling.
        x (torch.Tensor): The input tensor to update via Langevin dynamics.
        step_lr (float): The learning rate of the optimizer used to update the input.
        eps (float): The step size of the Langevin dynamics update.
        max_grad_norm (float or None): The maximum norm of the gradient for gradient clipping.
        steps (int): The number of Langevin dynamics steps to perform.

    Returns:
        torch.Tensor: The updated input tensor after multiple steps of Langevin dynamics.
    """
    # Set the energy model to evaluation mode and disable gradient
    is_training = energy_model.training
    energy_model.eval()
    x = x.detach().clone().requires_grad_(True)
    lr_now = step_lr
    eps_now = eps
    for p in energy_model.parameters():
        p.requires_grad = False
    
    for step in range(steps):
        # Compute the gradient of the energy w.r.t. the input tensor
        energy = energy_model(x)
        energy.backward(torch.ones_like(energy))
        grad = x.grad

        if annealing:
            lr_now = step_lr * (0.5 + 0.25*(1 + np.cos(np.pi * step / steps)))
            eps_now = eps * np.sqrt(0.1 + 0.45*(1 + np.cos(np.pi * step / steps)))
        # Optionally clip the gradient norm
        if max_grad_norm is not None:
            grad_norm = grad.norm(p=2)
            grad = grad * torch.clamp(grad_norm / max_grad_norm, max=1.0)
        
        # Update the input tensor using Langevin dynamics
        noise = eps_now * torch.randn_like(x)  # Add noise for stochasticity
        x = x - lr_now * grad + noise
        
        # Clamp values to [0,1] range since these are image pixels
        x = torch.clamp(x, 0, 1)
        
        # Reset gradients for the input tensor
        x = x.detach().clone().requires_grad_(True)
    
    for p in energy_model.parameters():
        p.requires_grad = True
    energy_model.train(is_training)
    
    return x.detach()


## Inpainting

Implement the inpainting procedure. Think about the difference between sampling and inpainting.

In [83]:
import torch
import numpy as np
from torchmetrics.functional import structural_similarity_index_measure as ssim_fn

def inpainting(
    energy_model, 
    x, 
    mask, 
    n_steps, 
    step_lr, 
    max_grad_norm, 
    max_noise=0.005,
    momentum=0.9, 
    use_momentum=False, 
    lr_schedule='cosine',
    adaptive_step=False,
    adapt_lr_factor=1.0,
    adapt_decay=0.9,
    eps_min=1e-6,
    early_stopping=True,
    tol=1e-4,
    patience=10,
    metric='ssim'
):
    """
    Inpainting function that completes an image given a masked input using Langevin
    dynamics with optional momentum, scheduling, adaptive step size, adaptive noise,
    and early stopping.

    Args:
        energy_model (nn.Module): The energy-based model used to generate the image.
        x (torch.Tensor): Masked image tensor to be completed.
        mask (torch.Tensor): Binary mask tensor (1 for observed pixel, 0 for missing).
        n_steps (int): Maximum number of Langevin steps to run.
        step_lr (float): Base step size (learning rate) for updates.
        max_grad_norm (float or None): If not None, clip gradient norms to this value.
        max_noise (float): Base noise amplitude for Langevin updates.
        momentum (float): Momentum factor for heavy-ball updates.
        use_momentum (bool): Whether to use momentum-based updates.
        lr_schedule (str): One of ['cosine', 'fixed', 'linear'] for scheduling step size/noise.
        adaptive_step (bool): Whether to adapt step size (and noise) based on gradient norms.
        adapt_lr_factor (float): Scaling factor for the adaptive step size.
        adapt_decay (float): EMA decay for running average of gradient norms.
        eps_min (float): Minimum value added to denominator (prevents division by zero).
        early_stopping (bool): If True, stops early if changes in the image are below `tol`.
        tol (float): Tolerance for the change in the image to consider convergence.
        patience (int): Number of consecutive steps with change below `tol` before stopping.
        metric (str): 'l2' or 'ssim' for measuring change in early stopping.

    Returns:
        torch.Tensor: The completed (inpainted) image tensor.
    """

    # --- PREPARE THE MODEL FOR EVALUATION ---
    was_training = energy_model.training
    energy_model.eval()
    for param in energy_model.parameters():
        param.requires_grad_(False)

    # --- INITIAL SETUP ---
    x = x.clone()  # Copy to avoid modifying the original
    device = x.device

    # Momentum velocity
    if use_momentum:
        velocity = torch.zeros_like(x)
    else:
        velocity = 0

    # Running average of gradient norms for adaptive step/noise
    if adaptive_step:
        running_avg_grad_norm = torch.tensor(1.0, device=device)  # avoid div by zero at init

    # Variables for early stopping
    if early_stopping:
        prev_x = x.clone()
        stop_counter = 0

    # --- HELPER FUNCTION FOR SCHEDULING ---
    def get_scheduled_val(step, total_steps, base_val, schedule_type):
        """
        Return a scaled value for either step size or noise, depending on schedule_type.
        """
        alpha = float(step) / float(total_steps)
        if schedule_type == "cosine":
            # Cosine schedule from base_val to 0.2 * base_val
            scale = 0.2 + 0.8 * (0.5 * (1.0 + np.cos(np.pi * alpha)))
            return base_val * scale
        elif schedule_type == "linear":
            return base_val * (1.0 - alpha)
        else:  # 'fixed'
            return base_val

    # --- MAIN LANGEVIN LOOP ---
    for step in range(n_steps):

        # 1. DETACH/REQUIRES_GRAD
        x_detach = x.detach().clone().requires_grad_(True)

        # 2. FORWARD AND BACKWARD
        energy = energy_model(x_detach)
        energy.backward(torch.ones_like(energy))
        grad = x_detach.grad

        # 3. GRADIENT CLIPPING
        if max_grad_norm is not None:
            grad_norm = grad.norm(p=2)
            if grad_norm > max_grad_norm:
                grad = grad * (max_grad_norm / grad_norm)
        else:
            grad_norm = grad.norm(p=2)

        # 4. ADAPTIVE STEP AND NOISE (IF ENABLED)
        if adaptive_step:
            # Update the running average of gradient norms
            running_avg_grad_norm = adapt_decay * running_avg_grad_norm + \
                                    (1 - adapt_decay) * grad_norm

            # Inversely scale step size by the running average grad norm
            step_lr_now = adapt_lr_factor / (running_avg_grad_norm + eps_min)
            
            # Inversely scale noise by running average grad norm
            # (If you prefer noise to remain fixed, you can skip this.)
            noise_now = max_noise / (running_avg_grad_norm + eps_min)
            noise_now = torch.clamp(noise_now, min=eps_min)

        else:
            # Use scheduled (or fixed) step sizes
            step_lr_now = get_scheduled_val(step, n_steps, step_lr, lr_schedule)
            noise_now   = get_scheduled_val(step, n_steps, max_noise, lr_schedule)

        # 5. LANGEVIN UPDATE
        langevin_noise = noise_now * torch.randn_like(x)
        if use_momentum:
            velocity = momentum * velocity - step_lr_now * grad + langevin_noise
            x = x + velocity
        else:
            x = x - step_lr_now * grad + langevin_noise

        # Enforce pixel range [0,1] (if desired)
        x = torch.clamp(x, 0, 1)

        # Preserve the observed pixels (mask=1) while updating only the missing ones (mask=0)
        x = mask * x.detach() + (1 - mask) * x

        # 6. EARLY STOPPING CHECK
        if early_stopping:
            if metric == 'l2':
                change = torch.norm(x - prev_x).item()
            elif metric == 'ssim':
                # Note that SSIM is 1.0 if the images are identical, so we measure 1 - ssim.
                change = 1.0 - ssim_fn(x, prev_x, data_range=1.0).item()
            else:
                raise ValueError("Unsupported metric for early stopping. Use 'l2' or 'ssim'.")

            if change < tol:
                stop_counter += 1
                if stop_counter >= patience:
                    print(f"Early stopping at step {step + 1}/{n_steps} with change={change:.6f}.")
                    break
            else:
                stop_counter = 0

            prev_x = x.clone()

    # --- RESTORE MODEL STATE ---
    for param in energy_model.parameters():
        param.requires_grad_(True)
    energy_model.train(was_training)

    return x


In [7]:
def evaluate(energy_model, val_loader, n_sample_steps, step_lr, langevin_grad_norm, device='cuda'):
    """
    Evaluates the energy model on the validation set and returns the corruption MSE,
    recovered MSE, corrupted images, and recovered images for visualization.

    Args:
        energy_model (nn.Module): Trained energy-based model.
        val_loader (torch.utils.data.DataLoader): Validation data loader.
        n_sample_steps (int): Number of Langevin dynamics steps to take when sampling.
        step_lr (float): Learning rate to use during Langevin dynamics.
        langevin_grad_norm (float): Maximum L2 norm of the Langevin dynamics gradient.
        device (str): Device to use (default='cuda').
    """
    mse = corruption_mse = 0
    energy_before_sampling = energy_after_sampling = 0
    n_batches = 0
    energy_model.eval()

    pbar = tqdm(total=len(val_loader.dataset))
    pbar.set_description('Eval')
    for data, _ in val_loader:
        n_batches += data.shape[0]
        data = data.to(device)
        broken_data, mask = corruption(data, type_='ebm')
        energy_before_sampling += energy_model(broken_data).sum().item()
        recovered_img = inpainting(energy_model, broken_data, mask,
                                   n_sample_steps, step_lr, langevin_grad_norm)
        energy_after_sampling += energy_model(recovered_img).sum().item()

        mse += np.mean((data.detach().cpu().numpy().reshape(-1, 28 * 28) - recovered_img.detach().cpu().numpy().reshape(-1, 28 * 28)) ** 2, -1).sum().item()
        corruption_mse += np.mean((data.detach().cpu().numpy().reshape(-1, 28 * 28) - broken_data.detach().cpu().numpy().reshape(-1, 28 * 28)) ** 2, -1).sum().item()

        pbar.update(data.shape[0])
        pbar.set_description('Corruption MSE: {:.6f}, Recovered MSE: {:.6f}, Energy Before Sampling: {:.6f}, Energy After Sampling: {:.6f}'.format(
            corruption_mse / n_batches, mse / n_batches, energy_before_sampling / n_batches, energy_after_sampling / n_batches))

    pbar.close()
    return (corruption_mse / n_batches, mse / n_batches, data[:100].detach().cpu(), broken_data[:100].detach().cpu(), recovered_img[:100].detach().cpu())

In [30]:
def neg_sample_gen (energy_model, buffer_size, replay_buffer, x_plus, batchsize, replay_ratio, n_steps, step_lr=10, eps=0.005, max_grad_norm=0.03, device='cuda'):
    """
    Generate negative samples using Langevin dynamics with a replay buffer.

    Args:
        energy_model (nn.Module): The energy-based model used for sampling.
        buffer_size (int): The size of the replay buffer.
        replay_buffer (torch.Tensor): The replay buffer containing samples for negative sampling.
        x_plus (torch.Tensor): The positive samples to generate negative samples for.
        batchsize (int): The batch size for generating negative samples.
        replay_ratio (float): The probability of sampling from the replay buffer.
        n_steps (int): The number of steps of Langevin dynamics to run.
        step_lr (float): The step size of Langevin dynamics.
        eps (float): The step size of Langevin dynamics.
        max_grad_norm (float): The maximum gradient norm to be used for gradient clipping.
        device (str): Device to use (default='cuda').

    Returns:
        x_minus (torch.Tensor): The generated negative samples.
        energy_before (float): The energy of the negative samples before Langevin dynamics.
    """
    is_training = energy_model.training
    energy_model.eval()
    energy_before = 0

    if buffer_size == 0:
        x_minus = torch.rand_like(x_plus)
    else:
        buffer_indices = torch.randint(0, buffer_size, (batchsize,))
        x_replay = replay_buffer[buffer_indices].to(device)
        random_samples = torch.rand_like(x_plus).to(device)
        mask = torch.rand(batchsize, 1, 1, 1, device=device) < replay_ratio
        x_minus = torch.where(mask, x_replay, random_samples)
    x_minus = x_minus.to(device)

    energy_before += energy_model(x_minus).sum().item()
   
    x_minus = langevin_step(energy_model, x_minus, step_lr, eps, max_grad_norm, n_steps)
    
    energy_model.train(is_training)
    
    return x_minus, energy_before
       

## Training
Fill the missing parts in the `train` function. There are some comments implying what to do in the corresponding blocks.

In [31]:
def train(n_epochs, energy_model, train_loader, val_loader, optimizer, n_sample_steps, step_lr, langevin_eps, l2_alpha, langevin_grad_norm =None,
          device='cuda', buffer_maxsize=int(1e5), replay_ratio=0.95, save_interval=1):
    
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.95)
    energy_model.to(device)
    replay_buffer = torch.zeros(buffer_maxsize, 1, 28, 28)
    buffer_size = buffer_ptr = 0
    best_mse = np.inf

    for epoch in range(n_epochs):
        train_loss = energy_before = energy_plus = energy_minus = n_batches = 0
        pbar = tqdm(total=len(train_loader.dataset))
        pbar.set_description('Train')
        for i, (x_plus, _) in enumerate(train_loader):
            n_batches += x_plus.shape[0]
            bs = x_plus.shape[0]

             # init negative samples
            if buffer_size == 0:
                x_minus = torch.rand_like(x_plus)
            else:
                ############################### CODE HERE ######################################
                buffer_indices = torch.randint(0, buffer_size, (bs,))
                x_replay = replay_buffer[buffer_indices].to(device)
                random_samples = torch.rand_like(x_plus).to(device)
                mask = torch.rand(bs, 1, 1, 1, device=device) < replay_ratio
                x_minus = torch.where(mask, x_replay, random_samples)
                ##############################################################################
            x_minus = x_minus.to(device)

            energy_before += energy_model(x_minus).sum().item()

            # sample negative samples
            ################################# CODE HERE ##################################
            # YOUR CODE HERE
            x_minus = langevin_step(energy_model, x_minus, step_lr, langevin_eps, langevin_grad_norm, n_sample_steps,True)
                
            # extend buffer
            if buffer_ptr + bs <= buffer_maxsize:
                replay_buffer[buffer_ptr: buffer_ptr +
                              bs] = ((x_minus * 255).to(torch.uint8).float() / 255).cpu()
            else:
                x_minus_ = (
                    (x_minus * 255).to(torch.uint8).float() / 255).cpu()
                replay_buffer[buffer_ptr:] = x_minus_[
                    :buffer_maxsize - buffer_ptr]
                remaining = bs - (buffer_maxsize - buffer_ptr)
                replay_buffer[:remaining] = x_minus_[
                    buffer_maxsize - buffer_ptr:]
            buffer_ptr = (buffer_ptr + bs) % buffer_maxsize
            buffer_size = min(buffer_maxsize, buffer_size + bs)

            # compute loss
            energy_model.train()
            x_plus = x_plus.to(device)
            x_minus = x_minus.to(device)
            e_plus = energy_model(x_plus)
            e_minus = energy_model(x_minus)
            
            ##############################################################################
            l2_regularization = l2_alpha * ((e_plus ** 2).mean() + (e_minus ** 2).mean()) 
            loss = e_plus.mean() - e_minus.mean() + l2_regularization
            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            ##############################################################################

            train_loss += loss.sum().item()
            energy_plus += e_plus.sum().item()
            energy_minus += e_minus.sum().item()

            pbar.update(x_plus.size(0))
            pbar.set_description("Train Epoch {}, Train Loss: {:.6f}, ".format(epoch + 1, 100*train_loss / n_batches) +
                                 "Energy Before Sampling: {:.6f}, ".format(energy_before / n_batches) +
                                 "Energy After Sampling: {:.6f}, ".format(energy_minus / n_batches) +
                                 "Energy of Ground Truth: {:.6f}".format(energy_plus / n_batches))
        pbar.close()
        scheduler.step()
        if (epoch + 1) % save_interval == 0:
            os.makedirs(f'./ebm/{epoch + 1}', exist_ok=True)
            energy_model.eval()
            save_model(f'./ebm/{epoch + 1}/ebm.pth',
                       energy_model, optimizer, replay_buffer)

            # evaluate inpainting
            # feel free to change the inpainting parameters!
            c_mse, r_mse, original, broken, recovered = evaluate(energy_model, val_loader,
                                                                 500, 10, 0.03, device=device)
            torchvision.utils.save_image(
                original, f"./ebm/{epoch + 1}/groundtruth.png", nrow=10)
            torchvision.utils.save_image(
                broken, f"./ebm/{epoch + 1}/corrupted.png", nrow=10)
            torchvision.utils.save_image(
                recovered, f"./ebm/{epoch + 1}/recovered.png", nrow=10)
            if r_mse < best_mse:
                print(f'Current best MSE: {best_mse} -> {r_mse}')
                best_mse = r_mse
                save_model('./ebm/ebm_best.pth', energy_model)

In [20]:
model = MlpBackbone((1, 28, 28), 512).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0, 0.999))

train_loader = DataLoader(train_set, 512, shuffle=True, drop_last=False, pin_memory=True)
val_loader = DataLoader(val_set, 2000, shuffle=True, drop_last=False, pin_memory=True)



Now you can start your training. Please keep in mind that this cell may **NOT** be run when we evaluate your assignment!

In [None]:
# feel free the change training hyper-parameters!

train(n_epochs=50, energy_model=model, train_loader=train_loader, val_loader=val_loader, optimizer=optimizer,
       n_sample_steps=200, step_lr=10, langevin_eps=0.005, langevin_grad_norm= 0.03, l2_alpha=0.1)

## Evaluation

Make sure you can run the following evaluation cell.

In [61]:
# feel free to change evaluation parameters!
# inpainting parameters are not necessarily the same as sampling parameters
n_sample_steps = 3000
step_lr = 10
langevin_grad_norm = 0.03

In [84]:
model.load_state_dict(load_model('./ebm/ebm_best.pth')[0])
corruption_mse, mse, original_eval, broken_eval, recovered_eval = evaluate(model, val_loader, n_sample_steps, step_lr, langevin_grad_norm, device=device)
print(f'Corruption MSE: {corruption_mse}')
print(f'Recovered MSE: {mse}')
os.makedirs(f'./ebm/eval', exist_ok=True)
torchvision.utils.save_image(
                original_eval, f"./ebm/eval/groundtruth.png", nrow=10)
torchvision.utils.save_image(
                broken_eval, f"./ebm/eval/corrupted.png", nrow=10)
torchvision.utils.save_image(
                recovered_eval, f"./ebm/eval/recovered.png", nrow=10)

  checkpoint = torch.load(load_path)


  0%|          | 0/10000 [00:00<?, ?it/s]



Corruption MSE: 0.06340540924072266
Recovered MSE: 0.04328401107788086
