In [77]:
import torch
import torch.nn as nn
from torchvision import models
from torch.nn.functional import relu
from diffusion_utils import *
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.nn.functional as F

In [78]:
class ContextUnet(nn.Module):
    def __init__(self, in_channels, n_feat=256, n_cfeat=10, height=28):  # cfeat - context features
        super(ContextUnet, self).__init__()

        # number of input channels, number of intermediate feature maps and number of classes
        self.in_channels = in_channels
        self.n_feat = n_feat
        self.n_cfeat = n_cfeat
        self.h = height  #assume h == w. must be divisible by 4, so 28,24,20,16...

        # Initialize the initial convolutional layer
        self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)
        self.e11 = nn.Conv2d(3, 64, kernel_size=3, padding=1) # output: 570x570x64
        self.e12 = nn.Conv2d(64, 64, kernel_size=3, padding=1) # output: 568x568x64
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # output: 284x284x64
        self.e21 = nn.Conv2d(64, 128, kernel_size=3, padding=1) # output: 282x282x128
        self.e22 = nn.Conv2d(128, 128, kernel_size=3, padding=1) # output: 280x280x128
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # output: 140x140x128
        self.e31 = nn.Conv2d(128, 256, kernel_size=3, padding=1) # output: 138x138x256
        self.e32 = nn.Conv2d(256, 256, kernel_size=3, padding=1) # output: 136x136x256
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) # output: 68x68x256
        
        
        

        # input: 68x68x256
        self.e41 = nn.Conv2d(256, 512, kernel_size=3, padding=1) # output: 66x66x512
        self.e42 = nn.Conv2d(512, 512, kernel_size=3, padding=1) # output: 64x64x512
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) # output: 32x32x512

        
         # original: self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())
        self.to_vec = nn.Sequential(nn.AvgPool2d((3)), nn.GELU())
        self.e51 = nn.Conv2d(512, 1024, kernel_size=3, padding=1)
        self.e52 = nn.Conv2d(1024, 1024, kernel_size=3, padding=1)
        

        # Embed the timestep and context labels with a one-layer fully connected neural network
        self.timeembed1 = EmbedFC(1, 8*n_feat)
        self.timeembed2 = EmbedFC(1, 1*n_feat)
        self.contextembed1 = EmbedFC(n_cfeat, 8*n_feat)
        self.contextembed2 = EmbedFC(n_cfeat, 8*n_feat)
        
        self.upconv1 = nn.ConvTranspose2d(512, 512, kernel_size=4, stride=4)
        self.d11 = nn.Conv2d(1024, 512, kernel_size=3, padding=1)
        self.d12 = nn.Conv2d(512, 512, kernel_size=3, padding=1)

       
        self.up1 = UnetUp(4 * n_feat, n_feat)
        self.up2 = UnetUp(2 * n_feat, n_feat)

        # Initialize the final convolutional layers to map to the same number of channels as the input image
        self.out = nn.Sequential(
            nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1), # reduce number of feature maps   #in_channels, out_channels, kernel_size, stride=1, padding=0
            nn.GroupNorm(8, n_feat), # normalize
            nn.ReLU(),
            nn.Conv2d(n_feat, self.in_channels, 3, 1, 1), # map to same number of channels as input
        )

    def forward(self, x, t, c=None):
        """
        x : (batch, n_feat, h, w) : input image
        t : (batch, n_cfeat)      : time step
        c : (batch, n_classes)    : context label
        """
        # x is the input image, c is the context label, t is the timestep, context_mask says which samples to block the context on

        # pass the input image through the initial convolutional layer
        xe11 = relu(self.e11(x))
        xe12 = relu(self.e12(xe11))
        xp1 = self.pool1(xe12)
        print(xp1.shape)

        xe21 = relu(self.e21(xp1))
        xe22 = relu(self.e22(xe21))
        xp2 = self.pool2(xe22)
        print(xp2.shape)
        xe31 = relu(self.e31(xp2))
        xe32 = relu(self.e32(xe31))
        xp3 = self.pool3(xe32)
        print(xp3.shape)
        xe41 = relu(self.e41(xp3))
        xe42 = relu(self.e42(xe41))
        xp4 = self.pool4(xe42)
        xp5 = self.pool4(xp4)
        print(f"xp5 {xp5.shape}")
        hiddenvec = self.to_vec(xp5)
        print(f"hiddenvec {hiddenvec.shape}")
        # convert the feature maps to a vector and apply an activation
        up1 = self.upconv1(hiddenvec)
        print(f"up1 {up1.shape}")
        
       
        # mask out context if context_mask == 1
        if c is None:
            c = torch.zeros(x.shape[0], self.n_cfeat).to(x)
        print(f"c {self.contextembed1(c).shape}")
        # embed context and timestep
        cemb1 = self.contextembed1(c).view(-1, self.n_feat * 8, 1, 1)     # (batch, 2*n_feat, 1,1)
        temb1 = self.timeembed1(t).view(-1, self.n_feat * 8, 1, 1)
        cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1)
        temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)
        print(f"uunet forward: cemb1 {cemb1.shape}. temb1 {temb1.shape}, cemb2 {cemb2.shape}. temb2 {temb2.shape}")
        #print(cemb1*up1 + temb1)
        
        up2 = self.up1(cemb1*up1 + temb1, xp5)  # add and multiply embeddings
        up3 = self.up2(cemb2*up2 + temb2, xp4)
        out = self.out(torch.cat((up3, x), 1))
        return out


In [79]:
# hyperparameters

# diffusion hyperparameters
timesteps = 500
beta1 = 1e-4
beta2 = 0.02

# network hyperparameters
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device('cpu'))
n_feat = 64 # 64 hidden dimension feature
n_cfeat = 5 # context vector is of size 5
height = 128 # 16x16 image
save_dir = 'weights/'

# training hyperparameters
batch_size = 32
n_epoch = 100
lrate=1e-3

In [80]:
# construct DDPM noise schedule
b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1
a_t = 1 - b_t
ab_t = torch.cumsum(a_t.log(), dim=0).exp()    
ab_t[0] = 1

In [81]:
# construct model
nn_model = ContextUnet(in_channels=3, n_feat=n_feat, n_cfeat=n_cfeat, height=height).to(device)


input_dim 1, emb_dim 512 
input_dim 1, emb_dim 64 
input_dim 5, emb_dim 512 
input_dim 5, emb_dim 512 


In [82]:
batch_size=32
dataset = CustomDataset("./wind_366X366.npy", "./wind_label_366X366.npy", transform, null_context=False)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)
optim = torch.optim.Adam(nn_model.parameters(), lr=lrate)

sprite shape: (70257, 128, 128, 3)
labels shape: (70257, 5)


In [83]:
def perturb_input(x, t, noise):
    return ab_t.sqrt()[t, None, None, None] * x + (1 - ab_t[t, None, None, None]) * noise

In [84]:
nn_model.train()

for ep in range(n_epoch):
    print(f'epoch {ep}')
    
    # linearly decay learning rate
    optim.param_groups[0]['lr'] = lrate*(1-ep/n_epoch)
    pbar = tqdm(dataloader, mininterval=2 )
    for x, _ in pbar:   # x: images
        optim.zero_grad()
        x = x.float().to(device)
        
        
        # perturb data
        noise = torch.randn_like(x)
        t = torch.randint(1, timesteps + 1, (x.shape[0],)).to(device) 
        x_pert = perturb_input(x, t, noise)
        #print(noise.shape)
        
        # use network to recover noise
        pred_noise = nn_model(x_pert, t / timesteps)
        print(pred_noise.shape)
        

epoch 0


  0%|          | 0/2196 [00:00<?, ?it/s]

torch.Size([32, 64, 64, 64])
torch.Size([32, 128, 32, 32])
torch.Size([32, 256, 16, 16])
xp5 torch.Size([32, 512, 4, 4])
hiddenvec torch.Size([32, 512, 1, 1])
up1 torch.Size([32, 512, 4, 4])
c torch.Size([32, 512])
uunet forward: cemb1 torch.Size([32, 512, 1, 1]). temb1 torch.Size([32, 512, 1, 1]), cemb2 torch.Size([256, 64, 1, 1]). temb2 torch.Size([32, 64, 1, 1])
tensor([[[[ 0.2336,  0.2326,  0.2362,  0.2342],
          [ 0.2326,  0.2352,  0.2344,  0.2351],
          [ 0.2328,  0.2331,  0.2334,  0.2340],
          [ 0.2336,  0.2352,  0.2339,  0.2323]],

         [[ 0.2319,  0.2320,  0.2318,  0.2326],
          [ 0.2325,  0.2324,  0.2325,  0.2333],
          [ 0.2312,  0.2322,  0.2318,  0.2324],
          [ 0.2324,  0.2315,  0.2324,  0.2335]],

         [[-0.0753, -0.0713, -0.0712, -0.0753],
          [-0.0731, -0.0774, -0.0757, -0.0752],
          [-0.0737, -0.0723, -0.0753, -0.0761],
          [-0.0778, -0.0720, -0.0745, -0.0733]],

         ...,

         [[-0.3549, -0.3549, -0.354

  0%|          | 0/2196 [00:00<?, ?it/s]


RuntimeError: Given transposed=1, weight of size [256, 64, 2, 2], expected input[32, 1024, 4, 4] to have 256 channels, but got 1024 channels instead