In [1]:
import torch.nn.functional as F
from torch import nn
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import math
import UNet

In [2]:
def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    return torch.linspace(start, end, timesteps)


def quadratic_beta_schedule(timesteps, start=0.0001, end=0.02):
    # Generate a linear space from 0 to 1
    linear_space = torch.linspace(0, 1, timesteps)
    
    # Apply a quadratic transformation
    quadratic_space = linear_space ** 2
    
    # Scale and shift the quadratic space to start and end at the specified values
    beta_values = start + (end - start) * quadratic_space
    
    return beta_values

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

IMG_SIZE = 64
BATCH_SIZE = 256
EPOCHS = 5000

# Define beta schedule
T = 2000
betas = quadratic_beta_schedule(timesteps=T).to(device)

# Pre-calculate different terms for closed form
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
sqrt_recip_alphas = torch.sqrt(1.0/ alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

In [4]:
def load_transformed_dataset():
    data_transforms = [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), # Scales data into [0, 1]
        transforms.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1]
    ]

    data_transform = transforms.Compose(data_transforms)

    train = torchvision.datasets.FGVCAircraft(root=".", download=True, transform=data_transform)

    test = torchvision.datasets.FGVCAircraft(root=".", download=True, transform=data_transform, split='test')

    return torch.utils.data.ConcatDataset([train, test])

def show_tensor_image(image):
    reverse_transforms = transforms.Compose([
        transforms.Lambda(lambda t: (t + 1) / 2),
        transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
        transforms.Lambda(lambda t: t * 255.),
        transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
        transforms.ToPILImage(),
    ])

    # take first image of batch
    if len(image.shape) == 4:
        image = image[0, :, :, :]
    plt.imshow(reverse_transforms(image))

data = load_transformed_dataset()
dataloader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, prefetch_factor=20, num_workers=4)


In [5]:
model = UNet.UNetModel(32).to(device)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(weight_decay=1e-4, params=model.parameters(), lr=0.05)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=5, gamma=0.5)

In [8]:

sum(p.numel() for p in model.parameters() if p.requires_grad)

22686

In [7]:
# Training algorithm
# x_0 ~ q(x_0)
# t ~ Uniform({1, ...., T})
# eps ~ N(0, I)
# Take graident descent step on MSE(eps - eps_theta(root(alpha_t_bar)*x_0 + root(1-alpha_t_bar)*eps, t))
# Until converged

for epoch in range(EPOCHS):
    # to keep track of accumulated accuracy
    epoch_loss = 0

    for step, (images, labels) in enumerate(dataloader):
        images = images.to(device)
        t = torch.randint(1, T+1, (BATCH_SIZE,)).to(device)
        epsilon = torch.randn_like(images)

        noised_samples = (sqrt_alphas_cumprod[t - 1][(...,) + (None,) * 3] * images) + (sqrt_one_minus_alphas_cumprod[t - 1][(...,) + (None,) * 3] * epsilon)
        # [(...,) + (None,) * 3] is used so that I can turn the tensor of shape N to N 1 1 1, which allows me to multiply it with images and epsilon which have shape N C W H

        # Forward pass
        predicted_epsilon = model(noised_samples, t)
        loss = criterion(predicted_epsilon, epsilon)
        epoch_loss += loss

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (step+1) % 12 == 0:
            print(f'Epoch [{epoch+1}/{EPOCHS}], Step [{step+1}/{len(dataloader)}], Loss: {loss.item():.3f}')
    
    print(f'Epoch [{epoch+1}/{EPOCHS}], Accumulated_Loss: {(epoch_loss.item() / (len(dataloader))):.3f}')
    scheduler.step()


Epoch [1/5000], Step [12/39], Loss: 1.000
Epoch [1/5000], Step [24/39], Loss: 0.987
Epoch [1/5000], Step [36/39], Loss: 0.885
Epoch [1/5000], Accumulated_Loss: 0.969


KeyboardInterrupt: 

In [None]:
def forward_diffusion_sample(x_0, t, device="cpu"):
    noise = torch.randn(1, 3, 64, 64).to(device)
    xt = (sqrt_alphas_cumprod[t - 1][(...,) + (None,) * 3] * x_0) + (sqrt_one_minus_alphas_cumprod[t - 1][(...,) + (None,) * 3] * noise)

    return xt, noise

In [None]:
dataiter = iter(dataloader)

In [None]:
images, labels = next(dataiter)

t = 300

x_current = images[0].unsqueeze(0).to(device)
x_current, noise = forward_diffusion_sample(x_current, t, device)

show_tensor_image(x_current.to('cpu'))
show_tensor_image(x_current.to('cpu'))

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fdeb53855e0>
Traceback (most recent call last):
  File "/home/yanisf/.pyenv/versions/3.9.8/envs/deep-learning-3.9.8/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
    self._shutdown_workers()
  File "/home/yanisf/.pyenv/versions/3.9.8/envs/deep-learning-3.9.8/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1443, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/yanisf/.pyenv/versions/3.9.8/lib/python3.9/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/home/yanisf/.pyenv/versions/3.9.8/lib/python3.9/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
  File "/home/yanisf/.pyenv/versions/3.9.8/lib/python3.9/multiprocessing/connection.py", line 936, in wait
    ready = selector.select(timeout)
  File "/home/yanisf/.pyenv/versions/3.9.8/lib/py

In [None]:
# Sampling algorithm
# for t = T, ... 1 do
# z ~ N(0, I) if t > 1, else z =0
# x_t-1 = 1/root(alpha_t) * (x_t - (1-alpha_t)/(root(1-alpha_t))*eps_theta(x_t, t)) + sigma_t*z
# end for
# return x_0

with torch.no_grad():
    x_current = torch.randn(1, 3, 64, 64).to(device)
    for i in range(T, 0, -1):
        z = torch.randn(1, 3, 64, 64).to(device) if i > 0 else torch.zeros(1, 3, 64, 64).to(device)
        epsilon_theta = model(x_current, torch.tensor([i - 1], device=device))
        x_prev = sqrt_recip_alphas[i - 1] * (x_current - (((1 - alphas[i - 1])/(sqrt_one_minus_alphas_cumprod[i-1]))*epsilon_theta))

    show_tensor_image(x_prev.cpu())

AttributeError: 'DataLoader' object has no attribute 'sapl'