In [1]:
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from tqdm import tqdm
from PIL import Image

In [2]:
# Hyperparameters
IMAGE_SIZE = 64
CHANNELS = 3
BATCH_SIZE = 128
EPOCHS = 100
LR = 2e-4
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
NUM_TIMESTEPS = 1000  # Number of diffusion steps

In [3]:
import pathlib

data_folder = pathlib.Path(r"C:\Users\amrul\programming\deep_learning\dl_projects\Generative_Deep_Learning_2nd_Edition\data\flower\flower_data\flower_data\train")

def scale_to_neg_one_to_one(x):
    return x * 2 - 1


# Transformations for the dataset
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),  # Converts image to [0,1]
    transforms.Lambda(scale_to_neg_one_to_one)  # Scale to [-1, 1]
])

# Replace 'data' with your dataset directory
dataset = datasets.ImageFolder(root=str(data_folder), transform=transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)

In [4]:
class SimpleUNet(nn.Module):
    def __init__(self, channels=3, base_channels=64):
        super(SimpleUNet, self).__init__()
        self.down1 = self.conv_block(channels, base_channels)
        self.down2 = self.conv_block(base_channels, base_channels*2)
        self.down3 = self.conv_block(base_channels*2, base_channels*4)
        self.down4 = self.conv_block(base_channels*4, base_channels*8)

        self.mid = self.conv_block(base_channels*8, base_channels*8)

        self.up4 = self.conv_transpose_block(base_channels*8, base_channels*4)
        self.up3 = self.conv_transpose_block(base_channels*4, base_channels*2)
        self.up2 = self.conv_transpose_block(base_channels*2, base_channels)
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(base_channels, channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

        self.time_embed = nn.Sequential(
            nn.Linear(1, base_channels*4),
            nn.ReLU(),
            nn.Linear(base_channels*4, base_channels*8)
        )

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels)
        )

    def conv_transpose_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x, t):
        # Embed time
        t = t.unsqueeze(-1).unsqueeze(-1)  # Shape: [B, 1, 1, 1]
        t_embed = self.time_embed(t.view(-1, 1))  # [B, base_channels*8]
        t_embed = t_embed.view(-1, 8 * 64, 1, 1)  # Adjust based on base_channels

        # Encoder
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)

        # Middle
        mid = self.mid(d4 + t_embed)

        # Decoder
        u4 = self.up4(mid) + d3
        u3 = self.up3(u4) + d2
        u2 = self.up2(u3) + d1
        out = self.up1(u2)
        return out

In [5]:
def linear_beta_schedule(timesteps):
    beta_start = 1e-4
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

betas = linear_beta_schedule(NUM_TIMESTEPS).to(DEVICE)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1,0), value=1.0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod)

In [6]:
def noise_images(x0, t):
    """
    Adds noise to the images at timestep t.
    """
    noise = torch.randn_like(x0).to(DEVICE)
    sqrt_alphas_cumprod_t = sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
    sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
    return sqrt_alphas_cumprod_t * x0 + sqrt_one_minus_alphas_cumprod_t * noise, noise

In [7]:
model = SimpleUNet(channels=CHANNELS).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = nn.MSELoss()

for epoch in range(EPOCHS):
    model.train()
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for batch in pbar:
        x, _ = batch
        x = x.to(DEVICE)

        # Sample random timesteps
        t = torch.randint(0, NUM_TIMESTEPS, (x.size(0),)).to(DEVICE)

        # Forward diffusion
        x_noisy, noise = noise_images(x, t)

        # Predict the noise
        noise_pred = model(x_noisy, t.float() / NUM_TIMESTEPS)

        loss = criterion(noise_pred, noise)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_postfix({'loss': loss.item()})
    
    # Optionally, save checkpoints
    torch.save(model.state_dict(), f"ddm_epoch_{epoch+1}.pth")

Epoch 1/100:   0%|          | 0/52 [00:00<?, ?it/s]

In [None]:
@torch.no_grad()
def sample(model, num_samples, image_size):
    model.eval()
    x = torch.randn(num_samples, CHANNELS, image_size, image_size).to(DEVICE)
    for t in reversed(range(NUM_TIMESTEPS)):
        t_tensor = torch.full((num_samples,), t, dtype=torch.long).to(DEVICE)
        # Predict noise
        noise_pred = model(x, t_tensor.float() / NUM_TIMESTEPS)
        
        beta_t = betas[t]
        sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alphas_cumprod[t]
        sqrt_recip_alpha_t = 1.0 / torch.sqrt(alphas[t])

        # Compute posterior mean
        x_mean = sqrt_recip_alpha_t * (x - beta_t / sqrt_one_minus_alpha_cumprod_t * noise_pred)

        if t > 0:
            noise = torch.randn_like(x)
            sigma_t = torch.sqrt(betas[t])
            x = x_mean + sigma_t * noise
        else:
            x = x_mean
    return (x.clamp(-1, 1) + 1) / 2  # Scale back to [0,1]

# Generate and save images
import torchvision.utils as vutils

generated_images = sample(model, num_samples=16, image_size=IMAGE_SIZE)
grid = vutils.make_grid(generated_images, nrow=4)
vutils.save_image(grid, 'generated_flowers.png')