In [1]:
!mkdir -p contents_lap

In [2]:
# From https://stackoverflow.com/questions/61058798/python-relative-import-in-jupyter-notebook
import os, sys

dir2 = os.path.abspath("")
dir1 = os.path.dirname(dir2)
if not dir1 in sys.path:
    sys.path.append(dir1)

In [3]:
import torch
import random
import matplotlib.pyplot as plt
import numpy as np
from typing import Dict, Tuple
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image, make_grid
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from accelerate import Accelerator
from tqdm import tqdm
import src.funcs as funcs
from src.cnn_module import CNN
from src.custom_deg_module import Col_Averaging, Row_Averaging
from src.ddpm_module import DDPM

  from .autonotebook import tqdm as notebook_tqdm


# 1. Loading a pre-trained model

In [4]:
# Load the data
tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0))])
dataset = MNIST("./data", train=True, download=True, transform=tf)
dataloader = DataLoader(
    dataset, batch_size=128, shuffle=True, num_workers=4, drop_last=True
)

Next we initialise the model from a choice of models, and hyperparameters for DDPM:
- Default:
    - betas [1e-4,0.02]
    - n_T 1000
    - n_hidden (16,32,32,16)
    - batch_size 128
    - activation = nn.GELU
- testing2:
    - betas = (1e-4, 0.02)
    - n_T = 1500
    - lr = 4e-4
    - n_hidden = (16, 32, 32, 16)
    - batch_size = 128
    - activation = nn.GELU

And for the custom degradation:
- Default_7: 
    - order = 3
    - grouping = "7"
    - n_T = 7
    - lr = 2e-4
    - n_hidden = (16, 32, 32, 16)
    - batch_size = 128
- Default_28:
    - order = 3
    - grouping = "28"
    - n_T = 28
    - lr = 2e-4
    - n_hidden = (16, 32, 32, 16)
    - batch_size = 128

In [5]:
gt = CNN(in_channels=1, expected_shape=(28, 28), n_hidden=(16, 32, 32, 16), act=nn.GELU)

dif_model = Col_Averaging(
    gt=gt, col_order=3, n_T=28, grouping="28"
)  # For column averaging with column grouping

# dif_model = Row_Averaging(gt=gt, row_order=3, n_T=7, grouping="7") # For row averaging with row grouping

# dif_model = Col_Averaging(gt=gt, col_order=3, n_T=28, grouping="28") # For column averaging with no grouping

# dif_model = Row_Averaging(gt=gt, row_order=3, n_T=28, grouping="28") # For row averaging with no grouping

# dif_model = DDPM(gt=gt,betas=(1e-4, 0.02), n_T=1000) # For DDPM

# Take care for the learning rate:
optim = torch.optim.Adam(dif_model.parameters(), lr=2e-4)

In [6]:
accelerator = Accelerator()

# We wrap our model, optimizer, and dataloaders with `accelerator.prepare`,
# which lets HuggingFace's Accelerate handle the device placement and gradient accumulation.
dif_model, optim, dataloader = accelerator.prepare(dif_model, optim, dataloader)

In [7]:
# Choose the appropirate model to load:
# - ddpm_mnist_{num_epoch}_{hyperparameters}.pth
# - custom_mnist_{num_epoch}_col_{hyperparameters}.pth, default_7 or default_28 to set grouping
# - custom_mnist_{num_epoch}_row_{hyperparameters}.pth, default_7 or default_28 to set grouping

# Load the model
dif_model.load_state_dict(
    torch.load(
        "../custom_mnist_80_col_default_28.pth", map_location=torch.device("cpu")
    )
)
dif_model.eval()

Col_Averaging(
  (gt): CNN(
    (blocks): ModuleList(
      (0): CNNBlock(
        (net): Sequential(
          (0): Conv2d(1, 16, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
          (1): LayerNorm((16, 28, 28), eps=1e-05, elementwise_affine=True)
          (2): GELU(approximate='none')
        )
      )
      (1): CNNBlock(
        (net): Sequential(
          (0): Conv2d(16, 32, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
          (1): LayerNorm((32, 28, 28), eps=1e-05, elementwise_affine=True)
          (2): GELU(approximate='none')
        )
      )
      (2): CNNBlock(
        (net): Sequential(
          (0): Conv2d(32, 32, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
          (1): LayerNorm((32, 28, 28), eps=1e-05, elementwise_affine=True)
          (2): GELU(approximate='none')
        )
      )
      (3): CNNBlock(
        (net): Sequential(
          (0): Conv2d(32, 16, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
          (1): LayerNorm((16, 28

# 3. Explore the model

In [8]:
# Set random seeds
torch.manual_seed(75016)
np.random.seed(75016)

In [9]:
# Sample some images

# ---------------------------------------------
# For DDPM
# ---------------------------------------------
# degraded, xh = dif_model.sample(16, (1, 28, 28), accelerator.device)

# ---------------------------------------------
# For custom degradation with conditional sampling
# ---------------------------------------------
# original, degraded, direct, xh = dif_model.sample(
#     16, (1, 28, 28), accelerator.device
# )

# ---------------------------------------------
# For custom degradation with unconditional sampling
# ---------------------------------------------
degraded, direct, xh = dif_model.sample_unconditional(
    16, (1, 28, 28), accelerator.device
)


# normalise the degraded image for better visualisation
for i in range(16):
    deg_min = torch.min(degraded[i])
    deg_max = torch.max(degraded[i])
    # Bump up the contrast, by scaling the values back to -0.5 to 0.5 scale
    degraded[i] = (degraded[i] - deg_min) * (0.5 - (-0.5)) / (deg_max - deg_min) - 0.5

grid = make_grid(xh, nrow=4)
grid1 = make_grid(degraded, nrow=4)
# grid2 = make_grid(original, nrow=4)
grid3 = make_grid(direct, nrow=4)

# Plot the original, the degraded, the direct and the restored images
save_image(grid, f"./contents_lap/generated_sample.png")
save_image(grid1, f"./contents_lap/degraded_sample.png")
# save_image(grid2, f"./contents_lap/original_sample.png")
save_image(grid3, f"./contents_lap/direct_sample.png")

torch.Size([16, 1, 28, 28])
tensor([[[-0.4001, -0.4450, -0.4456, -0.4666, -0.4695, -0.4765, -0.4802,
          -0.4671, -0.4007, -0.4192, -0.2221, -0.3694, -0.2326, -0.2988,
          -0.2495, -0.3327, -0.1871, -0.2764, -0.4195, -0.4834, -0.4517,
          -0.4481, -0.4510, -0.4827, -0.4873, -0.4767, -0.4622, -0.4996],
         [-0.4001, -0.4450, -0.4456, -0.4666, -0.4695, -0.4765, -0.4802,
          -0.4671, -0.4007, -0.4192, -0.2221, -0.3694, -0.2326, -0.2988,
          -0.2495, -0.3327, -0.1871, -0.2764, -0.4195, -0.4834, -0.4517,
          -0.4481, -0.4510, -0.4827, -0.4873, -0.4767, -0.4622, -0.4996],
         [-0.4001, -0.4450, -0.4456, -0.4666, -0.4695, -0.4765, -0.4802,
          -0.4671, -0.4007, -0.4192, -0.2221, -0.3694, -0.2326, -0.2988,
          -0.2495, -0.3327, -0.1871, -0.2764, -0.4195, -0.4834, -0.4517,
          -0.4481, -0.4510, -0.4827, -0.4873, -0.4767, -0.4622, -0.4996],
         [-0.4001, -0.4450, -0.4456, -0.4666, -0.4695, -0.4765, -0.4802,
          -0.4671, -

In [None]:
# Compute some metrics

# ---------------------------------------------
# For DDPM, may need to reduce sample size for time
# ---------------------------------------------
# fid_score = get_fid(dif_model, dataset, 1000, accelerator.device)
# is_score_real = get_is(dataset, True, 1000, accelerator.device)
# is_score_gen = get_is(dif_model, False, 1000, accelerator.device)

# ---------------------------------------------
# For custom degradation
# ---------------------------------------------
fid_score = funcs.get_fid(dif_model, dataset, 100, accelerator.device)
is_score_real = funcs.get_is_custom(dif_model, dataset, True, 100, accelerator.device)
is_score_gen = funcs.get_is_custom(dif_model, dataset, False, 100, accelerator.device)

print(f"FID: {fid_score}")
print(f"IS real: {is_score_real}")
print(f"IS gen: {is_score_gen}")

In [None]:
# To visualise the DDPM degradation
funcs.plot_ddpm_degrade(dataset, 100, (1e-4, 0.02), 1000)
funcs.plot_ddpm_degrade(dataset, 900, (1e-4, 0.02), 1000)