## Some misc. code snippets while learning diffusion

In [1]:
%load_ext autoreload

In [2]:
%autoreload
# import libraries
import numpy as np
import pickle as pkl
import os
import sys
import torchvision.utils as vutils

import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import torch
%matplotlib inline

from celeba_dataset import CelebA
from unet_diffusion import UNet_Diffusion, get_time_embedding
from noise_scheduler import LinearNoiseScheduler

In [3]:
img_size = (64,64) 
batch_size = 8 
num_timesteps = 1000
beta_start = 0.0001
beta_end = 0.02
lns = LinearNoiseScheduler(num_timesteps, beta_start, beta_end)


------------------------------------------
### Experiments that give good results:

<hr>

### Experiment 1

-  No augmentations other than horizontal flips (do NOT use Gaussian blur from pytorch packages since this messes up the scheduled noising of the images)
-  The AttentionBlock uses nn.MultiheadAttention with #heads = 4
-  Attention=True only for self.down_1, self.down_2, self.up_2, self.up_1 (all others false)
-  img_shape = (64,64), batch_size=10, two-gpu strategy='ddp_find_unused_parameters_true'
-  time_emb dimension = 256
-  Epochs = 26 (~250K batches of 10 images)
-  num_timesteps = 1000, beta_start = 0.0001, beta_end = 0.02
-  Exponential moving average with warmup of 2000 batches
-  Adam optimizer, lr = 0.0002, b1 = 0.5, b2 = 0.999
-  scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9)

| <img src="images/x0_0_save11_works_again.png" alt="" width="300"/> | 
|:--:| 

<hr>

### Experiment 2

-  The AttentionBlock uses my multiheaded self attention code instead of pytorch's code.
-  Number of heads = 12
-  No augmentations other than horizontal flips
-  Attention=True only for self.down_1, self.down_2, self.up_2, self.up_1 (all others false)
-  img_shape = (64,64), batch_size=80, two-gpu strategy='ddp_find_unused_parameters_true'
-  time_emb dimension = 256
-  Epochs ~40 
-  num_timesteps = 1000, beta_start = 0.0001, beta_end = 0.02
-  Exponential moving average with warmup of 2000 batches
-  scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9)

| <img src="images/x0_0_save13_newAttn_more_fully_conv_12heads.png" alt="" width="300"/> | 
|:--:| 



---------------------------------------------------------
## Inference

In [4]:
# Instantiate the model
time_emb_dim = 256 #128


import torchvision
import argparse
import yaml
import os
from torchvision.utils import make_grid
from unet_diffusion import UNet_Diffusion
from diffusion_lightning import DDPM
from tqdm import tqdm

num_samples = 25
num_grid_rows = 5
im_channels = 3
im_size = img_size[0]
num_timesteps = 1000
beta_start = 0.0001
beta_end = 0.02
task_name = 'default'
ckpt_name = 'model_ckpt.pth'

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')


def sample(model, scheduler):
    """
    Sample stepwise by going backward one timestep at a time.
    We save the x0 predictions
    """

    # # Create two random vectors and interpolate between them.
    # rand_a = torch.randn(im_channels, im_size, im_size)
    # rand_b = torch.randn(im_channels, im_size, im_size)
    # delta_ab = rand_a - rand_b
    # samples = []
    # samples.append(rand_a)
    # delt = 1.0/num_samples
    # for i in range(1, (num_samples-1), 1):
    #     s = rand_a + (i * delt) * delta_ab
    #     samples.append(s)

    # samples.append(rand_b)
    # xt = torch.stack(samples).to(device)
    # print('xt shape:', xt.shape)

    xt = torch.randn((num_samples, im_channels, im_size, im_size)).to(device)

    for i in tqdm(reversed(range(num_timesteps))):
        # Get prediction of noise
        noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device))
        
        # Use scheduler to get x0 and xt-1
        xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device))
        
        # Save x0 every 200th time.
        if i % 200 == 0 or (i == num_timesteps-1):
            ims = torch.clamp(xt, -1., 1.).detach().cpu()
            ims = (ims + 1) / 2
            grid = make_grid(ims, nrow=num_grid_rows)
            img = torchvision.transforms.ToPILImage()(grid)
            if not os.path.exists(os.path.join(task_name, 'samples')):
                os.mkdir(os.path.join(task_name, 'samples'))
            img.save(os.path.join(task_name, 'samples', 'x0_{}.png'.format(i)))
            img.close()


def infer():
    # map_location = {'cuda:0':'cuda:1'}
    model = DDPM.load_from_checkpoint(checkpoint_path='/home/mark/dev/diffusion/lightning_logs/version_9/checkpoints/epoch=28-step=33060.ckpt') #,
                                    #   map_location=map_location)
    
    model.ema_model = None # dump the extra EMA model (to reduce memory footprint)

    total_params = sum(param.numel() for param in model.parameters())
    print('Model has:', int(total_params//1e6), 'M parameters')

    
    # model = UNet_Diffusion(time_emb_dim).to(device)
    # model.load_state_dict(torch.load(os.path.join(task_name, ckpt_name), map_location=device))
    model.eval()
    model.to(device)
    
    # Create the noise scheduler
    scheduler = LinearNoiseScheduler(num_timesteps=num_timesteps,
                                     beta_start=beta_start,
                                     beta_end=beta_end)
    with torch.no_grad():
        sample(model.model, scheduler)

    return



#----------------------------------------------------
# Run the inference
#----------------------------------------------------
infer()



Restarting from checkpoint
on_load_checkpoint: calling self.ema.step: 33060
Model has: 150 M parameters


1000it [00:50, 19.81it/s]


----------------
## Misc debugging code for the AttentionBlock (pytorch's vs. the hand-crafted version)

In [None]:
import numpy as np
import pickle as pkl
import os
import sys
import torchvision.utils as vutils


import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import torch
from einops import rearrange
from torch import nn




    # def forward(self, x):
    #     print('\ninput x shape:', x.shape)
    #     b, c, h, w = x.shape
    #     in_attn = self.attention_norm(x)
    #     in_attn = x.reshape(b, h * w, c)
    #     # in_attn = in_attn.transpose(1, 2)  # reshape to [b, (h*w), c] i.e. [b, seq, emb_dim]
    #     print('in_attn shape:', in_attn.shape)

    #     qkv = self.to_qkv(in_attn).chunk(3, dim = -1)
    #     print('qkv, len:', len(qkv), ', qkv[0] shape:', qkv[0].shape)
    #     q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
    #     print('q shape:', q.shape, ', k shape:', k.shape, ', v shape:', v.shape)
    #     dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
    #     print('q*k dot product shape:', dots.shape)
    #     attn = self.attend(dots)
    #     print('attn shape:', attn.shape)
    #     out = torch.matmul(attn, v)
    #     print('1. out shape:', out.shape)
    #     out = rearrange(out, 'b h n d -> b n (h d)')
    #     out = self.to_out(out)
    #     out = out.transpose(1, 2).reshape(b, c, h, w)
    #     print('2. out shape:', out.shape)
    #     return out 

    # def _reset_parameters(self):
    #         if self._qkv_same_embed_dim:
    #             xavier_uniform_(self.in_proj_weight)
    #         else:
    #             xavier_uniform_(self.q_proj_weight)
    #             xavier_uniform_(self.k_proj_weight)
    #             xavier_uniform_(self.v_proj_weight)

    #         if self.in_proj_bias is not None:
    #             constant_(self.in_proj_bias, 0.)
    #             constant_(self.out_proj.bias, 0.)
    #         if self.bias_k is not None:
    #             xavier_normal_(self.bias_k)
    #         if self.bias_v is not None:
    #             xavier_normal_(self.bias_v)        


import math

class AttentionBlock_new(nn.Module):
    def __init__(self, dim, heads = 4, numgroups=8, dim_head = 64, dropout = 0.):  
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)
        self.heads = heads
        self.attention_norm = nn.GroupNorm(numgroups, dim)
        self.scale = dim_head ** -0.5
        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()
 
        # nn.init.normal_(self.to_qkv, mean=0., std=np.sqrt(2 / (dim+inner_dim)))
        # print(self.to_qkv.weight
        # # torch.nn.init.normal_(self.to_qkv, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

    def forward(self, x):
        b, c, h, w = x.shape
        in_attn = self.attention_norm(x)
        in_attn = x.reshape(b, h * w, c)
        qkv = self.to_qkv(in_attn).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = self.attend(dots)
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        out = out.transpose(1, 2).reshape(b, c, h, w)
        return out 


class AttentionBlock(nn.Module):
    def __init__(self, out_channels, num_heads=4, numgroups=8):
        super().__init__()
        self.attention_norms = nn.GroupNorm(numgroups, out_channels)
        self.attentions = nn.MultiheadAttention(out_channels, num_heads, batch_first=True)

    def forward(self, x):
        out = x
        # Attention block of Unet
        batch_size, channels, h, w = out.shape
        in_attn = out.reshape(batch_size, channels, h * w)
        in_attn = self.attention_norms(in_attn)
        in_attn = in_attn.transpose(1, 2)    #So, I guess: [N, (h*w), C] where (h*w) is the target "sequence length", and C is the embedding dimension
        out_attn, _ = self.attentions(in_attn, in_attn, in_attn)
        print('\nout_attn shape:', out_attn.shape)
        out_attn = out_attn.transpose(1, 2).reshape(batch_size, channels, h, w)
        return out_attn


In [None]:
import math
a = 0.02/math.sqrt(2 * (512))
print(a)

In [None]:
dim = 256
heads = 4
dim_head =128
inner_dim = dim_head *  heads
numgroups = 8

x = torch.randn([2, 256, 32, 32])
print('in x shape:', x.shape)

b, c, h, w = x.shape
norm = nn.GroupNorm(numgroups, dim)
in_attn = norm(x)
in_attn = x.reshape(b, h * w, c)
# in_attn = in_attn.transpose(1, 2)  # reshape to [b, (h*w), c] i.e. [b, seq, emb_dim]
print('in_attn shape:', in_attn.shape)


to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
print('to_qkv, mean:', torch.mean(to_qkv.weight.data), ', std:', torch.std(to_qkv.weight.data))
nn.init.normal_(to_qkv.weight.data, mean=0., std=np.sqrt(2 / (dim+inner_dim)))
print('to_qkv, mean:', torch.mean(to_qkv.weight.data), ', std:', torch.std(to_qkv.weight.data))
nn.init.xavier_normal_(to_qkv.weight.data)
print('to_qkv, mean:', torch.mean(to_qkv.weight.data), ', std:', torch.std(to_qkv.weight.data))


qkv = to_qkv(in_attn)
print('out shape:', qkv.shape)

qkv = qkv.chunk(3, dim = -1)
print('q shape:', qkv[0].shape)

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = heads), qkv)
print('q shape:', q.shape)


dots = torch.matmul(q, k.transpose(-1, -2)) 
print('dots shape:', dots.shape)

out = torch.matmul(dots, v)
print('1 out shape:', out.shape)

out = rearrange(out, 'b h n d -> b n (h d)')
print('2 out shape:', out.shape)

to_out = nn.Linear(inner_dim, dim)

out = to_out(out)
print('3 out shape:', out.shape)

out = out.transpose(1, 2).reshape(b, c, h, w)
print('4 out shape:', out.shape)


In [None]:
b = 2
emb_dim = c = 256
h = 32
w = 32
groups = 8
heads = 4
dim_head = 64 #emb_dim//heads
dropout = 0
x = torch.randn((b, c, h, w))
print('input shape:', x.shape)

attn1 = AttentionBlock_new(emb_dim, heads)
attn2 = AttentionBlock(c, heads, groups)

In [None]:
attn1.apply(lambda m: print(type(m).__name__))

In [None]:
out1 = attn1.forward(x)
print(out1.shape, ', mean:', torch.mean(out1), ', std:', torch.std(out1))
print()


In [None]:
out2 = attn2.forward(x)
print(out2.shape, ', mean:', torch.mean(out2), ', std:', torch.std(out2))
print()

In [None]:
patch_size = 16
batch = torch.randn([32, 3, 128, 128])
make_patches = nn.Conv2d(3, 3, 16, 16, padding=1)

patches = make_patches(batch)
print('patches shape:', patches.shape)

In [None]:
mu, sigma = 0, 0.1 # mean and standard deviation
s = np.random.normal(mu, sigma, 1000)


# Verify the mean and the variance: 
abs(mu - np.mean(s))
0.0  # may vary

abs(sigma - np.std(s, ddof=1))
0.1  # may vary


# Display the histogram of the samples, along with the probability density function:
count, bins, ignored = plt.hist(s, 30, density=True)
plt.plot(bins, 1/(sigma * np.sqrt(2 * np.pi)) *
               np.exp( - (bins - mu)**2 / (2 * sigma**2) ),
         linewidth=2, color='r')
plt.show()

In [None]:
%autoreload
import os
import torch
from torch import utils
from torch import nn
import pytorch_lightning as pl
from torchvision import transforms
from torchvision.transforms.v2 import Resize, Compose, ToDtype, RandomHorizontalFlip, RandomVerticalFlip 
from torchvision.transforms.v2 import RandomResizedCrop, RandomRotation, GaussianBlur, RandomErasing

from celeba_dataset import CelebA

#--------------------------------------------------------------------
# Dataset, Dataloader
#--------------------------------------------------------------------
from pathlib import Path
image_dir_train = Path('../data/img_align_celeba/img_align_celeba/')

img_size = (64,64) 
batch_size = 10


train_transforms = Compose([ToDtype(torch.float32, scale=False),
                            RandomHorizontalFlip(p=0.50),
                            Resize(img_size, antialias=True)
                            ])

train_dataset = CelebA(image_dir_train, transform=train_transforms, limit_size=True, size_limit=20)
train_loader = utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle = True, num_workers=5, persistent_workers=True)


In [None]:
class UnNormalize(object):
    def __init__(self) : #, mean, std):
        pass
    def __call__(self, img):
        img = (img*127.5) + 127.5
        return img
    
unorm  = UnNormalize()

In [None]:
import matplotlib.pyplot as plt

images, _  = next(iter(train_loader))
print(images.shape)
print(torch.min(images[0]), ', ', torch.max(images[0]))


cols = 5
rows = 4
print('num rows:', rows, ', num cols:', cols)
plt.figure(figsize=(10, 10))
idx = 0
for img in (images):  
    img = unorm(img).to(torch.uint8).permute(1, 2, 0)
    # target = unorm(target).to(torch.uint8).permute(1, 2, 0)

    idx += 1
    ax = plt.subplot(rows, cols, idx)
    ax.axis('off')
    plt.imshow(img)

    if idx == (cols*rows):
        break



In [None]:
images_0, _  = next(iter(train_loader))
shape = images_0.shape
print(shape)
noise = torch.randn(shape[2], shape[3])
print(noise.shape)
print(images[0:5].shape)

imgs_n = lns.add_noise(images[0:1], noise, 50)
print(imgs_n.shape)

In [None]:
import matplotlib.pyplot as plt

cols = 2
rows = 1
print('num rows:', rows, ', num cols:', cols)
plt.figure(figsize=(5, 5))
idx = 0

img   = unorm(images[0]).to(torch.uint8).permute(1, 2, 0)
img_n = unorm(imgs_n[0]).to(torch.uint8).permute(1, 2, 0)

idx += 1
ax = plt.subplot(rows, cols, idx)
ax.axis('off')
plt.imshow(img)

idx += 1
ax = plt.subplot(rows, cols, idx)
ax.axis('off')
plt.imshow(img_n)



In [None]:
time_emb_dim = 128
time_steps = torch.ones((512)) * 999
print(time_steps.shape)

blah = time_steps[:, None]
print(blah.shape)

poo = blah.repeat(1, 128//2)
print(poo.shape)


t_emb = get_time_embedding(time_steps, time_emb_dim)
print(t_emb.shape)
print(t_emb)

-------------------------------------------
## Training

In [None]:
from unet_diffusion import UNet_Diffusion, get_time_embedding
from diffusion_lightning import DDPM, EMA

In [None]:
# map_location = {'cuda:0':'cuda:1'}
# model = DDPM.load_from_checkpoint(checkpoint_path='/home/mark/dev/diffusion/lightning_logs/version_10/checkpoints/epoch=3-step=72936.ckpt',
#                                   map_location=map_location) 



In [None]:
# model=  DDPM()
trainer = pl.Trainer(accelerator='cpu', devices=1, max_epochs=100) 
trainer.fit(model=model, train_dataloaders=train_loader)


In [None]:

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    save_top_k=10,
    every_n_epochs=1,
    monitor = 'loss',
    mode = 'min'
)

map_location = {'cuda:0':'cuda:1'}
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

from lightning.pytorch.loggers import TensorBoardLogger
logger = TensorBoardLogger(save_dir=os.getcwd(), name="lightning_logs", default_hp_metric=False)
trainer = pl.Trainer(accelerator='gpu', devices=1, max_epochs=500,
                     logger=logger, log_every_n_steps=1000, callbacks=[checkpoint_callback],
                     checkpoint_path='/home/mark/dev/diffusion/lightning_logs/version_10/checkpoints/epoch=3-step=72936.ckpt') 

trainer.fit(model=model, train_dataloaders=train_loader)
