In [1]:
!mkdir -p contents_lap

In [1]:
# From https://stackoverflow.com/questions/61058798/python-relative-import-in-jupyter-notebook
import os, sys

dir2 = os.path.abspath("")
dir1 = os.path.dirname(dir2)
if not dir1 in sys.path:
    sys.path.append(dir1)

In [2]:
import torch
import random
import matplotlib.pyplot as plt
import numpy as np
from typing import Dict, Tuple
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image, make_grid
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from accelerate import Accelerator
import src.funcs as funcs
from src.cnn_module import CNN
from src.custom_deg_module import Col_Averaging, Row_Averaging
from src.ddpm_module import DDPM

  from .autonotebook import tqdm as notebook_tqdm


# 1. Some functions

In [3]:
def get_fid(generator, real_data, num_images, device):
    with torch.no_grad():
        # Sample images from the real images dataset
        num_samples = len(real_data)
        idx = random.sample(range(num_samples), num_images)
        real_img = torch.stack([real_data[i][0].clone() for i in idx])

        # Make the image have 3 identical channels
        # so that it can be processed by the FID metric
        real_img = real_img.to(torch.uint8).expand(-1, 3, -1, -1)

        # Sample images from the generator (DDPM)
        if generator.__class__.__name__ == "DDPM":
            gen_img = generator.sample(num_images, (1, 28, 28), device)
        else:
            _, __, ___, gen_img = generator.sample(
                num_images, real_data, (1, 28, 28), device
            )

        # Make the image have 3 identical channels
        gen_img = gen_img.expand(-1, 3, -1, -1)

        # Put the images in the same device
        gen_img = gen_img.to(real_img.device)

        # Initialize the FID metric
        fid = FrechetInceptionDistance(normalize=True)

        fid.update(real_img, real=True)
        fid.update(gen_img, real=False)
        fid_score = fid.compute()

        return fid_score


def get_is(data_source, is_real, num_images, device):
    with torch.no_grad():
        if is_real:
            # Sample images from the real images dataset
            num_samples = len(data_source)
            idx = random.sample(range(num_samples), num_images)
            img = torch.stack([data_source[i][0].clone() for i in idx])
            img = img.expand(-1, 3, -1, -1)

        else:
            # Sample images from the generator (DDPM)
            img = data_source.sample(num_images, (1, 28, 28), device)

            # Make the image have 3 identical channels
            img = img.expand(-1, 3, -1, -1)

            img = img.to("cpu")

        # Initialize the IS metric
        is_score = InceptionScore("logits_unbiased", normalize=True)

        is_score.update(img)
        is_score = is_score.compute()

        return is_score


def get_is_custom(generator, data_source, is_real, num_images, device):
    with torch.no_grad():
        if is_real:
            # Sample images from the real images dataset
            num_samples = len(data_source)
            idx = random.sample(range(num_samples), num_images)
            img = torch.stack([data_source[i][0].clone() for i in idx])
            img = img.expand(-1, 3, -1, -1)

        else:
            # Sample images from the generator (DDPM)
            _, __, ___, img = generator.sample(
                num_images, data_source, (1, 28, 28), device
            )

            # Make the image have 3 identical channels
            img = img.expand(-1, 3, -1, -1)

            img = img.to("cpu")

        # Initialize the IS metric
        is_score = InceptionScore("logits_unbiased", normalize=True)

        is_score.update(img)
        is_score = is_score.compute()

        return is_score


# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
# Plotting functions
# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------


def plot_losses(losses, avg_losses, num_epochs, model):
    """!@brief Plot the losses over the training process

    @param losses: list of losses
    @param avg_losses: list of average losses
    @param num_epochs: number of epochs
    @param model: used model name

    @return None
    """

    x = np.arange(len(losses)) / (len(losses) // num_epochs)
    x_epoch = x[:: (len(losses) // num_epochs)].copy() + 1

    plt.style.use("ggplot")
    plt.figure(figsize=(10, 5))
    plt.plot(x, losses, color="green", label="Loss")
    plt.plot(
        x_epoch,
        avg_losses,
        linestyle="--",
        marker="+",
        color="black",
        label="Epoch average loss",
    )
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.yscale("log")  # Put the y-axis on a log scale
    plt.title(f"Losses over {num_epochs} epochs for {model}")
    plt.legend()
    plt.show()


def plot_fid(fid_scores, num_epochs, model):
    """!@brief Plot the FID scores over the training process

    @param fid_scores: list of FID scores
    @param num_epochs: number of epochs
    @param model: used model name

    @return None
    """
    x = np.arange(num_epochs)
    plt.style.use("ggplot")
    plt.figure(figsize=(10, 5))
    plt.plot(x, fid_scores, color="green")
    plt.xlabel("Epochs")
    plt.ylabel("FID")
    plt.title(f"FID over {num_epochs} epochs for {model}")

    plt.show()


def plot_is(is_scores, num_epochs, model):
    """!@brief Plot the IS scores over the training process

    @param is_scores: list of IS scores
    @param num_epochs: number of epochs
    @param model: used model name

    @return None
    """
    x = np.arange(num_epochs)
    plt.style.use("ggplot")
    plt.figure(figsize=(10, 5))
    plt.plot(x, is_scores, color="green")
    plt.xlabel("Epochs")
    plt.ylabel("IS")
    plt.title(f"IS over {num_epochs} epochs for {model}")

    plt.show()

# 1. Loading a pre-trained model

In [4]:
# Load the data
tf = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0))])
dataset = MNIST("./data", train=True, download=True, transform=tf)
dataloader = DataLoader(
    dataset, batch_size=128, shuffle=True, num_workers=4, drop_last=True
)

In [5]:
gt = CNN(in_channels=1, expected_shape=(28, 28), n_hidden=(16, 32, 32, 16), act=nn.GELU)

dif_model = Col_Averaging(
    gt=gt, col_order=3, n_T=7, grouping="7"
)  # For column averaging with column grouping
# dif_model = Row_Averaging(gt=gt, row_order=3, n_T=7, grouping="7") # For row averaging with row grouping
# dif_model = Col_Averaging(gt=gt, col_order=3, n_T=28, grouping="28") # For column averaging with no grouping
# dif_model = Row_Averaging(gt=gt, row_order=3, n_T=28, grouping="28") # For row averaging with no grouping
# dif_model = DDPM(gt=gt,betas=(1e-4, 0.02), n_T=1000) # For DDPM

optim = torch.optim.Adam(dif_model.parameters(), lr=1e-3)

In [6]:
accelerator = Accelerator()

# We wrap our model, optimizer, and dataloaders with `accelerator.prepare`,
# which lets HuggingFace's Accelerate handle the device placement and gradient accumulation.
dif_model, optim, dataloader = accelerator.prepare(dif_model, optim, dataloader)

In [7]:
# Choose the appropirate model to load:
# - ddpm_mnist_{num_epoch}_{hyperparameters}.pth
# - custom_mnist_{num_epoch}_col_{hyperparameters}.pth, default_7 or default_28 to set grouping
# - custom_mnist_{num_epoch}_row_{hyperparameters}.pth, default_7 or default_28 to set grouping

# Load the model
dif_model.load_state_dict(
    torch.load("./ddpm_mnist.pth", map_location=torch.device("cpu"))
)
dif_model.eval()

Col_Averaging(
  (gt): CNN(
    (blocks): ModuleList(
      (0): CNNBlock(
        (net): Sequential(
          (0): Conv2d(1, 16, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
          (1): LayerNorm((16, 28, 28), eps=1e-05, elementwise_affine=True)
          (2): GELU(approximate='none')
        )
      )
      (1): CNNBlock(
        (net): Sequential(
          (0): Conv2d(16, 32, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
          (1): LayerNorm((32, 28, 28), eps=1e-05, elementwise_affine=True)
          (2): GELU(approximate='none')
        )
      )
      (2): CNNBlock(
        (net): Sequential(
          (0): Conv2d(32, 32, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
          (1): LayerNorm((32, 28, 28), eps=1e-05, elementwise_affine=True)
          (2): GELU(approximate='none')
        )
      )
      (3): CNNBlock(
        (net): Sequential(
          (0): Conv2d(32, 16, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
          (1): LayerNorm((16, 28

# 3. Explore the model

In [None]:
# Set random seeds
torch.manual_seed(75016)
np.random.seed(75016)

In [8]:
# Sample some images

# For DDPM
# degraded, xh = dif_model.sample(16, (1, 28, 28), accelerator.device)

# For custom degradation
original, degraded, direct, xh = dif_model.sample(
    16, dataset, (1, 28, 28), accelerator.device
)

# Can get device explicitly with `accelerator.device`
for i in range(16):
    deg_min = torch.min(degraded[i])
    deg_max = torch.max(degraded[i])
    # Bump up the contrast, by scaling the values back to -0.5 to 0.5 scale
    degraded[i] = (degraded[i] - deg_min) * (0.5 - (-0.5)) / (deg_max - deg_min) - 0.5

grid1 = make_grid(degraded, nrow=4)
grid = make_grid(xh, nrow=4)
grid2 = make_grid(original, nrow=4)
grid3 = make_grid(direct, nrow=4)

# Plot the original, the degraded, the direct and the restored images
save_image(grid1, f"./contents_lap/degraded_sample.png")
save_image(grid, f"./contents_lap/generated_sample.png")
save_image(grid2, f"./contents_lap/original_sample.png")
save_image(grid3, f"./contents_lap/direct_sample.png")

In [12]:
# Compute some metrics

# For DDPM
# fid_score = get_fid(dif_model, dataset, 1000, accelerator.device)
# is_score_real = get_is(dataset, True, 1000, accelerator.device)
# is_score_gen = get_is(dif_model, False, 1000, accelerator.device)

# For custom degradation
fid_score = get_fid(dif_model, dataset, 10, accelerator.device)
is_score_real = get_is_custom(dif_model, dataset, True, 10, accelerator.device)
is_score_gen = get_is_custom(dif_model, dataset, False, 10, accelerator.device)

print(f"FID: {fid_score}")
print(f"IS real: {is_score_real}")
print(f"IS gen: {is_score_gen}")



FID: 634.5020141601562
IS real: (tensor(1.), tensor(1.0131e-07))
IS gen: (tensor(1.), tensor(0.))


In [10]:
funcs.plot_ddpm_degrade(dataset, 100, (1e-4, 0.02), 1000)
funcs.plot_ddpm_degrade(dataset, 900, (1e-4, 0.02), 1000)