In [85]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


In [3]:
T = 50
beta_schedule = torch.linspace(0, 1.0, T+1).to('cuda')

print(beta_schedule)

tensor([0.0000, 0.0200, 0.0400, 0.0600, 0.0800, 0.1000, 0.1200, 0.1400, 0.1600,
        0.1800, 0.2000, 0.2200, 0.2400, 0.2600, 0.2800, 0.3000, 0.3200, 0.3400,
        0.3600, 0.3800, 0.4000, 0.4200, 0.4400, 0.4600, 0.4800, 0.5000, 0.5200,
        0.5400, 0.5600, 0.5800, 0.6000, 0.6200, 0.6400, 0.6600, 0.6800, 0.7000,
        0.7200, 0.7400, 0.7600, 0.7800, 0.8000, 0.8200, 0.8400, 0.8600, 0.8800,
        0.9000, 0.9200, 0.9400, 0.9600, 0.9800, 1.0000], device='cuda:0')


In [90]:
def conv3x3(in_feat, out_ch):
    if len(in_feat.size()) == 3:
        ch, _, _ = in_feat.size()
    else:
        _, ch, _, _ = in_feat.size()

    conv_layer = nn.Conv2d(in_channels=ch, out_channels=out_ch, kernel_size=3, stride=1, padding=1).to(in_feat.device)
    return conv_layer(in_feat)

def dense(in_feat, out_ch):
    if len(in_feat.size()) == 2:
        _, ch = in_feat.size()
    else:
        _, _, ch = in_feat.size()
    
    dense_layer = nn.Linear(in_features=ch, out_features=out_ch).to(in_feat.device)

    return dense_layer(in_feat)

class time_embedding(nn.Module):
    def __init__(self, out_ch):
        super().__init__()

        self.out_ch = out_ch
        self.relu = nn.ReLU()
        self.batchnorm = nn.BatchNorm2d(num_features=out_ch)

    def forward(self, x_img, x_ts):
        out_ch = self.out_ch
        x_parameter = conv3x3(x_img, out_ch)
        x_parameter = self.relu(x_parameter)

        x_ts = x_ts.view(-1, 1, 1).float()
        time_parameter = dense(x_ts, out_ch)
        time_parameter = self.relu(time_parameter)
        time_parameter = time_parameter.view(-1, out_ch, 1, 1)
        x_parameter = x_parameter * time_parameter

        x_out = conv3x3(x_parameter, out_ch)
        x_out = x_out + x_parameter
        x_out = self.batchnorm(x_out)
        x_out = self.relu(x_out)

        return x_out

In [91]:
class Unet(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

In [92]:
def forward_diffusion (x0, t):
    alphas = 1. - beta_schedule
    alpha_bars = alphas.cumprod(dim=0)
    
    epsilon = torch.randn_like(x0)

    alpha_bar_t = torch.gather(alpha_bars, dim=0, index=t)
    alpha_bar_t = torch.reshape(alpha_bar_t,(-1,1,1,1))
    
    noisy_image = torch.sqrt(alpha_bar_t)*x0 + torch.sqrt(1 - alpha_bar_t)*epsilon
    return noisy_image, epsilon

In [93]:
x = torch.randn(size=[4, 1, 28, 28]).to('cuda')

# ts = np.random.randint(0, T, size=len(x))
ts = torch.randint(0, T, size=(len(x),), device='cuda')
print(ts.shape)
x_t, epsilon = forward_diffusion(x, ts)
print(x_t.shape, epsilon.shape)
test1 = conv3x3(x, 192).to('cuda')
# test2 = dense(x_t, 192)
print(test1.shape)
# print(test2.shape)

torch.Size([4])
torch.Size([4, 1, 28, 28]) torch.Size([4, 1, 28, 28])
torch.Size([4, 192, 28, 28])


In [94]:
# test_lin = torch.randn(size=(4, ))
# test_lin = test_lin.view(-1, 1, 1)
# print(test_lin)
# print(test_lin.shape)
# tt = dense(test_lin, 192)

# print(tt.shape)
time_embedding_layer = time_embedding(192).to('cuda')
time_test = time_embedding_layer.forward(x, ts)