In [1]:
# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange #pip install einops
from typing import List
import random
import math
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from timm.utils import ModelEmaV3 #pip install timm
from tqdm import tqdm #pip install tqdm
import matplotlib.pyplot as plt #pip install matplotlib
import torch.optim as optim
import numpy as np

class SinusoidalEmbeddings(nn.Module):
    def __init__(self, time_steps:int, embed_dim: int):
        super().__init__()
        position = torch.arange(time_steps).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, embed_dim, 2).float() * -(math.log(10000.0) / embed_dim))
        embeddings = torch.zeros(time_steps, embed_dim, requires_grad=False)
        embeddings[:, 0::2] = torch.sin(position * div)
        embeddings[:, 1::2] = torch.cos(position * div)
        self.embeddings = embeddings

    def forward(self, x, t):
        embeds = self.embeddings[t].to(x.device)
        return embeds[:, :, None, None]

In [2]:
# Residual Blocks
class ResBlock(nn.Module):
    def __init__(self, C: int, num_groups: int, dropout_prob: float):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.gnorm1 = nn.GroupNorm(num_groups=num_groups, num_channels=C)
        self.gnorm2 = nn.GroupNorm(num_groups=num_groups, num_channels=C)
        self.conv1 = nn.Conv2d(C, C, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(C, C, kernel_size=3, padding=1)
        self.dropout = nn.Dropout(p=dropout_prob, inplace=True)

    def forward(self, x, embeddings):
        x = x + embeddings[:, :x.shape[1], :, :]
        r = self.conv1(self.relu(self.gnorm1(x)))
        r = self.dropout(r)
        r = self.conv2(self.relu(self.gnorm2(r)))
        return r + x

In [3]:
class Attention(nn.Module):
    def __init__(self, C: int, num_heads:int , dropout_prob: float):
        super().__init__()
        self.proj1 = nn.Linear(C, C*3)
        self.proj2 = nn.Linear(C, C)
        self.num_heads = num_heads
        self.dropout_prob = dropout_prob

    def forward(self, x):
        h, w = x.shape[2:]
        x = rearrange(x, 'b c h w -> b (h w) c')
        x = self.proj1(x)
        x = rearrange(x, 'b L (C H K) -> K b H L C', K=3, H=self.num_heads)
        q,k,v = x[0], x[1], x[2]
        x = F.scaled_dot_product_attention(q,k,v, is_causal=False, dropout_p=self.dropout_prob)
        x = rearrange(x, 'b H (h w) C -> b h w (C H)', h=h, w=w)
        x = self.proj2(x)
        return rearrange(x, 'b h w C -> b C h w')

In [4]:
class UnetLayer(nn.Module):
    def __init__(self,
            upscale: bool,
            attention: bool,
            num_groups: int,
            dropout_prob: float,
            num_heads: int,
            C: int):
        super().__init__()
        self.ResBlock1 = ResBlock(C=C, num_groups=num_groups, dropout_prob=dropout_prob)
        self.ResBlock2 = ResBlock(C=C, num_groups=num_groups, dropout_prob=dropout_prob)
        if upscale:
            self.conv = nn.ConvTranspose2d(C, C//2, kernel_size=4, stride=2, padding=1)
        else:
            self.conv = nn.Conv2d(C, C*2, kernel_size=3, stride=2, padding=1)
        if attention:
            self.attention_layer = Attention(C, num_heads=num_heads, dropout_prob=dropout_prob)

    def forward(self, x, embeddings):
        x = self.ResBlock1(x, embeddings)
        if hasattr(self, 'attention_layer'):
            x = self.attention_layer(x)
        x = self.ResBlock2(x, embeddings)
        return self.conv(x), x

In [5]:
class UNET(nn.Module):
    def __init__(self,
            Channels: List = [64, 128, 256, 512, 512, 384],
            Attentions: List = [False, True, False, False, False, True],
            Upscales: List = [False, False, False, True, True, True],
            num_groups: int = 32,
            dropout_prob: float = 0.1,
            num_heads: int = 8,
            input_channels: int = 1,
            output_channels: int = 1,
            time_steps: int = 1000):
        super().__init__()
        self.num_layers = len(Channels)
        self.shallow_conv = nn.Conv2d(input_channels, Channels[0], kernel_size=3, padding=1)
        out_channels = (Channels[-1]//2)+Channels[0]
        self.late_conv = nn.Conv2d(out_channels, out_channels//2, kernel_size=3, padding=1)
        self.output_conv = nn.Conv2d(out_channels//2, output_channels, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)
        self.embeddings = SinusoidalEmbeddings(time_steps=time_steps, embed_dim=max(Channels))
        for i in range(self.num_layers):
            layer = UnetLayer(
                upscale=Upscales[i],
                attention=Attentions[i],
                num_groups=num_groups,
                dropout_prob=dropout_prob,
                C=Channels[i],
                num_heads=num_heads
            )
            setattr(self, f'Layer{i+1}', layer)

    def forward(self, x, t):
        x = self.shallow_conv(x)
        residuals = []
        for i in range(self.num_layers//2):
            layer = getattr(self, f'Layer{i+1}')
            embeddings = self.embeddings(x, t)
            x, r = layer(x, embeddings)
            residuals.append(r)
        for i in range(self.num_layers//2, self.num_layers):
            layer = getattr(self, f'Layer{i+1}')
            x = torch.concat((layer(x, embeddings)[0], residuals[self.num_layers-i-1]), dim=1)
        return self.output_conv(self.relu(self.late_conv(x)))

In [6]:
class DDPM_Scheduler(nn.Module):
    def __init__(self, num_time_steps: int=1000):
        super().__init__()
        self.beta = torch.linspace(1e-4, 0.02, num_time_steps, requires_grad=False)
        alpha = 1 - self.beta
        self.alpha = torch.cumprod(alpha, dim=0).requires_grad_(False)

    def forward(self, t):
        return self.beta[t], self.alpha[t]

In [7]:
def set_seed(seed: int = 42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

In [8]:
def train(batch_size: int=64,
          num_time_steps: int=1000,
          num_epochs: int=15,
          seed: int=-1,
          ema_decay: float=0.9999,
          lr=2e-5,
          checkpoint_path: str=None):
    set_seed(random.randint(0, 2**32-1)) if seed == -1 else set_seed(seed)

    train_dataset = datasets.MNIST(root='./data', train=True, download=True,transform=transforms.ToTensor())
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)

    scheduler = DDPM_Scheduler(num_time_steps=num_time_steps)
    model = UNET().cuda()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    ema = ModelEmaV3(model, decay=ema_decay)
    if checkpoint_path is not None:
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['weights'])
        ema.load_state_dict(checkpoint['ema'])
        optimizer.load_state_dict(checkpoint['optimizer'])
    criterion = nn.MSELoss(reduction='mean')

    for i in range(num_epochs):
        total_loss = 0
        for bidx, (x,_) in enumerate(tqdm(train_loader, desc=f"Epoch {i+1}/{num_epochs}")):
            x = x.cuda()
            x = F.pad(x, (2,2,2,2))
            t = torch.randint(0,num_time_steps,(batch_size,))
            e = torch.randn_like(x, requires_grad=False)
            a = scheduler.alpha[t].view(batch_size,1,1,1).cuda()
            x = (torch.sqrt(a)*x) + (torch.sqrt(1-a)*e)
            output = model(x, t)
            optimizer.zero_grad()
            loss = criterion(output, e)
            total_loss += loss.item()
            loss.backward()
            optimizer.step()
            ema.update(model)
        print(f'Epoch {i+1} | Loss {total_loss / (60000/batch_size):.5f}')

    checkpoint = {
        'weights': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'ema': ema.state_dict()
    }
    torch.save(checkpoint, 'checkpoints/ddpm_checkpoint')

In [11]:
def display_reverse(images: List):
    fig, axes = plt.subplots(1, 10, figsize=(10,1))
    for i, ax in enumerate(axes.flat):
        x = images[i].squeeze(0)
        x = rearrange(x, 'c h w -> h w c')
        x = x.numpy()
        ax.imshow(x)
        ax.axis('off')
    plt.show()

def inference(checkpoint_path: str=None,
              num_time_steps: int=1000,
              ema_decay: float=0.9999, ):
    checkpoint = torch.load(checkpoint_path)
    model = UNET().cuda()
    model.load_state_dict(checkpoint['weights'])
    ema = ModelEmaV3(model, decay=ema_decay)
    ema.load_state_dict(checkpoint['ema'])
    scheduler = DDPM_Scheduler(num_time_steps=num_time_steps)
    times = [0,15,50,100,200,300,400,550,700,999]
    images = []

    with torch.no_grad():
        model = ema.module.eval()
        for i in range(10):
            z = torch.randn(1, 1, 32, 32)
            for t in reversed(range(1, num_time_steps)):
                t = [t]
                temp = (scheduler.beta[t]/( (torch.sqrt(1-scheduler.alpha[t]))*(torch.sqrt(1-scheduler.beta[t])) ))
                z = (1/(torch.sqrt(1-scheduler.beta[t])))*z - (temp*model(z.cuda(),t).cpu())
                if t[0] in times:
                    images.append(z)
                e = torch.randn(1, 1, 32, 32)
                z = z + (e*torch.sqrt(scheduler.beta[t]))
            temp = scheduler.beta[0]/( (torch.sqrt(1-scheduler.alpha[0]))*(torch.sqrt(1-scheduler.beta[0])) )
            x = (1/(torch.sqrt(1-scheduler.beta[0])))*z - (temp*model(z.cuda(),[0]).cpu())

            images.append(x)
            x = rearrange(x.squeeze(0), 'c h w -> h w c').detach()
            x = x.numpy()
            plt.imshow(x)
            plt.show()
            display_reverse(images)
            images = []

In [9]:
def main():
    # Train from scratch
    train(checkpoint_path=None, lr=2e-5, num_epochs=75)

    # After training, save the checkpoint inside train() or manually:
    ckpt_path = "ddpm_model.pth"
    torch.save({
        'weights': model.state_dict(),
        'ema': ema.state_dict(),
    }, ckpt_path)

    # Run inference
    inference(ckpt_path)

if __name__ == "__main__":
    main()


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:01<00:00, 5145852.72it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 135414.78it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:01<00:00, 1287886.29it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 6601014.82it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



Epoch 1/75: 100%|██████████| 937/937 [04:41<00:00,  3.33it/s]


Epoch 1 | Loss 0.13567


Epoch 2/75: 100%|██████████| 937/937 [04:47<00:00,  3.26it/s]


Epoch 2 | Loss 0.05157


Epoch 3/75: 100%|██████████| 937/937 [04:50<00:00,  3.23it/s]


Epoch 3 | Loss 0.03896


Epoch 4/75: 100%|██████████| 937/937 [04:50<00:00,  3.23it/s]


Epoch 4 | Loss 0.03373


Epoch 5/75: 100%|██████████| 937/937 [04:51<00:00,  3.22it/s]


Epoch 5 | Loss 0.02983


Epoch 6/75: 100%|██████████| 937/937 [04:49<00:00,  3.24it/s]


Epoch 6 | Loss 0.02680


Epoch 7/75: 100%|██████████| 937/937 [04:49<00:00,  3.24it/s]


Epoch 7 | Loss 0.02408


Epoch 8/75: 100%|██████████| 937/937 [04:48<00:00,  3.24it/s]


Epoch 8 | Loss 0.02226


Epoch 9/75: 100%|██████████| 937/937 [04:48<00:00,  3.24it/s]


Epoch 9 | Loss 0.02088


Epoch 10/75: 100%|██████████| 937/937 [04:46<00:00,  3.27it/s]


Epoch 10 | Loss 0.01966


Epoch 11/75: 100%|██████████| 937/937 [04:41<00:00,  3.33it/s]


Epoch 11 | Loss 0.01871


Epoch 12/75: 100%|██████████| 937/937 [04:43<00:00,  3.30it/s]


Epoch 12 | Loss 0.01835


Epoch 13/75: 100%|██████████| 937/937 [04:43<00:00,  3.30it/s]


Epoch 13 | Loss 0.01767


Epoch 14/75: 100%|██████████| 937/937 [04:40<00:00,  3.34it/s]


Epoch 14 | Loss 0.01746


Epoch 15/75: 100%|██████████| 937/937 [04:49<00:00,  3.23it/s]


Epoch 15 | Loss 0.01696


Epoch 16/75: 100%|██████████| 937/937 [04:49<00:00,  3.23it/s]


Epoch 16 | Loss 0.01647


Epoch 17/75: 100%|██████████| 937/937 [04:41<00:00,  3.32it/s]


Epoch 17 | Loss 0.01623


Epoch 18/75: 100%|██████████| 937/937 [04:40<00:00,  3.34it/s]


Epoch 18 | Loss 0.01614


Epoch 19/75: 100%|██████████| 937/937 [04:39<00:00,  3.35it/s]


Epoch 19 | Loss 0.01578


Epoch 20/75: 100%|██████████| 937/937 [04:40<00:00,  3.34it/s]


Epoch 20 | Loss 0.01535


Epoch 21/75: 100%|██████████| 937/937 [04:40<00:00,  3.34it/s]


Epoch 21 | Loss 0.01547


Epoch 22/75: 100%|██████████| 937/937 [04:41<00:00,  3.33it/s]


Epoch 22 | Loss 0.01511


Epoch 23/75: 100%|██████████| 937/937 [04:45<00:00,  3.29it/s]


Epoch 23 | Loss 0.01485


Epoch 24/75: 100%|██████████| 937/937 [04:44<00:00,  3.29it/s]


Epoch 24 | Loss 0.01497


Epoch 25/75: 100%|██████████| 937/937 [04:45<00:00,  3.28it/s]


Epoch 25 | Loss 0.01483


Epoch 26/75: 100%|██████████| 937/937 [04:42<00:00,  3.32it/s]


Epoch 26 | Loss 0.01444


Epoch 27/75: 100%|██████████| 937/937 [04:49<00:00,  3.23it/s]


Epoch 27 | Loss 0.01458


Epoch 28/75: 100%|██████████| 937/937 [04:50<00:00,  3.22it/s]


Epoch 28 | Loss 0.01424


Epoch 29/75: 100%|██████████| 937/937 [04:50<00:00,  3.23it/s]


Epoch 29 | Loss 0.01420


Epoch 30/75: 100%|██████████| 937/937 [04:50<00:00,  3.23it/s]


Epoch 30 | Loss 0.01419


Epoch 31/75: 100%|██████████| 937/937 [04:50<00:00,  3.23it/s]


Epoch 31 | Loss 0.01413


Epoch 32/75: 100%|██████████| 937/937 [04:50<00:00,  3.23it/s]


Epoch 32 | Loss 0.01394


Epoch 33/75: 100%|██████████| 937/937 [04:50<00:00,  3.23it/s]


Epoch 33 | Loss 0.01390


Epoch 34/75: 100%|██████████| 937/937 [02:11<00:00,  7.13it/s]


Epoch 34 | Loss 0.01374


Epoch 35/75: 100%|██████████| 937/937 [02:08<00:00,  7.29it/s]


Epoch 35 | Loss 0.01356


Epoch 36/75: 100%|██████████| 937/937 [02:09<00:00,  7.22it/s]


Epoch 36 | Loss 0.01360


Epoch 37/75: 100%|██████████| 937/937 [02:43<00:00,  5.73it/s]


Epoch 37 | Loss 0.01349


Epoch 38/75: 100%|██████████| 937/937 [02:45<00:00,  5.68it/s]


Epoch 38 | Loss 0.01343


Epoch 39/75: 100%|██████████| 937/937 [02:45<00:00,  5.65it/s]


Epoch 39 | Loss 0.01347


Epoch 40/75: 100%|██████████| 937/937 [02:46<00:00,  5.64it/s]


Epoch 40 | Loss 0.01345


Epoch 41/75: 100%|██████████| 937/937 [02:45<00:00,  5.67it/s]


Epoch 41 | Loss 0.01335


Epoch 42/75: 100%|██████████| 937/937 [02:45<00:00,  5.66it/s]


Epoch 42 | Loss 0.01348


Epoch 43/75: 100%|██████████| 937/937 [02:45<00:00,  5.67it/s]


Epoch 43 | Loss 0.01320


Epoch 44/75: 100%|██████████| 937/937 [02:45<00:00,  5.65it/s]


Epoch 44 | Loss 0.01327


Epoch 45/75: 100%|██████████| 937/937 [02:45<00:00,  5.66it/s]


Epoch 45 | Loss 0.01313


Epoch 46/75: 100%|██████████| 937/937 [02:44<00:00,  5.68it/s]


Epoch 46 | Loss 0.01301


Epoch 47/75: 100%|██████████| 937/937 [02:46<00:00,  5.64it/s]


Epoch 47 | Loss 0.01305


Epoch 48/75: 100%|██████████| 937/937 [02:44<00:00,  5.68it/s]


Epoch 48 | Loss 0.01281


Epoch 49/75: 100%|██████████| 937/937 [02:45<00:00,  5.66it/s]


Epoch 49 | Loss 0.01294


Epoch 50/75: 100%|██████████| 937/937 [02:45<00:00,  5.66it/s]


Epoch 50 | Loss 0.01289


Epoch 51/75: 100%|██████████| 937/937 [02:44<00:00,  5.70it/s]


Epoch 51 | Loss 0.01287


Epoch 52/75: 100%|██████████| 937/937 [02:45<00:00,  5.66it/s]


Epoch 52 | Loss 0.01273


Epoch 53/75: 100%|██████████| 937/937 [02:44<00:00,  5.69it/s]


Epoch 53 | Loss 0.01277


Epoch 54/75: 100%|██████████| 937/937 [02:45<00:00,  5.66it/s]


Epoch 54 | Loss 0.01284


Epoch 55/75: 100%|██████████| 937/937 [02:46<00:00,  5.63it/s]


Epoch 55 | Loss 0.01259


Epoch 56/75: 100%|██████████| 937/937 [02:45<00:00,  5.67it/s]


Epoch 56 | Loss 0.01265


Epoch 57/75: 100%|██████████| 937/937 [02:46<00:00,  5.63it/s]


Epoch 57 | Loss 0.01267


Epoch 58/75: 100%|██████████| 937/937 [02:45<00:00,  5.67it/s]


Epoch 58 | Loss 0.01273


Epoch 59/75: 100%|██████████| 937/937 [02:45<00:00,  5.66it/s]


Epoch 59 | Loss 0.01257


Epoch 60/75: 100%|██████████| 937/937 [02:45<00:00,  5.65it/s]


Epoch 60 | Loss 0.01249


Epoch 61/75: 100%|██████████| 937/937 [02:44<00:00,  5.69it/s]


Epoch 61 | Loss 0.01253


Epoch 62/75: 100%|██████████| 937/937 [02:45<00:00,  5.65it/s]


Epoch 62 | Loss 0.01253


Epoch 63/75: 100%|██████████| 937/937 [02:45<00:00,  5.68it/s]


Epoch 63 | Loss 0.01245


Epoch 64/75: 100%|██████████| 937/937 [02:45<00:00,  5.68it/s]


Epoch 64 | Loss 0.01238


Epoch 65/75: 100%|██████████| 937/937 [02:45<00:00,  5.68it/s]


Epoch 65 | Loss 0.01228


Epoch 66/75: 100%|██████████| 937/937 [02:44<00:00,  5.69it/s]


Epoch 66 | Loss 0.01236


Epoch 67/75: 100%|██████████| 937/937 [02:45<00:00,  5.65it/s]


Epoch 67 | Loss 0.01234


Epoch 68/75: 100%|██████████| 937/937 [02:45<00:00,  5.67it/s]


Epoch 68 | Loss 0.01227


Epoch 69/75: 100%|██████████| 937/937 [02:45<00:00,  5.67it/s]


Epoch 69 | Loss 0.01221


Epoch 70/75: 100%|██████████| 937/937 [02:45<00:00,  5.66it/s]


Epoch 70 | Loss 0.01230


Epoch 71/75: 100%|██████████| 937/937 [02:44<00:00,  5.71it/s]


Epoch 71 | Loss 0.01214


Epoch 72/75: 100%|██████████| 937/937 [02:45<00:00,  5.66it/s]


Epoch 72 | Loss 0.01219


Epoch 73/75: 100%|██████████| 937/937 [02:45<00:00,  5.67it/s]


Epoch 73 | Loss 0.01212


Epoch 74/75: 100%|██████████| 937/937 [02:45<00:00,  5.67it/s]


Epoch 74 | Loss 0.01220


Epoch 75/75: 100%|██████████| 937/937 [02:45<00:00,  5.65it/s]


Epoch 75 | Loss 0.01206


RuntimeError: Parent directory checkpoints does not exist.