In [None]:
#@title Dataset Retrieval
!mkdir ~/.kaggle
!echo '{"username":"oriyonay2","key":"ae6fe32a1ad5e9c76204c9c526f4a3c8"}' > ~/.kaggle/kaggle.json
!chmod 600 /root/.kaggle/kaggle.json

# download and unzip the data
!kaggle datasets download -d alxmamaev/flowers-recognition
!unzip -qq flowers-recognition.zip

Downloading flowers-recognition.zip to /content
 93% 209M/225M [00:01<00:00, 134MB/s]
100% 225M/225M [00:01<00:00, 144MB/s]


In [None]:
#@title Utilities

'''
utils.py: general utilities
'''

import matplotlib.pyplot as plt
import os
from PIL import Image
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as T
from torchvision import datasets
from torchvision.utils import make_grid

# neat function for getting the next batch without epochs
def infinite_dataloader(dataloader):
    while True:
        for data in dataloader:
            yield data

def plot_images(images):
    plt.figure(figsize=(32, 32))
    images = torch.cat([i for i in images.cpu()], dim=-1)
    images = torch.cat([images], dim=-2).permute(1, 2, 0).cpu()
    plt.imshow(images)
    plt.show()

def save_images(images, path, **kwargs):
    grid = make_grid(images, **kwargs)
    arr = grid.permute(1, 2, 0).cpu().numpy()
    im = Image.fromarray(arr)
    im.save(path)

def get_data(args):
    transforms = T.Compose([
        T.Resize(80),
        T.RandomResizedCrop(args.image_size, scale=(0.8, 1.0)),
        T.ToTensor(),
        T.Normalize(0.5, 0.5)
    ])
    dataset = datasets.ImageFolder(args.dataset_path, transform=transforms)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
    dataloader = infinite_dataloader(dataloader)
    return dataloader

def setup_logging(run_name):
    os.makedirs('models', exist_ok=True)
    os.makedirs('results', exist_ok=True)
    os.makedirs(os.path.join('models', run_name), exist_ok=True)
    os.makedirs(os.path.join('results', run_name), exist_ok=True)

class EMA:
    '''
    exponential moving average class, for adjusting model weights more smoothly
    EMA update: w = (beta * w_old) + ((1 - beta) * w_new)

    beta: EMA parameter
    step_start_ema: number of warmup steps before using EMA
    '''
    def __init__(self, beta, step_start_ema=2000):
        self.beta = beta
        self.step_start_ema = step_start_ema
        self.step = 0

    def step_ema(self, ema_model, model):
        if self.step < self.step_start_ema:
            self.reset_params(ema_model, model)
        else:
            self.update_model_average(ema_model, model)

        self.step += 1

    def update_model_average(self, ema_model, model):
        mp = model.parameters()
        ep = ema_model.parameters()
        for current_param, ema_param in zip(mp, ep):
            w_old, w_new = ema_param.data, current_param.data
            ema_param.data = self.update_average(w_old, w_new)

    def update_average(self, old, new):
        return (self.beta * old) + ((1 - self.beta) * new)

    def reset_params(self, ema_model, model):
        ema_model.load_state_dict(model.state_dict())

In [None]:
#@title Models

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm


class Diffusion:
    '''
    the diffusion class

    noise_steps: number of diffusion noising steps
    betas: start and end for the variance schedule
    img_size: generated image size
    device: the device to use for training
    '''
    def __init__(self, noise_steps=1000, betas=(1e-4, 2e-2), img_size=64, device='cuda'):
        self.noise_steps = noise_steps
        self.betas = betas
        self.img_size = img_size
        self.device = device

        # prepare the noise variance schedule and various constants:
        self.beta = self.prepare_noise_schedule().to(device)
        self.alpha = 1 - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)
        self.sqrt_alpha_hat = torch.sqrt(self.alpha_hat)
        self.sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat)

    # computes a linear noise schedule
    def prepare_noise_schedule(self):
        return torch.linspace(*self.betas, self.noise_steps)

    # noises the data x to timestep t
    def noise_data(self, x, t):
        # get constants
        sqrt_alpha_hat = self.sqrt_alpha_hat[t][:, None, None, None]
        sqrt_one_minus_alpha_hat = self.sqrt_one_minus_alpha_hat[t][:, None, None, None]

        # generate noise
        e = torch.randn_like(x)

        # compute scaled signal and noise
        signal = sqrt_alpha_hat * x
        noise = sqrt_one_minus_alpha_hat * e
        return signal + noise, e

    # samples time steps (just returns random ints as timesteps)
    def sample_timesteps(self, n):
        return torch.randint(1, self.noise_steps, size=(n,))

    # sample n datapoints from the model
    # labels: optional labels if the model was trained conditionally
    # cfg_scale:
    @torch.no_grad()
    def sample(self, model, n, labels=None, cfg_scale=3):
        model.eval()

        # start with gaussian noise
        x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)

        # denoise data
        for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
            # create the timestep tensor that encodes the current timestep
            t = (torch.ones(n) * i).long().to(self.device)

            # predict the noise
            predicted_noise = model(x, t, labels)

            # classifier-free guidance: linearly interpolate between
            # unconditional and conditional (above) samples
            if labels:
                uncond_predicted_noise = model(x, t)
                predicted_noise = torch.lerp(uncond_predicted_noise, predicted_noise, cfg_scale)

            # compute scaling constants
            alpha = self.alpha[t][:, None, None, None]
            alpha_hat = self.alpha_hat[t][:, None, None, None]
            beta = self.beta[t][:, None, None, None]

            # remove a small bit of noise from x
            noise = torch.randn_like(x) if i > 1 else torch.zeros_like(x)
            signal_scale = 1 / torch.sqrt(alpha)
            pred_noise_scale = (1 - alpha) / (torch.sqrt(1 - alpha_hat))
            scaled_noise = torch.sqrt(beta) * noise
            signal = x - (pred_noise_scale * predicted_noise)

            x = (signal_scale * signal) + scaled_noise

        # set the model back to training mode
        model.train()

        # clamp and rescale x values to [0, 1] (output was [-1, 1]):
        x = (x.clamp(-1, 1) + 1) / 2

        # convert x to valid pixel range:
        x = (x * 255).type(torch.uint8)

        return x


class PatchingLayer(nn.Module):
    '''
    Splits the input image into non-overlapping patches.
    '''
    def __init__(self, patch_size, num_channels):
        super(PatchingLayer, self).__init__()
        self.patch_size = patch_size
        self.num_channels = num_channels

    def forward(self, x):
        # Split the image into patches
        b, c, h, w = x.size()
        x = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        x = x.permute(0, 2, 3, 1, 4, 5).contiguous()
        x = x.view(b, -1, c * self.patch_size * self.patch_size)
        return x


class PositionalEncoding(nn.Module):
    def __init__(self, d_model):
        super(PositionalEncoding, self).__init__()
        self.d_model = d_model

    def forward(self, x):
        B, N, E = x.size()  # Batch size, Number of tokens, Embedding dimension
        pos = torch.arange(0, N).unsqueeze(0).unsqueeze(-1).to(x.device).float()  # Shape: [1, N, 1]
        div_term = torch.exp(torch.arange(0, E, 2).float() * -(math.log(10000.0) / E)).to(x.device)  # Shape: [E//2]
        div_term = div_term.unsqueeze(0).unsqueeze(0)  # Shape: [1, 1, E//2]

        pos_enc = torch.zeros_like(x)  # Shape: [B, N, E]

        pos_enc[:, :, 0::2] = torch.sin(pos * div_term)  # Apply to even indices
        pos_enc[:, :, 1::2] = torch.cos(pos * div_term)  # Apply to odd indices

        return pos_enc


class OutputLayer(nn.Module):
    def __init__(self, d_model, patch_size, num_channels):
        super(OutputLayer, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.ReLU(),
            nn.Linear(d_model * 4, num_channels * patch_size * patch_size)
        )
        self.patch_size = patch_size
        self.num_channels = num_channels

    def forward(self, x):
        # Reshape tokens back to patches
        x = self.mlp(x)
        b, n, _ = x.size()
        x = x.view(b, n, self.num_channels, self.patch_size, self.patch_size)

        # Reconstruct the original image dimensions from the patches
        h_dim = w_dim = int((n)**0.5)
        x = x.view(b, h_dim, w_dim, self.num_channels, self.patch_size, self.patch_size)
        x = x.permute(0, 3, 1, 4, 2, 5).contiguous()
        x = x.view(b, self.num_channels, h_dim * self.patch_size, w_dim * self.patch_size)

        return x


class ImageTransformer(nn.Module):
    '''
    Full model architecture
    '''
    def __init__(self, d_model, nhead, num_layers, patch_size, num_classes=None, num_channels=3, dropout=0.05):
        super(ImageTransformer, self).__init__()
        self.d_model = d_model
        self.patching_layer = PatchingLayer(patch_size, num_channels)
        self.projection = nn.Linear(num_channels * patch_size * patch_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model)
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, nhead, batch_first=True, dropout=dropout),
            num_layers
        )
        self.output_layer = OutputLayer(d_model, patch_size, num_channels)

        if num_classes:
            self.label_emb = nn.Embedding(num_classes, d_model)

    def forward(self, x, t, label=None):
        # compute positional encoding for the timestep (len(t), self.time_dim)
        t = t.unsqueeze(-1).float()
        t = self._time_embedding(t)

        # class-conditioning
        if label:
            t = t + self.label_emb(label)

        x = self.patching_layer(x)
        x = self.projection(x)
        residual = x
        x = x + t + self.positional_encoding(x)
        x = self.transformer_encoder(x) + residual
        x = self.output_layer(x)
        return x

    def _time_embedding(self, t):
        inv_freq = 1.0 / (10000 ** (torch.arange(0, self.d_model, 2, dtype=torch.float32).to(t.device) / self.d_model))

        # Create the sine and cosine encodings
        pos_enc_sin = torch.sin(t.unsqueeze(1).float() * inv_freq)
        pos_enc_cos = torch.cos(t.unsqueeze(1).float() * inv_freq)

        # Concatenate the sine and cosine encodings
        pos_enc = torch.cat([pos_enc_sin, pos_enc_cos], dim=-1)

        return pos_enc

    @property
    def n_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

In [None]:
#@title Training

'''
training the diffusion model
'''

import argparse
import copy
import numpy as np
import os
import torch
import torch.nn as nn
from tqdm import trange


def train(args):
    print(f'Training on {args.device}')

    # set up training run
    setup_logging(args.run_name)
    device = args.device
    dataloader = get_data(args)
    model = ImageTransformer(
        d_model=128,
        nhead=8,
        num_layers=4,
        patch_size=4,
        num_channels=3
    ).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
    criterion = nn.MSELoss()
    diffusion = Diffusion(img_size=args.image_size, device=device)
    ema = EMA(args.ema_beta, args.step_start_ema)
    ema_model = copy.deepcopy(model).eval().requires_grad_(False)

    print(f'Model contains {model.n_params} trainable parameters')

    pbar = trange(args.n_iters)
    for i in pbar:
        # prepare next batch
        images, labels = next(dataloader)
        images = images.to(device)
        labels = labels.to(device) if args.n_labels else None
        batch_size = images.shape[0]

        # set labels to None 10% of the time, even if this is
        # conditional generation (classifier-free guidance)
        if np.random.random() < 0.1:
            labels = None

        # forward pass
        t = diffusion.sample_timesteps(batch_size).to(device)
        x_t, noise = diffusion.noise_data(images, t)
        predicted_noise = model(x_t, t, labels)
        loss = criterion(noise, predicted_noise)

        # backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        ema.step_ema(ema_model, model)

        # update progress bar
        pbar.set_description(f'loss: {loss.item():.3f}')

        # sample images every few iterations
        if (i+1) % args.checkpoint_every == 0:
            sampled_images = diffusion.sample(model, n=batch_size)
            save_path = os.path.join('results', args.run_name, f'{i:06}.jpg')
            save_images(sampled_images, save_path)

            # checkpoint the model weights
            torch.save(model.state_dict(), os.path.join('models', args.run_name, 'checkpoint.pt'))

class Arguments:
    pass

def launch():
    # parser = argparse.ArgumentParser()
    args = Arguments() # parser.parse_args()

    args.run_name = 'image_transformer_test'
    args.n_iters = 40000
    args.batch_size = 16
    args.image_size = 64
    args.checkpoint_every = 1000
    args.n_labels = None # number of labels for conditional generation
    args.dataset_path = 'flowers'
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    args.lr = 3e-4
    args.ema_beta = 0.995
    args.step_start_ema = 2000 # warmup steps before EMA

    train(args)

if __name__ == '__main__':
    launch()

Training on cuda
Model contains 2469040 trainable parameters


999it [00:10, 96.78it/s]
999it [00:10, 95.40it/s]
999it [00:10, 95.28it/s]
loss: 0.055:   9%|▉         | 3664/40000 [05:39<56:10, 10.78it/s]


KeyboardInterrupt: ignored