In [None]:
import sys
sys.path.append('../')

In [None]:
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
import torchvision
from diffusers import UNet2DModel

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
IMG_SIZE = 256

In [None]:
# Transformations
transform = transforms.Compose(
    [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Lambda(lambda t: (t * 2) - 1),
    ]
)

undo_transform = transforms.Compose(
    [
        transforms.Lambda(lambda t: (t + 1) / 2),
        transforms.Lambda(lambda t: t.permute(1, 2, 0)),
        transforms.Lambda(lambda t: (t * 255.0).numpy().astype(np.uint8)),
    ]
)

In [None]:
class diffusion:
    def __init__(self, timesteps, scheduler_type=None):
        betas = torch.linspace(0.0001, 0.02, timesteps)

        alphas = 1 - betas
        sqrt_alphas = torch.sqrt(alphas)
        alpha_hats = torch.cumprod(1 - betas, axis=0)
        sqrt_alpha_hats = torch.sqrt(alpha_hats)

        sqrt_one_minus_alpha_hats = torch.sqrt(1 - alpha_hats)

        self.betas = betas
        self.alphas = alphas
        self.sqrt_alphas = sqrt_alphas
        self.alpha_hats = alpha_hats
        self.sqrt_alpha_hats = sqrt_alpha_hats
        self.sqrt_one_minus_alpha_hats = sqrt_one_minus_alpha_hats

    def noise_schedule(self, x_0, t, device="cpu"):
        noise = torch.randn(size=x_0.shape)

        alpha_hats_t = self.alpha_hats[t]
        alpha_hats_t = alpha_hats_t.reshape(shape=(t.shape[0], 1, 1, 1))

        sqrt_alpha_hats_t = self.sqrt_alpha_hats[t]
        sqrt_alpha_hats_t = sqrt_alpha_hats_t.reshape(shape=(t.shape[0], 1, 1, 1))

        sqrt_one_minus_alpha_hats_t = self.sqrt_one_minus_alpha_hats[t]
        sqrt_one_minus_alpha_hats_t = sqrt_one_minus_alpha_hats_t.reshape(
            shape=(t.shape[0], 1, 1, 1)
        )

        mean = sqrt_alpha_hats_t * x_0
        variance = sqrt_one_minus_alpha_hats_t * noise

        x_ts = mean + variance

        return x_ts.to(device), noise.to(device)

    @torch.no_grad()
    def sample(self, x_t, t, prediction):
        x_t = x_t.cpu()
        t = t.cpu()
        prediction = prediction.cpu()

        betas_t = self.betas[t]
        sqrt_alphas_t = self.sqrt_alphas[t]

        sqrt_one_minus_alpha_hats_t = self.sqrt_one_minus_alpha_hats[t]

        x_t_minus_one = (1 / sqrt_alphas_t) * (
            x_t - ((betas_t * prediction) / sqrt_one_minus_alpha_hats_t)
        )

        if t == 0:
            return x_t_minus_one
        else:
            noise = torch.randn_like(x_t)
            return x_t_minus_one + torch.sqrt(betas_t) * noise
        
    def create_image(self, model, img_size=32, timesteps=300, num_channels=3):
        with torch.no_grad():
            img = torch.randn((1, num_channels, img_size, img_size))
            plt.figure(figsize=(15, 2))
            plt.axis("off")
            num_images = 10
            stepsize = int(timesteps / num_images)

            for i in range(0, timesteps)[::-1]:
                t = torch.full((1,), i, dtype=torch.long)

                img = img.to(device)
                t = t.to(device)

                prediction = model(img, t).sample

                img = self.sample(img, t, prediction)

                img = torch.clamp(img, -1.0, 1.0)
                if i % stepsize == 0:
                    plt.subplot(1, num_images, int(i / stepsize) + 1)
                    plt.imshow(undo_transform(img.cpu()[0]))
            plt.show()
            return img.cpu()[0]

In [None]:
image_data = torchvision.datasets.FGVCAircraft("../data", download=True, transform=transform)

In [None]:
data_loader = torch.utils.data.DataLoader(image_data, batch_size=16, shuffle=True)
images, labels = next(iter(data_loader))

plt.imshow(undo_transform(images[0]))

In [None]:
BATCH_SIZE = 128
NO_EPOCHS = 50
TIME_STEPS = 500

diffuser = diffusion(timesteps=TIME_STEPS)

In [None]:
model = UNet2DModel(
    sample_size=64,  # the target image resolution
    in_channels=3,  # the number of input channels, 3 for RGB images
    out_channels=3,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(IMG_SIZE, IMG_SIZE, IMG_SIZE*2, IMG_SIZE*2, IMG_SIZE*4, IMG_SIZE*4),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)

model.to(device)

dataloader = torch.utils.data.DataLoader(
    image_data, batch_size=BATCH_SIZE, shuffle=True
)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

for epoch in range(NO_EPOCHS):
    for step, batch in enumerate(dataloader):
        optimizer.zero_grad()

        t = torch.randint(0, TIME_STEPS, (len(batch[0]),), device=device).long()

        img_batch_noisy, noise_batch = diffuser.noise_schedule(
            x_0=batch[0], t=t.cpu(), device=device
        )

        predicted_noise_batch = model(img_batch_noisy, t)

        loss = loss_fn(predicted_noise_batch.sample, noise_batch)

        loss.backward()
        optimizer.step()

        if step == 0:
            print(f"Epoch {epoch} | step {step:03d} Loss: {loss.item()} ")
            print(f"Using {torch.cuda.mem_get_info()[0]/1000000000:.2f} GB")
            diffuser.create_image(model, img_size=IMG_SIZE, timesteps=TIME_STEPS)

model.save_pretrained("model_FGVCAircraft_checkpoint", from_pt=True)

In [None]:
model_new = UNet2DModel().from_pretrained('./model_FGVCAircraft_checkpoint/')
model_new = model_new.to(device)

In [None]:
img = diffuser.create_image(model_new, img_size=IMG_SIZE, timesteps=TIME_STEPS)

In [None]:
img.cpu()
plt.imshow(undo_transform(img.cpu()))