In [1]:
import torch.nn.functional as F
from torch import nn
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import math
import UNet

In [2]:
def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    return torch.linspace(start, end, timesteps)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

IMG_SIZE = 64
BATCH_SIZE = 128
EPOCHS = 1000

# Define beta schedule
T = 200
betas = linear_beta_schedule(timesteps=T).to(device)

# Pre-calculate different terms for closed form
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
sqrt_recip_alphas = torch.sqrt(1.0/ alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

In [4]:
def load_transformed_dataset():
    data_transforms = [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), # Scales data into [0, 1]
        transforms.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1]
    ]

    data_transform = transforms.Compose(data_transforms)

    train = torchvision.datasets.FGVCAircraft(root=".", download=True, transform=data_transform)

    test = torchvision.datasets.FGVCAircraft(root=".", download=True, transform=data_transform, split='test')

    return torch.utils.data.ConcatDataset([train, test])

def show_tensor_image(image):
    reverse_transforms = transforms.Compose([
        transforms.Lambda(lambda t: (t + 1) / 2),
        transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
        transforms.Lambda(lambda t: t * 255.),
        transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
        transforms.ToPILImage(),
    ])

    # take first image of batch
    if len(image.shape) == 4:
        image = image[0, :, :, :]
    plt.imshow(reverse_transforms(image))

data = load_transformed_dataset()
dataloader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)


In [5]:
def forward_diffusion_sample(image, t):
    noise = torch.randn(3, 64, 64)
    image_noised = (sqrt_alphas_cumprod[t] * image) + (sqrt_one_minus_alphas_cumprod[t] * noise)
    return image_noised, noise

In [6]:
model = UNet.UNetModel(32).to(device)

In [7]:
# Training algorithm
# x_0 ~ q(x_0)
# t ~ Uniform({1, ...., T})
# eps ~ N(0, I)
# Take graident descent step on MSE(eps - eps_theta(root(alpha_t_bar)*x_0 + root(1-alpha_t_bar)*eps, t))
# Until converged

# Sampling algorithm
# for t = T, ... t do
# z ~ N(0, I) if t > 1, else z =0
# x_t-1 = 1/root(alpha_t) * (x_t - (1-alpha_t)/(root(1-alpha_t))*eps_theta(x_t, t)) + sigma_t*z
# end for
# return x_0




for epoch in range(EPOCHS):
    for step, (images, labels) in enumerate(dataloader):
        images = images.to(device)
        t = torch.randint(1, T+1, (BATCH_SIZE,)).to(device)
        epsilon = torch.randn_like(images)


        noised_samples = (sqrt_alphas_cumprod[t - 1][(...,) + (None,) * 3] * images) + (sqrt_one_minus_alphas_cumprod[t - 1][(...,) + (None,) * 3] * epsilon)
        # [(...,) + (None,) * 3] is used so that I can turn the tensor of shape N to N 1 1 1, which allows me to multiply it with images and epsilon which have shape N C W H

        predicted_epsilon = model(noised_samples, t)

        print(predicted_epsilon)
        


tensor([[0.0966],
        [0.0582],
        [0.0769],
        [0.0525],
        [0.0956],
        [0.0905],
        [0.0465],
        [0.0651],
        [0.0974],
        [0.1098],
        [0.0455],
        [0.0799],
        [0.0649],
        [0.0809],
        [0.0834],
        [0.0164],
        [0.0891],
        [0.0856],
        [0.0473],
        [0.0940],
        [0.0479],
        [0.0828],
        [0.1096],
        [0.0506],
        [0.0792],
        [0.0746],
        [0.0451],
        [0.0250],
        [0.0952],
        [0.0423],
        [0.0658],
        [0.0594],
        [0.0381],
        [0.0531],
        [0.0652],
        [0.0567],
        [0.0443],
        [0.0915],
        [0.0796],
        [0.0439],
        [0.0858],
        [0.1133],
        [0.0568],
        [0.0422],
        [0.0795],
        [0.0400],
        [0.0359],
        [0.0194],
        [0.1007],
        [0.0968],
        [0.0621],
        [0.0686],
        [0.0590],
        [0.0509],
        [0.0531],
        [0

KeyboardInterrupt: 