In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn

from torch.nn import functional as F
import matplotlib.pyplot as plt

from torchvision.datasets import CIFAR10
from torch.utils import data
from torch.utils.data import DataLoader, Subset


import torch.optim as optim

from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import torchvision.utils as vutils


In [None]:
from models import UNet, SinusoidalPositionEmbeddings

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 64

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=False, transform=transform)

class_idx = 3
target_indices = [i for i, (_, label) in enumerate(trainset) if label == class_idx]

# Create a Subset using these indices
# filtered_dataset = Subset(trainset, target_indices[:64])
filtered_dataset = Subset(trainset, target_indices)

trainloader = torch.utils.data.DataLoader(filtered_dataset, batch_size=batch_size,
                                          shuffle=True, num_workers=24)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)


In [None]:
# Function to add noise
def add_noise(x0, t, noise):
    sqrt_alpha_cumprod = torch.sqrt(alphas_cumprod[t]).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
    sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alphas_cumprod[t]).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
    return sqrt_alpha_cumprod * x0 + sqrt_one_minus_alpha_cumprod * noise

# def generate_and_save_samples(epoch, model, device, writer):
#     model.eval()
#     with torch.no_grad():
#         # Start with a batch of random noise (generate 10 sample images for tracking)
#         sample_noise = torch.randn(10, 3, 64, 64).to(device)  # Generate 10 sample images
        
#         # List to store images at regular intervals
#         intermediate_images = []
        
#         # Gradually reverse the diffusion process
#         for t in reversed(range(num_timesteps)):
#             t_tensor = torch.tensor([t] * sample_noise.size(0), device=device).long()
#             predicted_noise = model(sample_noise, t_tensor)
#             alpha_t = torch.sqrt(alphas_cumprod[t])
#             beta_t = torch.sqrt(1 - alphas_cumprod[t])
#             sample_noise = (sample_noise - beta_t * predicted_noise) / alpha_t

#             # Optional: Clip or scale the output at each step to avoid extreme values
#             sample_noise = sample_noise.clamp(-1, 1)
            
#             # Save the images every 100 steps
#             if t % 100 == 0 or t == 0:
#                 img = (sample_noise + 1) / 2  # Convert [-1, 1] to [0, 1]
#                 img = img.clamp(0, 1)
#                 intermediate_images.append(img)

#         # Stack the saved steps to form a tensor of shape: (num_intervals, 10, 3, 64, 64)
#         # Each entry in the list is a batch of 10 images, so `torch.stack` creates this structure
#         all_steps = torch.stack(intermediate_images, dim=0)  # Shape: (num_intervals, 10, 3, 64, 64)
        
#         # Rearrange to have a single batch where each row will represent the progression of one image
#         all_steps = all_steps.permute(1, 0, 2, 3, 4)  # Shape: (10, num_intervals, 3, 64, 64)
#         all_steps = all_steps.reshape(-1, 3, 64, 64)  # Shape: (10 * num_intervals, 3, 64, 64)

#         # Create a grid where each row corresponds to the progression of one image across intervals
#         grid = vutils.make_grid(all_steps, nrow=len(intermediate_images))
#         writer.add_image(f'Progression at epoch {epoch + 1}', grid, epoch)


import torchvision.utils as vutils

def generate_and_save_samples(epoch, model, device, writer=None):
    model.eval()
    with torch.no_grad():
        # Start with a batch of random noise (generate 10 sample images for tracking)
        sample_noise = torch.randn(10, 3, 32, 32, device=device)  # CIFAR-10 images are 32x32 and RGB
        
        # List to store images at regular intervals
        intermediate_images = []
        
        # Gradually reverse the diffusion process
        for t in reversed(range(num_timesteps)):
            t_tensor = torch.tensor([t] * sample_noise.size(0), device=device).long()
            predicted_noise = model(sample_noise, t_tensor)
            
            # Scaling factors for the reverse step
            alpha_t = torch.sqrt(alphas_cumprod[t])
            beta_t = torch.sqrt(1 - alphas_cumprod[t])
            
            # Reverse step
            sample_noise = (sample_noise - beta_t * predicted_noise) / alpha_t
            
            # Optional: Clip or scale the output at each step to avoid extreme values
            sample_noise = sample_noise.clamp(-1, 1)
            
            # Save the images every 100 steps
            if t % 100 == 0 or t == 0:
                img = (sample_noise + 1) / 2  # Convert [-1, 1] to [0, 1]
                img = img.clamp(0, 1)
                intermediate_images.append(img)

        # Stack and rearrange to create a grid of images showing the progression
        all_steps = torch.stack(intermediate_images, dim=0)  # Shape: (num_intervals, 10, 3, 32, 32)
        all_steps = all_steps.permute(1, 0, 2, 3, 4).reshape(-1, 3, 32, 32)  # Shape: (10 * num_intervals, 3, 32, 32)

        # Create a grid where each row represents the progression of one image across intervals
        grid = vutils.make_grid(all_steps, nrow=len(intermediate_images))
        
        # Save to TensorBoard
        if writer:
            writer.add_image(f'Progression at epoch {epoch + 1}', grid, epoch)
        
        # Optionally save directly to disk
        vutils.save_image(grid, f"generated_samples/progression_epoch_{epoch + 1}.png")





In [None]:
# Assuming you have the modular UNet model, dataset, and noise scheduler
model = UNet(in_channels=3, out_channels=3, time_emb_dim=256)  # Modular UNet
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

# Hyperparameters
epochs = 5000
batch_size = 64
learning_rate = 1e-4
beta_start = 1e-4
beta_end = 0.02
num_timesteps = 1000

# Define the beta schedule (linear schedule in this example)
betas = torch.linspace(beta_start, beta_end, num_timesteps).to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Loss function (Mean Squared Error for noise prediction)
loss_fn = nn.MSELoss()

# TensorBoard Writer
import os

base_log_dir = "logs/ddpm_training"
run_number = 0

# Increment run number until a new directory is found
while os.path.exists(os.path.join(base_log_dir, f"run_{run_number}")):
    run_number += 1

log_dir = os.path.join(base_log_dir, f"run_{run_number}")
writer = SummaryWriter(log_dir=log_dir)



# Training loop
for epoch in range(epochs):
    model.train()
    pbar = tqdm(trainloader, desc=f"Epoch {epoch + 1}/{epochs}")

    running_loss = 0.0  # Track loss for the epoch

    for images, _ in pbar:
        images = images.to(device)
        batch_size = images.size(0)
        
        # Sample a random timestep for each image in the batch
        t = torch.randint(0, num_timesteps, (batch_size,), device=device).long()
        
        # Sample random noise
        noise = torch.randn_like(images, device=device)
        
        # Get the noisy image for time step t
        noisy_images = add_noise(images, t, noise)
        
        # Forward pass through the model
        predicted_noise = model(noisy_images, t)
        
        # Compute loss
        loss = loss_fn(predicted_noise, noise)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        # Update the progress bar with the loss
        running_loss += loss.item()
        pbar.set_postfix({"loss": loss.item()})
        # break

    # Log the average loss to TensorBoard
    avg_loss = running_loss / len(trainloader)
    writer.add_scalar("Loss/train", avg_loss, epoch)

    # Generate and log sample images to TensorBoard
    if epoch % 10 == 0:
        generate_and_save_samples(epoch, model, device, writer)

    # Optional: Save model checkpoints
    if epoch % 200 == 0:
        torch.save(model.state_dict(), f"weights/ddpm_{run_number}_epoch_{epoch + 1}.pth")
    # break

# Close the TensorBoard writer
writer.close()

In [None]:
epochs = 5000
batch_size = 64
learning_rate = 1e-4
beta_start = 1e-4
beta_end = 0.02
num_timesteps = 1000

# Define the beta schedule (linear schedule in this example)
betas = torch.linspace(beta_start, beta_end, num_timesteps).to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)


for epoch in range(epochs):
    pbar = tqdm(trainloader, desc=f"Epoch {epoch + 1}/{epochs}")
    for images, _ in pbar:
        images = images.to(device)
        batch_size = images.size(0)
        
        # Sample a random timestep for each image in the batch
        t = torch.randint(0, num_timesteps, (batch_size,), device=device).long()
        
        # Sample random noise
        noise = torch.randn_like(images).to(device)
        
        # Get the noisy image for time step t
        noisy_images = add_noise(images, t, noise)
        break
    break
        

In [None]:
noise[0].shape

In [None]:
noise[0][1]

In [None]:
combined.shape

In [None]:
img1.shape, img2.shape, img3.shape

In [None]:
combined.shape

In [None]:
for index in range(10):
    img_list = []
    img_list.append(images[index])
    
    for i in range(1000):
        # Sample random noise
        noise = torch.randn_like(images).to(device)
        
        # Get the noisy image for time step t
        noisy_images = add_noise(images, i, noise)
    
        img_list.append(noisy_images[index])
    combined = torch.stack(img_list)
    
    # Use make_grid to create a single frame
    grid = vutils.make_grid(combined, nrow=25, padding=1)
    plt.figure(figsize=(30,60))
    plt.imshow(grid.permute(1, 2, 0).numpy())
    plt.axis('off')
    plt.show()

In [None]:
from diffusers import UNet2DConditionModel

# Load the UNet2DConditionModel
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")

# Print the structure of the UNet model to locate attention components
print(unet)
