## Some misc. code snippets while learning diffusion

In [2]:
%load_ext autoreload

In [3]:
%autoreload
# import libraries
import numpy as np
import pickle as pkl
import os
import sys
import torchvision.utils as vutils

import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import torch
%matplotlib inline

from celeba_dataset import CelebA
from unet_diffusion import UNet_Diffusion, get_time_embedding
from noise_scheduler import LinearNoiseScheduler

In [None]:
img_size = (64,64) 
batch_size = 8 
num_timesteps = 1000
beta_start = 0.0001
beta_end = 0.02
lns = LinearNoiseScheduler(num_timesteps, beta_start, beta_end)


---------------------------------------------------------
## Inference

In [1]:
# # Create two random vectors and interpolate between them.
# rand_a = torch.randn(3, 64, 64)
# rand_b = torch.randn(3, 64, 64)
# delta_ab = rand_a - rand_b
# print(delta_ab.shape)
# num_samples = 10

# samples = []
# samples.append(rand_a)
# delt = 1.0/num_samples
# for i in range(1, 9, 1):
#     s = rand_a + (i * delt) * delta_ab
#     samples.append(s)

# samples.append(rand_b)
# print('len(samples):', len(samples))


In [None]:
# Instantiate the model
time_emb_dim = 256 #128


import torchvision
import argparse
import yaml
import os
from torchvision.utils import make_grid
from unet_diffusion import UNet_Diffusion
from diffusion_lightning import DDPM
from tqdm import tqdm

num_samples = 25
num_grid_rows = 5
im_channels = 3
im_size = img_size[0]
num_timesteps = 1000
beta_start = 0.0001
beta_end = 0.02
task_name = 'default'
ckpt_name = 'model_ckpt.pth'

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')


def sample(model, scheduler):
    """
    Sample stepwise by going backward one timestep at a time.
    We save the x0 predictions
    """

    # # Create two random vectors and interpolate between them.
    # rand_a = torch.randn(im_channels, im_size, im_size)
    # rand_b = torch.randn(im_channels, im_size, im_size)
    # delta_ab = rand_a - rand_b
    # samples = []
    # samples.append(rand_a)
    # delt = 1.0/num_samples
    # for i in range(1, (num_samples-1), 1):
    #     s = rand_a + (i * delt) * delta_ab
    #     samples.append(s)

    # samples.append(rand_b)
    # xt = torch.stack(samples).to(device)
    # print('xt shape:', xt.shape)

    xt = torch.randn((num_samples, im_channels, im_size, im_size)).to(device)

    for i in tqdm(reversed(range(num_timesteps))):
        # Get prediction of noise
        noise_pred = model(xt, torch.as_tensor(i).unsqueeze(0).to(device))
        
        # Use scheduler to get x0 and xt-1
        xt, x0_pred = scheduler.sample_prev_timestep(xt, noise_pred, torch.as_tensor(i).to(device))
        
        # Save x0 every 200th time.
        if i % 200 == 0 or (i == num_timesteps-1):
            ims = torch.clamp(xt, -1., 1.).detach().cpu()
            ims = (ims + 1) / 2
            grid = make_grid(ims, nrow=num_grid_rows)
            img = torchvision.transforms.ToPILImage()(grid)
            if not os.path.exists(os.path.join(task_name, 'samples')):
                os.mkdir(os.path.join(task_name, 'samples'))
            img.save(os.path.join(task_name, 'samples', 'x0_{}.png'.format(i)))
            img.close()


def infer():
    map_location = {'cuda:0':'cuda:1'}
    model = DDPM.load_from_checkpoint(checkpoint_path='/home/mark/dev/diffusion/lightning_logs/version_12/checkpoints/epoch=9-step=182340.ckpt',
                                      map_location=map_location)
    
    model.ema_model = None # dump the extra EMA model (to reduce memory footprint)

    total_params = sum(param.numel() for param in model.parameters())
    print('Model has:', int(total_params//1e6), 'M parameters')

    
    # model = UNet_Diffusion(time_emb_dim).to(device)
    # model.load_state_dict(torch.load(os.path.join(task_name, ckpt_name), map_location=device))
    model.eval()
    model.to(device)
    
    # Create the noise scheduler
    scheduler = LinearNoiseScheduler(num_timesteps=num_timesteps,
                                     beta_start=beta_start,
                                     beta_end=beta_end)
    with torch.no_grad():
        sample(model.model, scheduler)

    return



#----------------------------------------------------
# Run the inference
#----------------------------------------------------
infer()


In [None]:
mu, sigma = 0, 0.1 # mean and standard deviation
s = np.random.normal(mu, sigma, 1000)


# Verify the mean and the variance: 
abs(mu - np.mean(s))
0.0  # may vary

abs(sigma - np.std(s, ddof=1))
0.1  # may vary


# Display the histogram of the samples, along with the probability density function:
count, bins, ignored = plt.hist(s, 30, density=True)
plt.plot(bins, 1/(sigma * np.sqrt(2 * np.pi)) *
               np.exp( - (bins - mu)**2 / (2 * sigma**2) ),
         linewidth=2, color='r')
plt.show()

In [None]:
%autoreload
import os
import torch
from torch import utils
from torch import nn
import pytorch_lightning as pl
from torchvision import transforms
from torchvision.transforms.v2 import Resize, Compose, ToDtype, RandomHorizontalFlip, RandomVerticalFlip 
from torchvision.transforms.v2 import RandomResizedCrop, RandomRotation, GaussianBlur, RandomErasing


#--------------------------------------------------------------------
# Dataset, Dataloader
#--------------------------------------------------------------------
from pathlib import Path
image_dir_train = Path('../data/img_align_celeba/img_align_celeba/')

img_size = (64,64) 
batch_size = 8 


train_transforms = Compose([ToDtype(torch.float32, scale=False),
                            RandomHorizontalFlip(p=0.50),
                            # RandomVerticalFlip(p=0.25),
                            # transforms.RandomApply(nn.ModuleList([GaussianBlur(kernel_size=7)]), p=0.5),
                            # transforms.RandomApply(nn.ModuleList([RandomRotation(10.0)]), p=0.5),
                            # RandomResizedCrop(size=img_size, scale=(0.3, 1.0), antialias=True),
                            # RandomErasing(p=0.5, scale=(0.02, 0.20)),
                            Resize(img_size, antialias=True)
                            ])

train_dataset = CelebA(image_dir_train, transform=train_transforms)
train_loader = utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle = True, num_workers=5, persistent_workers=True)


In [None]:
class UnNormalize(object):
    def __init__(self) : #, mean, std):
        pass
    def __call__(self, img):
        img = (img*127.5) + 127.5
        return img
    
unorm  = UnNormalize()

In [None]:
import matplotlib.pyplot as plt

images, _  = next(iter(train_loader))
print(images.shape)
print(torch.min(images[0]), ', ', torch.max(images[0]))


cols = 4
rows = 4
print('num rows:', rows, ', num cols:', cols)
plt.figure(figsize=(10, 10))
idx = 0
for img in (images):  
    img = unorm(img).to(torch.uint8).permute(1, 2, 0)
    # target = unorm(target).to(torch.uint8).permute(1, 2, 0)

    idx += 1
    ax = plt.subplot(rows, cols, idx)
    ax.axis('off')
    plt.imshow(img)

    if idx == (cols*rows):
        break



In [None]:
images_0, _  = next(iter(train_loader))
shape = images_0.shape
print(shape)
noise = torch.randn(shape[2], shape[3])
print(noise.shape)
print(images[0:5].shape)

imgs_n = lns.add_noise(images[0:1], noise, 50)
print(imgs_n.shape)

In [None]:
import matplotlib.pyplot as plt

cols = 2
rows = 1
print('num rows:', rows, ', num cols:', cols)
plt.figure(figsize=(5, 5))
idx = 0

img   = unorm(images[0]).to(torch.uint8).permute(1, 2, 0)
img_n = unorm(imgs_n[0]).to(torch.uint8).permute(1, 2, 0)

idx += 1
ax = plt.subplot(rows, cols, idx)
ax.axis('off')
plt.imshow(img)

idx += 1
ax = plt.subplot(rows, cols, idx)
ax.axis('off')
plt.imshow(img_n)



In [None]:
time_emb_dim = 128
time_steps = torch.ones((512)) * 999
print(time_steps.shape)

blah = time_steps[:, None]
print(blah.shape)

poo = blah.repeat(1, 128//2)
print(poo.shape)


t_emb = get_time_embedding(time_steps, time_emb_dim)
print(t_emb.shape)
print(t_emb)

-------------------------------------------
## Training

In [None]:
#--------------------------------------------------------------------
#
# DDPM Diffusion Model
# as a pytorch lightning module.
#
#--------------------------------------------------------------------

import torch
from pytorch_lightning.core import LightningModule
from torch import nn
import pytorch_lightning as pl
import copy

from unet_diffusion import UNet_Diffusion
from noise_scheduler import LinearNoiseScheduler


# -------------------------------------------------------------------
# Exponential moving average for more stable training
# copied from https://github.com/dome272/Diffusion-Models-pytorch
# -------------------------------------------------------------------
class EMA:
    def __init__(self, beta):
        super().__init__()
        self.beta = beta
        self.step = 0

    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

    def step_ema(self, ema_model, model, step_start_ema=2000):
        if self.step < step_start_ema:
            self.reset_parameters(ema_model, model)
            self.step += 1
            return
        self.update_model_average(ema_model, model)
        self.step += 1

    def reset_parameters(self, ema_model, model):
        ema_model.load_state_dict(model.state_dict())



class DDPM(LightningModule):
    def __init__(self,
                **kwargs):
        super().__init__()
        self.criterion = nn.MSELoss()
        self.num_timesteps = 1000
        self.beta_start = 0.0001
        self.beta_end = 0.02
        self.time_emb_dim = 256
        self.num_epochs = 500
        self.model = UNet_Diffusion(self.time_emb_dim)
        self.scheduler = LinearNoiseScheduler(self.num_timesteps, self.beta_start, self.beta_end)
        self.ema = EMA(0.995)
        self.ema_model = copy.deepcopy(self.model).eval().requires_grad_(False)

        # print('self.optimizers:', self.optimizers)
        # print('self.lr_schedulers:', self.lr_schedulers)
        print('self.current_epoch:', self.current_epoch)

        self.save_hyperparameters()
    
    def forward(self, noisy_im, t):
        return self.model(noisy_im, t)
    
    def common_forward(self, batch):
        imgs = batch[0]
        # Random noise
        noise = torch.randn_like(imgs) 
        # Timestep
        tstep = torch.randint(0, self.num_timesteps, (imgs.shape[0],)) 
        # Add noise to images according to timestep
        noisy_imgs = self.scheduler.add_noise(imgs, noise, tstep).to(imgs)
        # Model tries to learn the noise that was added to im to make noise_im
        noise_pred = self.forward(noisy_imgs, tstep.to(imgs))
        # Loss is our predicted noise relative to actual noise
        loss = self.criterion(noise_pred, noise)
        return loss
    
    # ---------------------------------------------------------------
    # Training step:
    # ---------------------------------------------------------------
    def training_step(self, batch, batch_idx):
        loss = self.common_forward(batch)
        self.log_dict({"loss": loss}, prog_bar=True, sync_dist=True)
        return loss
    
    def on_train_batch_end(self, outputs, batch, batch_idx):
        # After every batch, apply the EMA-based weights update
        self.ema.step_ema(self.ema_model, self.model)
        return

    # ---------------------------------------------------------------
    # Validation step:
    # ---------------------------------------------------------------
    def validation_step(self, batch, batch_idx):
        val_loss = self.common_forward(batch)
        self.log_dict({"val_loss": val_loss}, prog_bar=True, sync_dist=True)
        return val_loss
    
    def on_load_checkpoint(self, checkpoint):
        print("\nRestarting from checkpoint")
        print(type(checkpoint))
        print(checkpoint.keys())
        print('epoch:', checkpoint['epoch'])
        print('global_step:', checkpoint['global_step'])
        print('lr_schedulers:', checkpoint['lr_schedulers'])
        print('loops:', checkpoint['loops'])
        print('hyper_parameters:', checkpoint['hyper_parameters'])
        print('type(optimizer_states):', type(checkpoint['optimizer_states'][0]))
        print('self.current_epoch;', self.current_epoch)
        self.current_epoch = checkpoint['epoch']

        self.ema_model = copy.deepcopy(self.model).eval().requires_grad_(False)
        self.ema.step = checkpoint['global_step'] 
        print('on_load_checkpoint: calling self.ema.step:', self.ema.step)

        return

    def configure_optimizers(self):
        print('calling configure_optimizers')
        lr = 0.0002  
        b1 = 0.5
        b2 = 0.999
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr, betas=(b1, b2))
        # I have no evidence to suggest scheduler is an improvement, but let's give it a whirl anyway :)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9)
        return [optimizer], [scheduler]

 

In [None]:
map_location = {'cuda:0':'cuda:1'}
model = DDPM.load_from_checkpoint(checkpoint_path='/home/mark/dev/diffusion/lightning_logs/version_10/checkpoints/epoch=3-step=72936.ckpt',
                                  map_location=map_location) 



In [None]:
trainer = pl.Trainer(accelerator='gpu', devices=1, max_epochs=1) 

trainer.fit(model=model, train_dataloaders=train_loader)


In [None]:

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    save_top_k=10,
    every_n_epochs=1,
    monitor = 'loss',
    mode = 'min'
)

map_location = {'cuda:0':'cuda:1'}
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

from lightning.pytorch.loggers import TensorBoardLogger
logger = TensorBoardLogger(save_dir=os.getcwd(), name="lightning_logs", default_hp_metric=False)
trainer = pl.Trainer(accelerator='gpu', devices=1, max_epochs=500,
                     logger=logger, log_every_n_steps=1000, callbacks=[checkpoint_callback],
                     checkpoint_path='/home/mark/dev/diffusion/lightning_logs/version_10/checkpoints/epoch=3-step=72936.ckpt') 

trainer.fit(model=model, train_dataloaders=train_loader)
