In [13]:
# let's implemet ddpm with a u-net with a score-matching (or equivalently, noise prediction) objective 
# same setup as in DiT but now with a convolutional backbone. Recall, we take noised_img [b, ch, h, w] -> noise_pred [b, ch, h, w]
# with (eps_true - eps_pred).mean() as loss 
# new additions here: groupnorm, new type of timeEmbeddings, Ublock and Unet with bottleneck structure compared to DiT 
import torchvision 
import math 
import torch 
import torch.nn as nn 
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

# this conv is more involved than in eg. our resnet implementation because we will be supporting arbitrary up/downsampling of inputs 
# but the core idea of converting a conv into a batched matmul is the same, just with casework to interpolate upsample before conv or 
# add a stride to downsample during the conv, conceptually not very different
# and i'm sure i could've abstracted out some of the code to make this cleaner/shorter but nbd
class Conv(nn.Module): # [b, ch, h, w] -> [b, ch', h', w'] where ch', h', w' depend on down/upsampling ratio 
    def __init__(self, upsample_ratio=1.0, kernel_sz=3, ch_in=1, ch_out=1): 
        super().__init__()
        self.upsample_ratio = upsample_ratio
        self.kernel_sz = kernel_sz
        self.ch_in = ch_in
        self.ch_out = ch_out
        self.kernel_weights = torch.nn.Parameter(torch.randn(ch_out, ch_in, kernel_sz, kernel_sz))

    def forward(self, x): # residual conv 
        if self.upsample_ratio > 1.0:
            # interpolate up then conv to maintain 
            x = F.interpolate(x, scale_factor=self.upsample_ratio)

            # conv to maintain 
            patches = F.unfold(x, kernel_size=self.kernel_sz, padding=self.kernel_sz//2) # [b, ch_in * k * k, np]
            kernel_flat = self.kernel_weights.reshape(self.ch_out, -1) # [ch_out, ch_in * k * k]
            out = torch.einsum('bmn,om->bon', patches, kernel_flat)# [b, ch_out, np]
            b, _, h, w = x.shape
            out = out.reshape(b, self.ch_out, h, w)
            return out 
        elif self.upsample_ratio < 1.0:
            # downsample by setting stride=1//upsample_ratio
            stride = int(1/self.upsample_ratio)
            padding = int(1/self.upsample_ratio)

            patches = F.unfold(x, kernel_size=self.kernel_sz, padding=padding, stride=stride) # [b, ch_in * k * k, np]
            kernel_flat = self.kernel_weights.reshape(self.ch_out, -1) # [ch_out, ch_in * k * k]
            out = torch.einsum('bmn,om->bon', patches, kernel_flat)# [b, ch_out, np]
            b, _, h, w = x.shape

            h_out = (h + 2*padding - self.kernel_sz)//stride + 1
            w_out = (w + 2*padding - self.kernel_sz)//stride + 1
            out = out.reshape(b, self.ch_out, h_out, w_out)
            return out 
        else: # == 1 
            # this conv maintains dimensionality
            patches = F.unfold(x, kernel_size=self.kernel_sz, padding=self.kernel_sz//2) # [b, ch_in * k * k, np]
            kernel_flat = self.kernel_weights.reshape(self.ch_out, -1) # [ch_out, ch_in * k * k]
            out = torch.einsum('bmn,om->bon', patches, kernel_flat)# [b, ch_out, np]
            b, _, h, w = x.shape
            out = out.reshape(b, self.ch_out, h, w)
            return out 


class GroupNorm(nn.Module): 
    def __init__(self, ch=3, channels_per_group=3, eps = 1e-8): 
        super().__init__()
        if channels_per_group > ch: 
            channels_per_group = ch 
        self.channels_per_group = channels_per_group
        assert ch % channels_per_group == 0, \
            "GroupNorm requires number of channels to be a multiple of channels_per_group!"
        self.num_groups = int(ch/channels_per_group)
        self.shift = nn.Parameter(torch.zeros(self.num_groups))
        self.scale = nn.Parameter(torch.ones(self.num_groups))
        self.eps = eps 

    
    def forward(self, x): # [b, ch, h, w] -> [b, ch, h, w] 
        # reshape to [b, ch//g, g * h * w], normalize within the last, and reshape back 
        # g is number of groups 
        b, ch, h, w = x.shape
        out = x.reshape(b, self.num_groups, self.channels_per_group * h * w)
        out = (out - out.mean(dim=-1, keepdim=True))/(out.std(dim=-1, keepdim=True) + self.eps)
        # broadcast from [ng] to [1, ng, 1] to can mult with [b, ng, cpg * h * w]
        shift_reshaped = self.shift.view(1, self.num_groups, 1)
        scale_reshaped = self.scale.view(1, self.num_groups, 1)
        out = (out + shift_reshaped) * scale_reshaped  # apply to each group in each batch separately
        return out.reshape(b, ch, h, w)

class Attention(nn.Module): 
    def __init__(self, D=1): # ch = 1 for mnist, ch=3 for cifar-10
        super().__init__()
        self.D = D
        self.wq = nn.Linear(D, D)
        self.wk = nn.Linear(D, D)
        self.wv = nn.Linear(D, D)
        self.wo = nn.Linear(D, D)
    
    def forward(self, x): 
        # fwd is [b, ch, h, w] -> [b, ch, h * w] -> then 
        # self attn across last dim with b, ch as batch dims, reshape to [b, ch, h, w]
        b, ch, h, w = x.shape
        # [b, h*w, ch] like in a transformer now [b,s,d] with s = h*w tokens and ch = features (D)
        x = x.reshape(b, h * w, ch).transpose(-1, -2) # [b, ch, h * w]
        q, k, v = self.wq(x), self.wk(x), self.wv(x) # [b, s, d]

        scale = math.sqrt(self.D)
        A_logits = torch.bmm(q/scale, k.transpose(-1, -2))
        A = F.softmax(A_logits, dim=-1) # [b, s, s]

        out = torch.bmm(A, v) # [b, s, s] @ [b, s, d] -> [b, s, d]
        out = self.wo(out) # [b, s, d] = [b, h * w, ch]
        return out.transpose(-1, -2).reshape(b, ch, h, w) # output [b, ch, h, w]
        
class TimeEmbedding(nn.Module): # 
    def __init__(self, dim, max_period=10000, mlp_mult=4):
        super().__init__()
        self.dim = dim
        self.max_period = max_period # controls frequency of our sinusoidal embeddings
        
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * mlp_mult), 
            nn.SiLU() ,
            nn.Linear(dim * mlp_mult, dim), 
        )
    
    def forward(self, t): # [b] -> [b, dim] where then projection to [b, 2*ch] handled in UNet class
        # sinusoidal embeddings, freqs shape: [dim//2]
        half_dim = self.dim // 2
        freqs = torch.exp(
            -torch.arange(half_dim, device=t.device) * torch.log(torch.tensor(self.max_period)) / half_dim
        )
        
        args = t[:,None] * freqs[None, :] # [b, dim//2]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)  # [b, dim]
        
        # handle odd dim
        if self.dim % 2 == 1:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
            
        return self.mlp(embedding)

class UBlock(nn.Module):
    # 2x [GN, +embeddings, SiLU, Conv] with global residual 
    def __init__(self, ch_in, up=True, bottleneck=False, k=3): # gamma, beta are both [b, ch] vectors from TimeEmbeddings that we use to steer things 
        super().__init__()
        # bottleneck means 1x1 + self attn 
        self.up = up 
        self.gn1 = GroupNorm(ch=ch_in)
        self.silu1 = nn.SiLU()
        self.gn2 = GroupNorm(ch=ch_in)
        self.silu2 = nn.SiLU()

        self.bottleneck = bottleneck

        if self.up: # upsample by 2
            self.conv1 = Conv(ch_in=ch_in, ch_out=ch_in, upsample_ratio=2, kernel_sz=k)
            self.conv2 = Conv(ch_in=ch_in, ch_out=ch_in, upsample_ratio=2, kernel_sz=k)
        elif bottleneck: # preserve shape and use attn
            self.attn = Attention(D=ch_in)
            self.conv1 = Conv(ch_in=ch_in, ch_out=ch_in, upsample_ratio=1, kernel_sz=k)
            self.conv2 = Conv(ch_in=ch_in, ch_out=ch_in, upsample_ratio=1, kernel_sz=k)
        else: # downsample by 2
            self.conv1 = Conv(ch_in=ch_in, ch_out=ch_in, upsample_ratio=0.5, kernel_sz=k)
            self.conv2 = Conv(ch_in=ch_in, ch_out=ch_in, upsample_ratio=0.5, kernel_sz=k)
            self.skip = nn.Conv2d(in_channels=ch_in, out_channels=ch_in, kernel_size=k, stride=2)

    def forward(self, x, scale, shift): # conv, norm, and skip connection, with optional self-attn
        b, ch, h, w = x.shape 
        og_x = x.clone()
        if self.up: 
            og_x = F.interpolate(og_x, scale_factor=2)
        elif not self.up and not self.bottleneck: # self.down 
            og_x = self.skip(og_x)

        print(f'in ublock fwd, input x has shape {x.shape}')

        scale = scale.view(b, ch, 1, 1)
        shift = shift.view(b, ch, 1, 1)
        
        print(f"Input x shape: {x.shape}")
        h1 = self.gn1(x) * scale + shift
        print(f"After gn1 shape: {h1.shape}")
        h1 = self.silu1(h1)
        print(f"After silu1 shape: {h1.shape}")
        if self.bottleneck: 
            h1 = self.attn(h1) # attn preserves [b, ch, h, w] shape
            print(f"After attn shape: {h1.shape}")
        h1 = self.conv1(h1)
        print(f"After conv1 shape: {h1.shape}")

        h2 = self.gn2(h1) * scale + shift
        print(f"After gn2 shape: {h2.shape}")
        h2 = self.silu2(h2)
        print(f"After silu2 shape: {h2.shape}")
        if self.bottleneck: 
            h2 = self.attn(h2)
            print(f"After attn shape: {h2.shape}")
        h2 = self.conv2(h2)
        print(f"After conv2 shape: {h2.shape}")
        
        return h2 + og_x
        

class UNet(nn.Module): # [b, ch, h, w] noised image -> [b, ch, h, w] of error (score-matching)
    def __init__(self, nblocks=11, time_embed_dim=64, ch=1, h=32, w=32, k=3): # 5 down, 1 bottleneck, 5 up 
        super().__init__()
        self.nblocks = nblocks
        self.time_embed_dim = time_embed_dim
        self.ch = ch
        self.h = h
        self.w = w
        
        self.time_embeddings = TimeEmbedding(self.time_embed_dim)
        
        self.ups = nn.Sequential(
            *[UBlock(ch, up=True, bottleneck=False, k=k) for _ in range(int(nblocks//2))]
        )
        self.bottleneck = UBlock(ch, up=False, bottleneck=True, k=k)
        self.downs = nn.Sequential(
            *[UBlock(ch, up=False, bottleneck=False, k=k) for _ in range(int(nblocks//2))]
        )
        
        self.time_to_ss = nn.ModuleList([nn.Linear(time_embed_dim, 2 * ch) for _ in range(nblocks)])


    def forward(self, x, t): # [b, ch, h, w] -> [b, ch, h, w]
        b = x.shape[0]
        h = x
        # h = F.pad(x, (2, 2, 2, 2), mode='constant', value=0) # [b, 1, 28, 28] -> [b, 1, 32, 32]
        t_embeds = self.time_embeddings(t) # add to every UBlock 
        layer_counter = 0 

        for down_block in self.downs: 
            ss_output = self.time_to_ss[layer_counter](t_embeds)  # [b, 2*ch]
            scale, shift = ss_output.chunk(2, dim=1)  # split into two [b, ch] tensors
            h = down_block(h, scale, shift)
            layer_counter += 1
        
        ss_output = self.time_to_ss[layer_counter](t_embeds)  # [b, 2*ch]
        scale, shift = ss_output.chunk(2, dim=1)  # split into two [b, ch] tensors
        h = self.bottleneck(h, scale, shift)
        layer_counter += 1

        for up_block in self.ups: 
            ss_output = self.time_to_ss[layer_counter](t_embeds)  # [b, 2*ch]
            scale, shift = ss_output.chunk(2, dim=1)  # split into two [b, ch] tensors
            h = up_block(h, scale, shift)
            layer_counter += 1

        return h # [b, ch, h, w]

    pass # Ublock(down) x N -> Ublock(up) x N with middle layers having attn 

device = 'cuda'
model = UNet().to(device)
T = 100 
model(torch.randn(4,1,32,32).to(device), torch.randint(0,T,(4,),device=device)).shape


in ublock fwd, input x has shape torch.Size([4, 1, 32, 32])
Input x shape: torch.Size([4, 1, 32, 32])
After gn1 shape: torch.Size([4, 1, 32, 32])
After silu1 shape: torch.Size([4, 1, 32, 32])
After conv1 shape: torch.Size([4, 1, 17, 17])
After gn2 shape: torch.Size([4, 1, 17, 17])
After silu2 shape: torch.Size([4, 1, 17, 17])
After conv2 shape: torch.Size([4, 1, 10, 10])


RuntimeError: The size of tensor a (10) must match the size of tensor b (15) at non-singleton dimension 3

In [None]:
# t is [b]
# freqs is [dim//2]
# t[:,None] * freqs[None, :] is [b, dim//2]
# we cat these two along dim=-1 to get [b, dim] output from TimeEmbeddings
# UNet class stores projections [b, dim] -> [b, 2*ch] for each UBlock 
# we compute t_embed = TimeEmbedding(t) once at the beginning
# then does shift_i, scale_i = projs[i](t_embeds) 
# and passes in shift_i and scale_i to UBlocks[i]

In [None]:
# TODO: noise schedule, construct every batch element with distinct noise 
# ie. one el might have x_t the other might have x_{s>t} in terms of noise
# ie. will need alphas_cumprod, etc. 
def train(model, dataloader, betas, b=16, ch=1, h=28, w=28, epochs=1, lr=3e-4, print_every_steps=10, T=200): # inputs are both [b, ch, h, w], latter is unet(real_batch + true_noise)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    opt.zero_grad()

    t = torch.randint(0, T, (b,))
    # 
    
    step = 0
    for epoch_idx in range(epochs): 
        for real_batch, _ in dataloader: 
            real_batch = real_batch.to(device=device)
            true_noise = torch.randn(b, ch, h, w) # TODO: need diff variance noise for each batch el based on time steps T choose b
            pred_noise = model(real_batch + true_noise, t) # needs t for time step embeddings 
            loss = F.mse(true_noise, pred_noise)
            loss.backward()
            opt.step()
            opt.zero_grad()

            step +=1
            if step % print_every_steps == 0: 
                print(f'Step {step}, epoch {epoch_idx}: Loss {loss.item()}')




In [None]:
if __name__ == "main": 

    # put all the other argparse stuff for real training here 
    betas = 0 
    batch_sz = 64

    model = UNet()
    dataloader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST(
            root='./data',
            train=True,
            download=True,
            transform=torchvision.transforms.ToTensor()
        ),
        batch_size=batch_sz,
        shuffle=True
    )
    
    train(model, dataloader, betas, b=batch_sz)



In [15]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

# this conv maintains dimensionality

b, ch_in, ch_out, h, w, k = 16, 3, 1, 32, 32, 3
kernel_weights = torch.randn(ch_in, ch_out, k, k)
x = torch.randn(b, ch_in, h, w)

# to do all matmuls in parallel
patches = F.unfold(x, kernel_size=k, padding=k//2) # [b, ch_in * k * k, np]
kernel_flat = kernel_weights.reshape(ch_out, -1) # [ch_out, ch_in * k * k]
out = torch.einsum('bmn,om->bon', patches, kernel_flat)# [b, ch_out, np]
out.reshape(b, ch_out, h, w)

torch.Size([16, 1, 1024])


torch.Size([16, 1, 32, 32])

RuntimeError: shape '[1, 32, 32]' is invalid for input of size 13456