## Module 4 Project 2: Stable Diffusion
- Implement and train a simple version of [Stable Diffusion](https://en.wikipedia.org/wiki/Stable_Diffusion) using a basic dataset like Flickr30k
- Display some samples of the trained model

In [None]:
installs

## STEP 1: IMPORTS
- We need `torch`, `numpy`, `math`, and `matplotlib` as our usual ML imports
- We also need `einops` for tensor rearranging

In [None]:
import torch
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
from datasets import load_dataset
from torch.utils.data import DataLoader
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import functools
import math
from einops import rearrange
from torch.optim import Adam
from torch.optim.lr_scheduler import MultiplicativeLR, LambdaLR
from torch.nn.utils.rnn import pad_sequence
from torchvision.utils import make_grid

## STEP 2: DATASET
- We will be using the [Flickr30kDataset](https://paperswithcode.com/dataset/flickr30k) from our [last CLIP project](https://github.com/samherring99/NightwingCurriculum/blob/main/module_4_diffusion_models/module_4_project_1.ipynb) to make things simple
- We include a transform to resize the images to 112x112px (to reduce memory usage) and convert them to a tensor
- We set 2 possible captions per image, and a batch size of 128

In [None]:
class Flickr30kDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.dataset = load_dataset("nlphuji/flickr30k", cache_dir="./huggingface_data")
        self.transform = transforms.Compose([
            transforms.Resize((112, 112)),
            transforms.ToTensor(),
        ])
        self.cap_per_image = 2

    def __len__(self):
        return self.dataset.num_rows["test"] * self.cap_per_image

    def __getitem__(self, idx):
        original_idx = idx // self.cap_per_image
        image = self.dataset["test"][original_idx]["image"].convert("RGB")
        image = self.transform(image)
        caption = self.dataset["test"][original_idx]["caption"][idx % self.cap_per_image]

        return {"image": image, "caption": caption}
    
flickr30k_dataset = Flickr30kDataset()
print(len(flickr30k_dataset))
flickr_dataloader = DataLoader(flickr30k_dataset, batch_size=128, shuffle=True, num_workers=4)

## STEP 3: TOKENIZATION
- We will reuse our simple tokenization from [this GPT project](https://github.com/samherring99/NightwingCurriculum/blob/main/module_2_advanced_nlp_and_transformers/module_2_project_1.ipynb)
- This makes things simpler so we can focus on the architecture and design of SD
- We provide the usual `encode` and `decode` methods to tokenize our captions

In [None]:
captions = ''.join([str(batch['caption']) for batch in flickr_dataloader])

chars = sorted(list(set(captions)))

stoi = {ch:i for i, ch in enumerate(chars)}
itos = {i:ch for i, ch in enumerate(chars)}

encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

## STEP 4: PROJECTION LAYER
- We want to create a Projection layer that performs a [GaussianFourierProjection](https://mathworld.wolfram.com/FourierTransformGaussian.html)
- Our objective here depends on time (noise added to/removed from an image over `t` time steps), so we need to emded time variations throughout our model's training
- To do this, we take a random sample of the weights at initialization that remain unchanged
- We then calculate the cosine and sine projections of variations in the time step tensors against the original weights of the time steps
- This offers a Gaussian random features we can use to capture temporal patterns

In [None]:
class ProjectionLayer(nn.Module):
    def __init__(self, embed_dim, scale=30.):
        super().__init__()
        self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)

    def forward(self, x):
        x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)

## STEP 5: FULLY CONNECTED LAYER
- This is a simple fully connected layer that has one Linear projection from `input_dim` to `output_dim`
- The forward pass adds two dimensions to the output tensor to be of shape [B, C, H, W] wher B is the batch dimension, C is the number of channels, and H&W are the height and width of the image respectively

In [None]:
class FCLayer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.dense = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.dense(x)[..., None, None]

## STEP 6: SAMPLER
- Here we are building our sampler, which samples from score-based models using the [Euler-Maruyama solver](https://en.wikipedia.org/wiki/Euler%E2%80%93Maruyama_method)
- The first method, `marginal_prob_std`, caclulates the mean and standard deviation of an exponentially increasing noise level, given our time step vector and a `sigma` parameter, which is our initial noise level (25.0 here)
- The second method uses the same input parameters to compute the [diffusion coefficients](https://en.wikipedia.org/wiki/Diffusion_equation) used in sampling by multiplying the diffusion coefficients over the vector of time steps
- Lastly, we can create our partial functions and our sampler method
- We initialize a time step tensor of size `batch_size` and calulate `init_x` which is our starting image tensor
- We set our `time_steps` and `step_size` to be a 1D tensor of 500 steps of equal size from 1.0 to 0.001
- We then iterate over our time steps, calculate our diffusion coefficient `g` and apply it, followed by our scoring model to calulate the mean denoising intermediary image
- Lastly, we update our initial image tensor by adding the intermediary denoising image tensor, and return the denoising tensor at the last timestep
- By including `y` in our scoring model we can pass in text as a 'caption' and the model will generate the closest possible image representation to that given text
- This method generates samples based on provided text using our score-based model

In [None]:
device = "cuda"

def marginal_prob_std(t, sigma):
    t = torch.tensor(t, device=device)
    return torch.sqrt((sigma**(2 * t) - 1.) / 2. / np.log(sigma))

def diffusion_coeff(t, sigma):
    return torch.tensor(sigma**t, device=device)

sigma =  25.0
marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)

def sampler(score_model,
                           marginal_prob_std,
                           diffusion_coeff,
                           batch_size=64,
                           x_shape=(3, 112, 112),
                           num_steps=500,
                           device='cuda',
                           eps=1e-3, y=None):

    t = torch.ones(batch_size, device=device)
    init_x = torch.randn(batch_size, *x_shape, device=device) * marginal_prob_std(t)[:, None, None, None]
    
    time_steps = torch.linspace(1., eps, num_steps, device=device)
    step_size = time_steps[0] - time_steps[1]
    x = init_x
    
    with torch.no_grad():
        for time_step in range(time_steps):
            batch_time_step = torch.ones(batch_size, device=device) * time_step
            g = diffusion_coeff(batch_time_step)
            mean_x = x + (g**2)[:, None, None, None] * score_model(x, batch_time_step, y=y) * step_size
            x = mean_x + torch.sqrt(step_size) * g[:, None, None, None] * torch.randn_like(x)
    
    return mean_x

## STEP 7: ATTENTION
- Now we can start building our pieces for our Transformer blocks
- We will use attention the same way as in previous projects
- This implementation is just single-headed attention for simplicity
- We project the query to a linear layer, and assign our keys and values based on cross vs self attention
- After the forward pass, we use `torch.einsum` to perform the inner product of the query and key tensors 
- We then take the softmax and multiply this by the value tensor, the result of which will be returned

In [None]:
class Attention(nn.Module):
    def __init__(self, embed_dim, hidden_dim, context_dim=None, num_heads=1):
        super(Attention, self).__init__()

        self.hidden_dim = hidden_dim
        self.context_dim = context_dim
        self.embed_dim = embed_dim

        self.query = nn.Linear(hidden_dim, embed_dim, bias=False)
        
        if context_dim is None:
            self.self_attn = True
            self.key = nn.Linear(hidden_dim, embed_dim, bias=False)
            self.value = nn.Linear(hidden_dim, hidden_dim, bias=False)
        else:
            self.self_attn = False
            self.key = nn.Linear(context_dim, embed_dim, bias=False)
            self.value = nn.Linear(context_dim, hidden_dim, bias=False)

    def forward(self, tokens, context=None):
        if self.self_attn:
            Q = self.query(tokens)
            K = self.key(tokens)
            V = self.value(tokens)
        else:
            Q = self.query(tokens)
            K = self.key(context)
            V = self.value(context)

        new_K = torch.squeeze(K)
        new_V = torch.squeeze(V)

        scoremats = torch.einsum("BTH,BSH->BTS", Q, new_K)
        attnmats = F.softmax(scoremats / math.sqrt(self.embed_dim), dim=-1)
        ctx_vecs = torch.einsum("BTS,BSH->BTH", attnmats, new_V)

        return ctx_vecs

## STEP 8: BASIC TRANSFORMER
- Here we implement a basic Transformer block with both self and cross attention, one feed forward network, and 3 layernorms
- We perform self attention after normalizing, then add the residual connection
- Next, we perform cross attention with the provided caption (self attention again if none is provided) and add residual again
- Lastly, we run our feed forward layer with GeLU activation and add the final residual connection

In [None]:
class Transformer(nn.Module):
    def __init__(self, hidden_dim, context_dim):
        super(Transformer, self).__init__()
        self.attn_self = Attention(hidden_dim, hidden_dim)
        self.attn_cross = Attention(hidden_dim, hidden_dim, context_dim)
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)
        self.norm3 = nn.LayerNorm(hidden_dim)
        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, 3 * hidden_dim),
            nn.GELU(),
            nn.Linear(3 * hidden_dim, hidden_dim)
        )

    def forward(self, x, context=None):
        x = self.attn_self(self.norm1(x)) + x
        x = self.attn_cross(self.norm2(x), context=context) + x
        x = self.ffn(self.norm3(x)) + x

        return x

## STEP 9: IMAGE TRANSFORMER
- Using our `Transformer` class from above, we can build a spatial Transformer
- This simply helps shape the data using `einops` so the Transformer can take in images

In [None]:
class ImageTransformer(nn.Module):
    def __init__(self, hidden_dim, context_dim):
        super(ImageTransformer, self).__init__()
        self.transformer = Transformer(hidden_dim, context_dim)

    def forward(self, x, context=None):
        b, c, h, w = x.shape
        x_in = x
        x = rearrange(x, "b c h w -> b (h w) c")
        x = self.transformer(x, context)
        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
        return x + x_in

## STEP 10: UNET TRANSFORMER ARCHITECTURE
- We can finally build out UNet that will provide the upscaling/downscaling needed for training
- Our 
- We create our time embedding projection layer followed by a linear projection layer
- We then have 2 down-convolution layers that follow: [Conv2d](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html), fully connected layer, then `nn.GroupNorm`
- This takes the image from size 112x112 to size 27x27
- The next 2 down-convolution layers are the same but have a Spatial Transformer inserted after the normalization to compute attention
- Next, we have 3 up-convolution layers that follow: [ConvTranspose2d](https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html), fully connected layer, then `nn.GroupNorm` followed by one more ConvTranspose2d
- Lastly, we define our activation function and set the `marginal_prob_std` method
- For the forward pass, we embed our timestep vector and the caption
- We then go down the encoding path (down-scaling), applying convolution, fully connected layer, nromalization, and the activation function (with attention for the last two steps)
- Next, we go through our encoding path, applying the same steps as above but up-scaling and normalizing at the end

In [None]:
class UNet(nn.Module):
    def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256,
                 text_dim=256, nClass=len(chars)):
        super().__init__()

        self.time_embed = nn.Sequential(
            ProjectionLayer(embed_dim=embed_dim),
            nn.Linear(embed_dim, embed_dim)
        )

        self.conv1 = nn.Conv2d(3, channels[0], 3, stride=1, bias=False)
        self.dense1 = FCLayer(embed_dim, channels[0])
        self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])

        self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)
        self.dense2 = FCLayer(embed_dim, channels[1])
        self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])

        self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)
        self.dense3 = FCLayer(embed_dim, channels[2])
        self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])
        self.attn3 = ImageTransformer(channels[2], text_dim)

        self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)
        self.dense4 = FCLayer(embed_dim, channels[3])
        self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])
        self.attn4 = ImageTransformer(channels[3], text_dim)

        self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False)
        self.dense5 = FCLayer(embed_dim, channels[2])
        self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])

        self.tconv3 = nn.ConvTranspose2d(channels[2], channels[1], 3, stride=2, bias=False, output_padding=1)
        self.dense6 = FCLayer(embed_dim, channels[1])
        self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])

        self.tconv2 = nn.ConvTranspose2d(channels[1], channels[0], 3, stride=2, bias=False, output_padding=1)
        self.dense7 = FCLayer(embed_dim, channels[0])
        self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])
        self.tconv1 = nn.ConvTranspose2d(channels[0], 1, 3, stride=1)

        self.act = nn.SiLU()
        self.marginal_prob_std = marginal_prob_std
        self.cond_embed = nn.Embedding(nClass, text_dim)

    def forward(self, x, t, y=None):
        embed = self.act(self.time_embed(t))
        y = y.long()
        y_embed = self.cond_embed(y).unsqueeze(1)

        h1 = self.conv1(x) + self.dense1(embed)
        h1 = self.act(self.gnorm1(h1))
        h2 = self.conv2(h1) + self.dense2(embed)
        h2 = self.act(self.gnorm2(h2))
        h3 = self.conv3(h2) + self.dense3(embed)
        h3 = self.act(self.gnorm3(h3))
        h3 = self.attn3(h3, y_embed)
        h4 = self.conv4(h3) + self.dense4(embed)
        h4 = self.act(self.gnorm4(h4))
        h4 = self.attn4(h4, y_embed)

        h = self.tconv4(h4) + self.dense5(embed)
        h = self.act(self.tgnorm4(h))
        h3_resized = F.interpolate(h3, size=(h.size(2), h.size(3)), mode='bilinear', align_corners=True)
        h = self.tconv3(h + h3_resized) + self.dense6(embed)
        h = self.act(self.tgnorm3(h))
        h2_resized = F.interpolate(h2, size=(h.size(2), h.size(3)), mode='bilinear', align_corners=True)
        h = self.tconv2(h + h2_resized) + self.dense7(embed)
        h = self.act(self.tgnorm2(h))
        h1_resized = F.interpolate(h1, size=(h.size(2), h.size(3)), mode='bilinear', align_corners=True)
        h = self.tconv1(h + h1_resized)

        h_resized = F.interpolate(h, size=(112, 112), mode='bilinear', align_corners=True)

        h = h_resized / self.marginal_prob_std(t)[:, None, None, None]
        return h

## STEP 11: LOSS FUNCTION
- We need to implement our loss function for our model
- This function takes in the `x` and `y` components of the batch (the image and caption)
- Then, randomly sample our timesteps and our random noise
- Use our `marginal_prob_std` method with our random time sample to get standard deviation of perturbation kernel at `t`
- Perturb the input data `x` using our noise `z` and our standard deviation
- Score the model and use the score to calculate loss, return it

In [None]:
def loss_fn(model, x, y, marginal_prob_std, eps=1e-5):
    random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps
    z = torch.randn_like(x)
    std = marginal_prob_std(random_t)
    perturbed_x = x + z * std[:, None, None, None]
    score = model(perturbed_x, random_t, y=y)
    loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1, 2, 3)))
    return loss

## STEP 12: HYPERPARAMETERS
- First, we set up our score model to be out `UNet` from above
- We then set 100 epochs with a learning rate of 0.001
- We will be using a basic Adam optimizer with default parameters for simplicity and a LambdaLR scheduler

In [None]:
score_model = torch.nn.DataParallel(UNet(marginal_prob_std=marginal_prob_std_fn))
score_model = score_model.to(device)

n_epochs = 100
lr = 10e-4         

optimizer = Adam(score_model.parameters(), lr=lr)
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: max(0.2, 0.98 ** epoch))

## STEP 13: TRAINING LOOP
- Now we can set up our training loop for our diffusion model
- We iterate over our epochs, for each one we iterate through our dataloader
- For each batch in the dataloader, we load the images into `x` and the captions into `y`
- We then perturb and compute the loss using our loss function
- After, we step our optimizer and after all batches in the epoch, we step the scheduler
- Finally, once the epochs finish we save the model

In [None]:
total_epochs = range(n_epochs)
for epoch in total_epochs:
    avg_loss = 0.
    num_items = 0

    for batch in flickr_dataloader:
        x = batch['image'].to(device)

        caption = []

        for string in batch['caption']:
            caption.append(torch.Tensor(encode(str(string))).to(device))

        y = pad_sequence(caption, batch_first=True)

        loss = loss_fn(score_model, x, y, marginal_prob_std_fn)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        avg_loss += loss.item() * x.shape[0]
        num_items += x.shape[0]

    scheduler.step()
    lr_current = scheduler.get_last_lr()[0]

    print('{} Average Loss: {:5f} lr {:.1e}'.format(epoch, avg_loss / num_items, lr_current))

    torch.save(score_model.state_dict(), 'ckpt_transformer.pth')

## STEP 14: SAMPLE
- Here we are going to display some generated images by passing in a string prompt and sampling images
- We will use our `sampler` created above for this, with 500 steps and a batch size of 1
- We initialize our sampler with our standard deviation and diffusion coefficient methods created above
- And we then add our prompt into `y`
- After sample generation, we clamp the samples to range [0.0, 1.0] and display/save the image

In [None]:
sample_batch_size = 1
num_steps = 500

sample = sampler(score_model,
                  marginal_prob_std_fn,
                  diffusion_coeff_fn,
                  sample_batch_size,
                  num_steps=num_steps,
                  device=device,
                  y=torch.Tensor(encode("a tuxedo cat sitting on a couch")))

sample = sample.clamp(0.0, 1.0)

img = sample[0].permute(1, 2, 0)

plt.axis('off')
plt.imshow(img.cpu(), vmin=0., vmax=1.)
plt.savefig('sample_image.png')
plt.show()