In [92]:
from VAE.dataset import MNISTDataModule
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import math
from tqdm import tqdm

import torchvision
import matplotlib.pyplot as plt

class ChannelShuffle(nn.Module):
    def __init__(self,groups):
        super().__init__()
        self.groups=groups
    def forward(self,x):
        n,c,h,w=x.shape
        x=x.view(n,self.groups,c//self.groups,h,w) # group
        x=x.transpose(1,2).contiguous().view(n,-1,h,w) #shuffle
        
        return x

class ConvBnSiLu(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0):
        super().__init__()
        self.module=nn.Sequential(nn.Conv2d(in_channels,out_channels,kernel_size,stride=stride,padding=padding),
                                  nn.BatchNorm2d(out_channels),
                                  nn.SiLU(inplace=True))
    def forward(self,x):
        return self.module(x)

class ResidualBottleneck(nn.Module):
    '''
    shufflenet_v2 basic unit(https://arxiv.org/pdf/1807.11164.pdf)
    '''
    def __init__(self,in_channels,out_channels):
        super().__init__()

        self.branch1=nn.Sequential(nn.Conv2d(in_channels//2,in_channels//2,3,1,1,groups=in_channels//2),
                                    nn.BatchNorm2d(in_channels//2),
                                    ConvBnSiLu(in_channels//2,out_channels//2,1,1,0))
        self.branch2=nn.Sequential(ConvBnSiLu(in_channels//2,in_channels//2,1,1,0),
                                    nn.Conv2d(in_channels//2,in_channels//2,3,1,1,groups=in_channels//2),
                                    nn.BatchNorm2d(in_channels//2),
                                    ConvBnSiLu(in_channels//2,out_channels//2,1,1,0))
        self.channel_shuffle=ChannelShuffle(2)

    def forward(self,x):
        x1,x2=x.chunk(2,dim=1)
        x=torch.cat([self.branch1(x1),self.branch2(x2)],dim=1)
        x=self.channel_shuffle(x) #shuffle two branches

        return x

class ResidualDownsample(nn.Module):
    '''
    shufflenet_v2 unit for spatial down sampling(https://arxiv.org/pdf/1807.11164.pdf)
    '''
    def __init__(self,in_channels,out_channels):
        super().__init__()
        self.branch1=nn.Sequential(nn.Conv2d(in_channels,in_channels,3,2,1,groups=in_channels),
                                    nn.BatchNorm2d(in_channels),
                                    ConvBnSiLu(in_channels,out_channels//2,1,1,0))
        self.branch2=nn.Sequential(ConvBnSiLu(in_channels,out_channels//2,1,1,0),
                                    nn.Conv2d(out_channels//2,out_channels//2,3,2,1,groups=out_channels//2),
                                    nn.BatchNorm2d(out_channels//2),
                                    ConvBnSiLu(out_channels//2,out_channels//2,1,1,0))
        self.channel_shuffle=ChannelShuffle(2)

    def forward(self,x):
        x=torch.cat([self.branch1(x),self.branch2(x)],dim=1)
        x=self.channel_shuffle(x) #shuffle two branches

        return x

class TimeMLP(nn.Module):
    '''
    naive introduce timestep information to feature maps with mlp and add shortcut
    '''
    def __init__(self,embedding_dim,hidden_dim,out_dim):
        super().__init__()
        self.act=nn.SiLU()

        self.mlp=nn.Sequential(nn.Linear(embedding_dim,hidden_dim),
                                self.act,
                               nn.Linear(hidden_dim,out_dim))
        
    def forward(self,x,t):
        
        t_emb=self.mlp(t).unsqueeze(-1).unsqueeze(-1)
        # print('t_emb shape', t_emb.shape, 'x.shape', x.shape)
        x=x+t_emb
        return self.act(x)

class TargetMLP_and_CONV(nn.Module):
    def __init__(self,embedding_dim,hidden_dim,out_dim):
        super().__init__()
        self.out_dim = out_dim
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim

        self.act=nn.LeakyReLU() 

        self.size_map = {
            32: 28,
            64: 14,
            128: 7,
        }

        self.mlp=nn.Sequential(nn.Linear(embedding_dim,hidden_dim),
                                self.act,
                                nn.Linear(hidden_dim, self.size_map[out_dim]**2))
        
        self.conv = nn.Sequential(
            nn.Conv2d(out_dim+1, out_dim, 3, 1, 1),
        )
        
    def forward(self, x, y):
        y_emb = self.mlp(y).unsqueeze(1).view(-1, 1, self.size_map[self.out_dim], self.size_map[self.out_dim])
        # print('y_emb shape', y_emb.shape, 'x.shape', x.shape, 'out_dim', self.out_dim, 'hidden_dim', self.hidden_dim, 'embedding_dim', self.embedding_dim)
        x=torch.cat([x,y_emb],dim=1)
        x = self.conv(x)
        return self.act(x)
        

class EncoderBlock(nn.Module):
    def __init__(self,in_channels,out_channels,time_embedding_dim, target_embedding_dim):
        super().__init__()
        self.conv0=nn.Sequential(*[ResidualBottleneck(in_channels,in_channels) for i in range(3)],
                                    ResidualBottleneck(in_channels,out_channels//2))

        self.time_mlp=TimeMLP(embedding_dim=time_embedding_dim, hidden_dim=out_channels, out_dim=out_channels//2)
        self.target_mlp = TargetMLP_and_CONV(embedding_dim=target_embedding_dim, 
                                    hidden_dim = out_channels, out_dim=out_channels//2)
        self.conv1=ResidualDownsample(out_channels//2,out_channels)
    
    def forward(self,x,t=None, y=None):
        # print('x shape', x.shape, 'y.shape', y.shape, 't.shape', t.shape)
        x_shortcut=self.conv0(x)
        if t is not None:
            x=self.time_mlp(x_shortcut,t)
        if y is not None:
            x = self.target_mlp(x_shortcut, y)
        x=self.conv1(x)

        return [x,x_shortcut]
        
class DecoderBlock(nn.Module):
    def __init__(self,in_channels,out_channels,time_embedding_dim, target_embedding_dim):
        super().__init__()
        self.upsample=nn.Upsample(scale_factor=2,mode='bilinear',align_corners=False)
        self.conv0=nn.Sequential(*[ResidualBottleneck(in_channels,in_channels) for i in range(3)],
                                    ResidualBottleneck(in_channels,in_channels//2))

        self.time_mlp=TimeMLP(embedding_dim=time_embedding_dim,hidden_dim=in_channels,out_dim=in_channels//2)
        self.target_mlp = TargetMLP_and_CONV(embedding_dim=target_embedding_dim, 
                                    hidden_dim = in_channels, out_dim=in_channels//2)
        self.conv1=ResidualBottleneck(in_channels//2,out_channels//2)

    def forward(self,x,x_shortcut,t=None, y=None):
        x=self.upsample(x)
        x=torch.cat([x,x_shortcut],dim=1)
        x=self.conv0(x)
        if t is not None:
            x=self.time_mlp(x,t)
        if y is not None:
            x = self.target_mlp(x_shortcut, y) 
        x=self.conv1(x)
        return x        

class Unet(nn.Module):
    '''
    unet design with target input
    '''
    def __init__(self,timesteps,time_embedding_dim, in_channels=3,out_channels=2,base_dim=32,dim_mults=[2,4,8,16]):
        super().__init__()
        assert isinstance(dim_mults,(list,tuple))
        assert base_dim%2==0 

        channels=self._cal_channels(base_dim,dim_mults)

        self.init_conv=ConvBnSiLu(in_channels,base_dim,3,1,1)
        self.time_embedding=nn.Embedding(timesteps,time_embedding_dim)
        targets = 10
        target_embedding_dim = 10

        self.target_embedding=nn.Embedding(targets,target_embedding_dim)



        self.encoder_blocks=nn.ModuleList([EncoderBlock(c[0],c[1],time_embedding_dim, target_embedding_dim) for c in channels])
        self.decoder_blocks=nn.ModuleList([DecoderBlock(c[1],c[0],time_embedding_dim, target_embedding_dim) for c in channels[::-1]])
    
        self.mid_block=nn.Sequential(*[ResidualBottleneck(channels[-1][1],channels[-1][1]) for i in range(2)],
                                        ResidualBottleneck(channels[-1][1],channels[-1][1]//2))

        self.final_conv=nn.Conv2d(in_channels=channels[0][0]//2,out_channels=out_channels,kernel_size=1)

    def forward(self,x,t=None, y=None):
        x=self.init_conv(x)
        if t is not None:
            t=self.time_embedding(t)
            
        if y is not None:
            y = self.target_embedding(y)

        encoder_shortcuts=[]
        for encoder_block in self.encoder_blocks:
            x,x_shortcut=encoder_block(x, t, y)
            encoder_shortcuts.append(x_shortcut)
        
        x=self.mid_block(x)


        encoder_shortcuts.reverse()
        for decoder_block,shortcut in zip(self.decoder_blocks,encoder_shortcuts):
            x=decoder_block(x, shortcut, t, y)
        x=self.final_conv(x)

        return x

    def _cal_channels(self,base_dim,dim_mults):
        dims=[base_dim*x for x in dim_mults]
        dims.insert(0,base_dim)
        channels=[]
        for i in range(len(dims)-1):
            channels.append((dims[i],dims[i+1])) # in_channel, out_channel

        return channels
    

class ImageDiffusion(nn.Module):
    def __init__(
        self,
        timesteps=1000,
        time_embedding_dim=64,
        epsilon=0.008,
    ):
        super().__init__()
        self.timesteps = timesteps
        self.time_embedding_dim = time_embedding_dim
        self.epsilon = epsilon

        self.model = Unet(
            timesteps = timesteps,
            time_embedding_dim=time_embedding_dim,
            in_channels=1,
            out_channels=1,
            base_dim=32,
            dim_mults=[2, 4,]
        )

        self.shape = (-1, 1, 28, 28)

        self.initialize_noise_schedule()

    def initialize_noise_schedule(self, epsilon=0.008): 
        # precalculate noise schedule
        betas = self._cosine_variance_schedule(self.timesteps, epsilon)
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=-1)
        self.register_buffer("betas", betas)
        self.register_buffer("alphas", alphas)
        self.register_buffer("alphas_cumprod", alphas_cumprod)
        self.register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
        self.register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod))

    
    def forward(self, x, y, noise):
        # x:NCHW
        t = torch.randint(0, self.timesteps, (x.shape[0],)).to(x.device)
        x_t = self._forward_diffusion(x, t, noise)

        
        return self.model(x_t, t=t, y=y)

    @torch.no_grad()
    def sampling(self,n_samples,clipped_reverse_diffusion=True,device="mps", y=True):
        x_t = torch.randn( (n_samples, *self.shape[1:]), device=device, dtype=torch.float32)
        hist = []
        hist.append(x_t)

        if y:
            y = torch.randint(0, 10, (n_samples,)).to(device)
        else:
            y = None

        for i in tqdm(range(self.timesteps-1,-1,-1),desc="Sampling", disable=True):
            noise=torch.randn_like(x_t).to(device)
            t=torch.tensor([i for _ in range(n_samples)], device=device, dtype=torch.long)

            if clipped_reverse_diffusion:
                x_t=self._reverse_diffusion_with_clip(x_t, y, t, noise)
            else:
                x_t=self._reverse_diffusion(x_t, y, t, noise)

            hist.append(x_t)

        # x_t=(x_t+1.)/2. #[-1,1] to [0,1]

        hist=torch.stack(hist,dim=0)

        return x_t, hist, y
    
    def _cosine_variance_schedule(self, timesteps, epsilon=0.008):
        steps = torch.linspace(0, timesteps, steps=timesteps + 1, dtype=torch.float32)
        f_t = (
            torch.cos(((steps / timesteps + epsilon) / (1.0 + epsilon)) * math.pi * 0.5)
            ** 2
        )
        betas = torch.clip(1.0 - f_t[1:] / f_t[:timesteps], 0.0, 0.999)

        return betas

    def _forward_diffusion(self, x_0, t, noise):
        assert x_0.shape == noise.shape

        A = self.sqrt_alphas_cumprod.gather(0, t).unsqueeze(1).unsqueeze(1).unsqueeze(1)
        B = self.sqrt_one_minus_alphas_cumprod.gather(0, t).unsqueeze(1).unsqueeze(1).unsqueeze(1)

        return A * x_0 + B * noise

    @torch.no_grad()
    def _reverse_diffusion(self, x_t, y, t, noise):
        """
        p(x_{t-1}|x_{t})-> mean,std

        pred_noise-> pred_mean and pred_std
        """
        pred = self.model(x_t, t=t, y=y)

        alpha_t = self.alphas.gather(-1, t).reshape(x_t.shape[0], 1)  # ,1,1)
        # print('alpha_t', alpha_t.shape)
        alpha_t_cumprod = self.alphas_cumprod.gather(-1, t).reshape(
            x_t.shape[0], 1
        )  # ,1,1)
        beta_t = self.betas.gather(-1, t).reshape(x_t.shape[0], 1)  # ,1,1)
        sqrt_one_minus_alpha_cumprod_t = self.sqrt_one_minus_alphas_cumprod.gather(
            -1, t
        ).reshape(
            x_t.shape[0], 1
        )  # ,1,1)
        mean = (1.0 / torch.sqrt(alpha_t)) * (
            x_t - ((1.0 - alpha_t) / sqrt_one_minus_alpha_cumprod_t) * pred
        )

        if t.min() > 0:
            alpha_t_cumprod_prev = self.alphas_cumprod.gather(-1, t - 1).reshape(
                x_t.shape[0], 1
            )  # ,1,1)
            std = torch.sqrt(
                beta_t * (1.0 - alpha_t_cumprod_prev) / (1.0 - alpha_t_cumprod)
            )
        else:
            std = 0.0

        return mean + std * noise

    @torch.no_grad()
    def _reverse_diffusion_with_clip(self, x_t, y, t, noise): 
        '''
        p(x_{0}|x_{t}),q(x_{t-1}|x_{0},x_{t})->mean,std

        pred_noise -> pred_x_0 (clip to [-1.0,1.0]) -> pred_mean and pred_std
        '''
        pred=self.model(x_t,t=t,y=y)
        alpha_t=self.alphas.gather(-1,t).reshape(x_t.shape[0],1,1,1)
        alpha_t_cumprod=self.alphas_cumprod.gather(-1,t).reshape(x_t.shape[0],1,1,1)
        beta_t=self.betas.gather(-1,t).reshape(x_t.shape[0],1,1,1)
        
        x_0_pred=torch.sqrt(1. / alpha_t_cumprod)*x_t-torch.sqrt(1. / alpha_t_cumprod - 1.)*pred
        x_0_pred.clamp_(-1., 1.)

        if t.min()>0:
            alpha_t_cumprod_prev=self.alphas_cumprod.gather(-1,t-1).reshape(x_t.shape[0],1,1,1)
            mean= (beta_t * torch.sqrt(alpha_t_cumprod_prev) / (1. - alpha_t_cumprod))*x_0_pred +\
                 ((1. - alpha_t_cumprod_prev) * torch.sqrt(alpha_t) / (1. - alpha_t_cumprod))*x_t

            std=torch.sqrt(beta_t*(1.-alpha_t_cumprod_prev)/(1.-alpha_t_cumprod))
        else:
            mean=(beta_t / (1. - alpha_t_cumprod))*x_0_pred #alpha_t_cumprod_prev=1 since 0!=1
            std=0.0

        return mean+std*noise 

# make pl model
class ImageDiffusionModule(pl.LightningModule):
    def __init__(
        self,
        criteria=nn.MSELoss(),
        **kwargs,
    ):
        super().__init__()
        self.lr = kwargs.get("LEARNING_RATE", 0.001)
        self.model = ImageDiffusion(
            timesteps=kwargs.get("TIMESTEPS", 300),
            time_embedding_dim=kwargs.get("TIME_EMBEDDING_DIM", 8),
            # target_embedding_dim=kwargs.get("TARGET_EMBEDDING_DIM", 8),
            epsilon=kwargs.get("EPSILON", 0.008),
        )
        
        self.criteria = criteria
    
    def forward(self, data):
        
        x, y = data
        if self.current_epoch < 0:
            y = None
        noise = torch.randn_like(x)
        pred_noise = self.model(x, y, noise)
        return noise, pred_noise

    def training_step(self, batch, batch_idx):
        loss = self._common_step(batch, stage="train")

        # clip gradients
        torch.nn.utils.clip_grad_norm_(self.parameters(), 1)
        return loss
    
    def _common_step(self, batch, stage='train'):
        noise, pred_noise = self.forward(batch)
        loss = self.criteria(pred_noise, noise) 
        self.log('total_loss', loss, prog_bar=True) 
        return loss

    def validation_step(self, batch, batch_idx):
        self._common_step(batch, stage="val")

    def on_validation_epoch_end(self):
        if self.current_epoch < 0:
            y = False
        else:
            y = True
        # sample from model
        x_t, hist, y = self.model.sampling(6, y=y)
        hist = hist[::len(hist)//4].squeeze().cpu()
        # print('hist shape', hist.shape)
        
        # fig, ax 
        fig, ax = plt.subplots(hist.shape[0], hist.shape[1], figsize=(10, 10 * hist.shape[0]/hist.shape[1]))
        for i in range(hist.shape[0]):
            for j in range(hist.shape[1]):
                ax[i, j].imshow(hist[i, j], cmap='gray')
                ax[i, j].axis('off')

        # set top row titles to y
        for i, yi in enumerate(y):
            ax[0, i].set_title(f'y={yi.item()}')
        plt.savefig(f'assets/diffusion_full/{self.current_epoch}.png', dpi=400)
        plt.close()


    def test_step(self, batch, batch_idx):
        self._common_step(batch, stage="test")

    def configure_optimizers(self):
        # return torch.optim.AdamW(self.parameters(), lr=self.lr)
        # decrease lr by 0.1 every 10 epochs
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9, verbose=False)
        return [optimizer], [scheduler]



dm = MNISTDataModule(BATCH_SIZE=64)
dm.setup()

plModule = ImageDiffusionModule(
    LEARNING_RATE = 0.01,
    TIMESTEPS=200
)

trainer = pl.Trainer(max_epochs=100)
trainer.fit(plModule, dm)


GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


DataModule params:
	batch_size: 64
	path: /Users/tonton/Documents/motion-synthesis/mnist_latent_diffusion/
	rotation: 0
	scale: 0
	translate: (0, 0)
	shear: 0
	normalize: (0.1307, 0.3081)
	bool: False



  | Name     | Type           | Params
--------------------------------------------
0 | model    | ImageDiffusion | 428 K 
1 | criteria | MSELoss        | 0     
--------------------------------------------
428 K     Trainable params
0         Non-trainable params
428 K     Total params
1.715     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]