In [1]:
#Import statements
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import functools
import torchvision

from torch.optim import Adam
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.nn.modules.activation import ReLU

import matplotlib.pyplot as plt

from PIL import Image

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
IMG_SIZE = 28
batch_size = 2048

#A transform to resize, randomly flip, and scale images and convert them to tensors
transform = transforms.Compose([#transforms.Resize((IMG_SIZE, IMG_SIZE)),
                                transforms.RandomHorizontalFlip(), #Maybe don't do this?
                                transforms.ToTensor(), #Convert to tensor
                                transforms.Lambda(lambda t: (t * 2) - 1) #Scale between [-1, 1]]
])

In [None]:
# from torchvision import datasets, models, transforms

# dset = datasets.ImageFolder(f"/content/drive/MyDrive/landscapes", transform)
# dataloader = DataLoader(dset, batch_size=batch_size, shuffle=True)

In [4]:
dataset = MNIST('.', train=True, transform=transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw





In [7]:
class Diffusion:
    def __init__(self):
        self.noise_steps = 300
        self.beta_start = 0.0001
        self.beta_end = 0.02
        #self.device = device

        self.beta = self.prepare_noise_schedule().to(DEVICE)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

    def prepare_noise_schedule(self):
        return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)

    def noise_images(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None].to(DEVICE)
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None].to(DEVICE)
        Ɛ = torch.randn_like(x).to(DEVICE)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ

    def sample_timesteps(self, n):
        return torch.randint(low=1, high=self.noise_steps, size=(n,))

    def sample(self, model, n):
        model.eval()
        with torch.no_grad():
            #x = torch.randn((n, 3, IMG_SIZE, IMG_SIZE)).to(DEVICE)
            x = torch.randn((n, 1, IMG_SIZE, IMG_SIZE)).to(DEVICE)

            for i in reversed(range(1, self.noise_steps)): #why reverse?
                t = (torch.ones(n) * i).long().to(DEVICE)
                predicted_noise = model(x, t)
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                if i > 1:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
        model.train()
        x = (x.clamp(-1, 1) + 1) / 2
        x = (x * 255).type(torch.uint8)
        return x

In [8]:
#@title Helper function for plotting forward diffusion
#Plots tensor images and undos transforms
def show_tensor_image(image):
    reverse_transforms = transforms.Compose([
        transforms.Lambda(lambda t: (t + 1) / 2),
        transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
        transforms.Lambda(lambda t: t * 255.),
        transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
        transforms.ToPILImage(),
    ])

    # Take first image of batch
    if len(image.shape) == 4:
        image = image[0, :, :, :] 

    plt.imshow(reverse_transforms(image))

In [None]:
#@title Plotting forward diffusion
DEVICE = "cpu"
diffusion = Diffusion()

# Simulate forward diffusion
image = next(iter(dataloader))[0]
image.to(DEVICE)

plt.figure(figsize=(100,100))
plt.axis('off')
num_images = 10
stepsize = int(300/num_images)

for idx in range(0, 300, stepsize):
    t = torch.Tensor([idx]).type(torch.int64).to(DEVICE)
    plt.subplot(1, num_images+1, (idx/stepsize) + 1)
    image, noise = diffusion.noise_images(image, t)
    show_tensor_image(image)


In [9]:
# #Crop function 
# def img_crop(tensor, target_tensor):
#   tensor_size = tensor.size()[2]
#   target_size = target_tensor.size()[2]
#   delta = tensor_size - target_size
#   #delta = delta // 2
#   print(delta)

#   return tensor[:,:, delta:tensor_size, delta:tensor_size]

class Double_Conv(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding='same'),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same'),
        nn.ReLU()
    )
  
  def forward(self, x):
    return self.conv(x)

class Up(nn.Module):
  def __init__(self, in_channels, out_channels, emb_dim=256):
    super().__init__()
    self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
    #self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
    self.conv = Double_Conv(in_channels, out_channels)
    self.emb_layer = nn.Sequential(
        nn.SiLU(),
        nn.Linear(emb_dim, out_channels)
    )

  def forward(self, x, skip_x, t):
    x = self.up(x)
    #skip_x = img_crop(skip_x, x)
    x = torch.cat([x, skip_x], dim = 1)
    x = self.conv(x)
    emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
    return x + emb

class Down(nn.Module):
  def __init__(self, in_channels, out_channels, emb_dim=256):
    super().__init__()
    self.conv = nn.Sequential(
        nn.MaxPool2d(kernel_size=2, stride=2),
        Double_Conv(in_channels, out_channels)
    )

    self.emb_layer = nn.Sequential(
        nn.SiLU(),
        nn.Linear(emb_dim, out_channels)
    )

  def forward(self, x, t):
    x = self.conv(x)
    emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
    return x + emb

class UNet(nn.Module):
  def __init__(self, time_dim=256):
    super(UNet, self).__init__()
    self.time_dim = time_dim

    #First conv layer
    #self.first = Double_Conv(3, 32)
    self.first = Double_Conv(1, 32)


    #Down blocks
    self.down1 = Down(32, 64)
    self.down2 = Down(64, 128)
    #self.down3 = Down(128, 256)

    #Bottleneck
    #self.middle = Double_Conv(256, 256)
    self.middle = Double_Conv(128, 128)

    #Up blocks
    #self.up1 = Up(256, 128)
    self.up1 = Up(128, 64)
    self.up2 = Up(64, 32)

    #Final conv layer
    #self.last = Double_Conv(32, 3)
    self.last = Double_Conv(32, 1)

  def pos_encoding(self, t, channels):
    inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels)).to(DEVICE)
    pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq).to(DEVICE)
    pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq).to(DEVICE)
    pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1).to(DEVICE)

    return pos_enc

  def forward(self, img, t):

    #Handle positional encoding
    t = t.unsqueeze(-1)
    t = self.pos_encoding(t, self.time_dim).to(DEVICE)
    
    x1 = self.first(img)
    x2 = self.down1(x1, t)
    x3 = self.down2(x2, t)
    #x4 = self.down3(x3, t)

    #print(x4.shape)

    x4 = self.middle(x3)
    #print(x5.shape)
    #print(x3.shape)

    #x = self.up1(x5, x3, t)
    x = self.up1(x4, x2, t)
    x = self.up2(x, x1, t)
    #print(x)
    x = self.last(x)

    return(x)

In [None]:
# def save_images(images, path, **kwargs):
#     grid = torchvision.utils.make_grid(images, **kwargs)
#     ndarr = grid.permute(1, 2, 0).to('cpu').numpy()
#     im = Image.fromarray(ndarr)
#     im.save(path)

In [17]:
from torch.optim import Adam

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
#DEVICE = "cpu"


model = UNet()
model.to(DEVICE)
optimizer = Adam(model.parameters(), lr=0.0001)
epochs = 100
diffusion = Diffusion()
loss = nn.MSELoss()

all_samples = []

def train():
  for epoch in range(epochs):
      total_loss = 0
      for i, images in enumerate(dataloader):
          images = images[0]
          images = images.to(DEVICE)
          t = diffusion.sample_timesteps(images.shape[0]).to(DEVICE)
          x_t, noise = diffusion.noise_images(images, t)
          x_t.to(DEVICE)
          predicted_noise = model(x_t, t)
          fit = loss(noise, predicted_noise)
          total_loss += fit
          optimizer.zero_grad()
          fit.backward()
          optimizer.step()
          #print(f'Batch #{i} completed')

      print(f"epoch: {epoch} / loss: {total_loss}")

      sampled_images = diffusion.sample(model, n=images.shape[1])
      all_samples.append(sampled_images)
      #save_images(sampled_images, os.path.join("results", args.run_name, f"{epoch}.jpg"))
    

In [18]:
model = UNet()
train()
#model(x_t, t)



RuntimeError: ignored

In [16]:
all_samples = []
sampled_images = diffusion.sample(model, n=images.shape[1])
all_samples.append(sampled_images)
plt.imshow(all_samples[0].squeeze().cpu(), cmap='gray')

NameError: ignored

In [None]:
model = UNet()
print(sum(p.numel() for p in model.parameters()))

836631


In [15]:
# def plot_images(images):
#     plt.figure(figsize=(32, 32))
#     plt.imshow(torch.cat([
#         torch.cat([i.squeeze() for i in images], dim=-1),
#     ], dim=-2).cpu())
#     plt.show()

# plot_images(all_samples)

# plt.imshow(all_samples[0].squeeze().cpu())



for i in all_samples:
    plt.figure(figsize=(5,5))
    plt.imshow(i.squeeze().cpu(), cmap='gray')

NameError: ignored