In [1]:
!pip install datasets


Collecting datasets
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.6.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.5/491.5 kB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m15.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m19.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading 

In [2]:
from datasets import load_dataset
from torchvision import transforms
from torch.utils.data import DataLoader

In [3]:
image_size = 128
batch_size = 32

In [4]:
# Pikachu dataset

# Please upload your kaggle.json file
from google.colab import files
files.upload()

! mkdir ~/.kaggle
! cp kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json

# Download dataset and unzip
# !kaggle datasets list -s pikachu
!kaggle datasets download -d hal0samuel/pikachu-classification-dataset
!unzip -q pikachu-classification-dataset.zip
!rm -r pikachu_dataset/train/not_pikachu    # remove non-pikachu image

# Load dataset
dataset_name = "./pikachu_dataset"
dataset = load_dataset(dataset_name, split="train")
dataset = [image.convert("RGB") for image in dataset["image"]]

# Preprocess image
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])

# Apply preprocess
# dataset = [transform(image).cuda() for image in dataset]
dataset = [transform(image) for image in dataset]

# Generate train data
train_data = DataLoader(dataset, batch_size=batch_size, shuffle=True)

Saving kaggle.json to kaggle.json
Dataset URL: https://www.kaggle.com/datasets/hal0samuel/pikachu-classification-dataset
License(s): unknown
Downloading pikachu-classification-dataset.zip to /content
100% 64.9M/64.9M [00:02<00:00, 33.2MB/s]
100% 64.9M/64.9M [00:02<00:00, 22.9MB/s]


Resolving data files:   0%|          | 0/387 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/258 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/258 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/387 [00:00<?, ?files/s]

Downloading data:   0%|          | 0/258 [00:00<?, ?files/s]

Downloading data:   0%|          | 0/258 [00:00<?, ?files/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [5]:
# Implementation of Latent Diffusion Model for Limited Dataset
# Based on the project proposal for ECE285, UCSD

import os
import numpy as np
import math
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torchvision import transforms
from torchvision.utils import save_image, make_grid

from datasets import load_dataset
import matplotlib.pyplot as plt

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Configuration
channels = 3
num_epochs = 80  # As mentioned in your proposal
learning_rate = 1e-4
latent_dim = 32  # Dimensionality of the latent space

# Data preparation was already done in your initial cells
# We'll continue with the model implementation

# 1. VAE Implementation (For encoding images to latent space)
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.norm1 = nn.GroupNorm(8, out_channels)
        self.norm2 = nn.GroupNorm(8, out_channels)
        self.act = nn.SiLU()

        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        residual = x

        x = self.conv1(x)
        x = self.norm1(x)
        x = self.act(x)

        x = self.conv2(x)
        x = self.norm2(x)

        return self.act(x + self.shortcut(residual))

class Encoder(nn.Module):
    def __init__(self, in_channels=3, latent_dim=32):
        super(Encoder, self).__init__()
        self.init_conv = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1)

        # Downsampling blocks
        self.down1 = nn.Sequential(
            ResidualBlock(64, 64),
            ResidualBlock(64, 64),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)  # 64 -> 32
        )

        self.down2 = nn.Sequential(
            ResidualBlock(128, 128),
            ResidualBlock(128, 128),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)  # 32 -> 16
        )

        self.down3 = nn.Sequential(
            ResidualBlock(256, 256),
            ResidualBlock(256, 256),
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)  # 16 -> 8
        )

        # Bottleneck
        self.bottleneck = nn.Sequential(
            ResidualBlock(512, 512),
            ResidualBlock(512, 512)
        )

        # Mean and variance for latent space
        self.fc_mu = nn.Conv2d(512, latent_dim, kernel_size=1)
        self.fc_var = nn.Conv2d(512, latent_dim, kernel_size=1)

    def forward(self, x):
        x = self.init_conv(x)
        x = self.down1(x)
        x = self.down2(x)
        x = self.down3(x)
        x = self.bottleneck(x)

        mu = self.fc_mu(x)
        log_var = self.fc_var(x)

        return mu, log_var

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z

class Decoder(nn.Module):
    def __init__(self, latent_dim=32, out_channels=3):
        super(Decoder, self).__init__()

        # Initial processing of latent vector
        self.init_conv = nn.Conv2d(latent_dim, 512, kernel_size=1)

        # Bottleneck
        self.bottleneck = nn.Sequential(
            ResidualBlock(512, 512),
            ResidualBlock(512, 512)
        )

        # Upsampling blocks
        self.up1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            ResidualBlock(256, 256),
            ResidualBlock(256, 256)
        )

        self.up2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            ResidualBlock(128, 128),
            ResidualBlock(128, 128)
        )

        self.up3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            ResidualBlock(64, 64),
            ResidualBlock(64, 64)
        )

        # Final output
        self.final = nn.Conv2d(64, out_channels, kernel_size=3, padding=1)

    def forward(self, z):
        x = self.init_conv(z)
        x = self.bottleneck(x)

        x = self.up1(x)
        x = self.up2(x)
        x = self.up3(x)

        return torch.tanh(self.final(x))

class VAE(nn.Module):
    def __init__(self, in_channels=3, latent_dim=32):
        super(VAE, self).__init__()
        self.encoder = Encoder(in_channels, latent_dim)
        self.decoder = Decoder(latent_dim, in_channels)
        self.latent_dim = latent_dim

    def forward(self, x):
        mu, log_var = self.encoder(x)
        z = self.encoder.reparameterize(mu, log_var)
        return self.decoder(z), mu, log_var

    def encode(self, x):
        mu, log_var = self.encoder(x)
        return self.encoder.reparameterize(mu, log_var)

    def decode(self, z):
        return self.decoder(z)

# 2. U-Net Implementation for Diffusion Model
class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp = nn.Linear(time_emb_dim, out_ch)
        if up:
            self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU()

    def forward(self, x, t):
        # First Conv
        h = self.bnorm1(self.relu(self.conv1(x)))
        # Time embedding
        time_emb = self.relu(self.time_mlp(t))
        # Extend time embeddings to match feature map dimensions
        time_emb = time_emb[(..., ) + (None, ) * 2]
        # Add time embedding
        h = h + time_emb
        # Second Conv
        h = self.bnorm2(self.relu(self.conv2(h)))
        # Down or Upsample
        return self.transform(h)

class SimpleUNet(nn.Module):
    def __init__(self, in_channels=32, hidden_dim=64):
        super(SimpleUNet, self).__init__()

        # Time embedding
        self.time_embedding = nn.Sequential(
            SinusoidalPositionEmbeddings(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim * 4),
            nn.SiLU(),
            nn.Linear(hidden_dim * 4, hidden_dim * 4)
        )

        # Encoder
        self.inc = nn.Sequential(
            nn.Conv2d(in_channels, hidden_dim, kernel_size=3, padding=1),
            nn.GroupNorm(8, hidden_dim),
            nn.SiLU()
        )

        self.down1 = nn.Sequential(
            nn.Conv2d(hidden_dim, hidden_dim * 2, kernel_size=3, stride=2, padding=1),
            nn.GroupNorm(8, hidden_dim * 2),
            nn.SiLU(),
            nn.Conv2d(hidden_dim * 2, hidden_dim * 2, kernel_size=3, padding=1),
            nn.GroupNorm(8, hidden_dim * 2),
            nn.SiLU()
        )

        self.down2 = nn.Sequential(
            nn.Conv2d(hidden_dim * 2, hidden_dim * 4, kernel_size=3, stride=2, padding=1),
            nn.GroupNorm(8, hidden_dim * 4),
            nn.SiLU(),
            nn.Conv2d(hidden_dim * 4, hidden_dim * 4, kernel_size=3, padding=1),
            nn.GroupNorm(8, hidden_dim * 4),
            nn.SiLU()
        )

        # Bottleneck
        self.bot1 = nn.Sequential(
            nn.Conv2d(hidden_dim * 4, hidden_dim * 8, kernel_size=3, padding=1),
            nn.GroupNorm(8, hidden_dim * 8),
            nn.SiLU()
        )

        self.bot2 = nn.Sequential(
            nn.Conv2d(hidden_dim * 8, hidden_dim * 8, kernel_size=3, padding=1),
            nn.GroupNorm(8, hidden_dim * 8),
            nn.SiLU()
        )

        self.bot3 = nn.Sequential(
            nn.Conv2d(hidden_dim * 8, hidden_dim * 4, kernel_size=3, padding=1),
            nn.GroupNorm(8, hidden_dim * 4),
            nn.SiLU()
        )

        # Time embeddings for each layer
        self.time_emb1 = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_dim * 4, hidden_dim * 2)
        )

        self.time_emb2 = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_dim * 4, hidden_dim * 4)
        )

        self.time_emb3 = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_dim * 4, hidden_dim * 8)
        )

        # Decoder
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=2, padding=1),
            nn.GroupNorm(8, hidden_dim * 2),
            nn.SiLU(),
            nn.Conv2d(hidden_dim * 2, hidden_dim * 2, kernel_size=3, padding=1),
            nn.GroupNorm(8, hidden_dim * 2),
            nn.SiLU()
        )

        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(hidden_dim * 2, hidden_dim, kernel_size=4, stride=2, padding=1),
            nn.GroupNorm(8, hidden_dim),
            nn.SiLU(),
            nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
            nn.GroupNorm(8, hidden_dim),
            nn.SiLU()
        )

        self.outc = nn.Conv2d(hidden_dim, in_channels, kernel_size=3, padding=1)

    def forward(self, x, t):
        # Time embedding
        t_emb = self.time_embedding(t)

        # Encoder
        x1 = self.inc(x)

        # Process and save features for skip connections
        x2 = self.down1(x1)

        # Add time embedding
        t1 = self.time_emb1(t_emb)
        t1 = t1[:, :, None, None]  # Add dimensions for broadcast
        x2 = x2 + t1

        x3 = self.down2(x2)

        # Add time embedding
        t2 = self.time_emb2(t_emb)
        t2 = t2[:, :, None, None]  # Add dimensions for broadcast
        x3 = x3 + t2

        # Bottleneck
        x4 = self.bot1(x3)

        # Add time embedding
        t3 = self.time_emb3(t_emb)
        t3 = t3[:, :, None, None]  # Add dimensions for broadcast
        x4 = x4 + t3

        x4 = self.bot2(x4)
        x4 = self.bot3(x4)

        # Decoder (no skip connections to avoid dimension issues)
        x = self.up1(x4)
        x = self.up2(x)

        # Output projection
        return self.outc(x)

# 3. Diffusion Model implementation
class LatentDiffusion(nn.Module):
    def __init__(self, vae, unet=None, n_steps=1000, beta_start=1e-4, beta_end=0.02):
        super().__init__()
        self.vae = vae

        # If no UNet is provided, create a BasicDiffusionModel
        if unet is None:
            self.unet = BasicDiffusionModel(latent_dim=32, time_steps=n_steps)
        else:
            self.unet = unet

        self.n_steps = n_steps
        self.device = next(vae.parameters()).device

        # Define beta schedule
        self.register_buffer('betas', torch.linspace(beta_start, beta_end, n_steps))
        self.register_buffer('alphas', 1. - self.betas)
        self.register_buffer('alphas_cumprod', torch.cumprod(self.alphas, dim=0))

        # Calculations for diffusion q(x_t | x_{t-1}) and others
        alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
        self.register_buffer('sqrt_recip_alphas', torch.sqrt(1.0 / self.alphas))
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - self.alphas_cumprod))

        # Calculations for posterior q(x_{t-1} | x_t, x_0)
        self.register_buffer('posterior_variance', self.betas * (1. - alphas_cumprod_prev) / (1. - self.alphas_cumprod))

    def q_sample(self, x_start, t, noise=None):
        """
        Forward diffusion process
        """
        if noise is None:
            noise = torch.randn_like(x_start)

        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t]
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t]

        # Reshape for proper broadcasting
        sqrt_alphas_cumprod_t = sqrt_alphas_cumprod_t.view(-1, 1, 1, 1)
        sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod_t.view(-1, 1, 1, 1)

        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

    def p_sample_loop(self, shape):
        """
        Generate samples from the model
        """
        device = next(self.parameters()).device
        b = shape[0]

        # Start from pure noise
        img = torch.randn(shape, device=device)

        for i in tqdm(reversed(range(0, self.n_steps)), desc='Sampling', total=self.n_steps):
            t = torch.full((b,), i, device=device, dtype=torch.long)
            with torch.no_grad():
                predicted_noise = self.unet(img, t)
                # Simple sampling without fancy variance scheduling
                alpha = self.alphas[i]
                alpha_bar = self.alphas_cumprod[i]
                beta = self.betas[i]

                if i > 0:
                    noise = torch.randn_like(img)
                else:
                    noise = torch.zeros_like(img)

                # Reshape for broadcasting
                alpha = alpha.view(-1, 1, 1, 1)
                alpha_bar = alpha_bar.view(-1, 1, 1, 1)
                beta = beta.view(-1, 1, 1, 1)

                # Formula for x_{t-1} given x_t and predicted noise
                img = (img - (1 - alpha) / torch.sqrt(1 - alpha_bar) * predicted_noise) / torch.sqrt(alpha)
                img = img + torch.sqrt(beta) * noise

        return img

    def forward(self, x, encode=True):
        """
        Training step
        """
        if encode:
            with torch.no_grad():
                x = self.vae.encode(x)

        batch_size = x.shape[0]
        device = x.device

        # Sample timesteps uniformly
        t = torch.randint(0, self.n_steps, (batch_size,), device=device).long()

        # Sample noise
        noise = torch.randn_like(x)

        # Apply forward diffusion
        x_noisy = self.q_sample(x, t, noise=noise)

        # Predict noise
        noise_pred = self.unet(x_noisy, t)

        return noise, noise_pred

    def sample(self, batch_size=1, img_size=8):
        """
        Generate samples
        """
        shape = (batch_size, latent_dim, img_size, img_size)
        latent_samples = self.p_sample_loop(shape)

        # Decode the latent samples
        with torch.no_grad():
            samples = self.vae.decode(latent_samples)

        return samplesregister_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - self.alphas_cumprod))

        # Calculations for posterior q(x_{t-1} | x_t, x_0)
        self.register_buffer('posterior_variance', self.betas * (1. - alphas_cumprod_prev) / (1. - self.alphas_cumprod))

    def q_sample(self, x_start, t, noise=None):
        """
        Forward diffusion process
        """
        if noise is None:
            noise = torch.randn_like(x_start)

        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t]
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t]

        # Reshape for proper broadcasting
        sqrt_alphas_cumprod_t = sqrt_alphas_cumprod_t.view(-1, 1, 1, 1)
        sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod_t.view(-1, 1, 1, 1)

        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

    def p_sample_loop(self, shape):
        """
        Generate samples from the model
        """
        device = next(self.parameters()).device
        b = shape[0]

        # Start from pure noise
        img = torch.randn(shape, device=device)

        for i in tqdm(reversed(range(0, self.n_steps)), desc='Sampling', total=self.n_steps):
            t = torch.full((b,), i, device=device, dtype=torch.long)
            with torch.no_grad():
                predicted_noise = self.unet(img, t)
                # Simple sampling without fancy variance scheduling
                alpha = self.alphas[i]
                alpha_bar = self.alphas_cumprod[i]
                beta = self.betas[i]

                if i > 0:
                    noise = torch.randn_like(img)
                else:
                    noise = torch.zeros_like(img)

                # Reshape for broadcasting
                alpha = alpha.view(-1, 1, 1, 1)
                alpha_bar = alpha_bar.view(-1, 1, 1, 1)
                beta = beta.view(-1, 1, 1, 1)

                # Formula for x_{t-1} given x_t and predicted noise
                img = (img - (1 - alpha) / torch.sqrt(1 - alpha_bar) * predicted_noise) / torch.sqrt(alpha)
                img = img + torch.sqrt(beta) * noise

        return img

    def forward(self, x, encode=True):
        """
        Training step
        """
        if encode:
            with torch.no_grad():
                x = self.vae.encode(x)

        batch_size = x.shape[0]
        device = x.device

        # Sample timesteps uniformly
        t = torch.randint(0, self.n_steps, (batch_size,), device=device).long()

        # Sample noise
        noise = torch.randn_like(x)

        # Apply forward diffusion
        x_noisy = self.q_sample(x, t, noise=noise)

        # Predict noise
        noise_pred = self.unet(x_noisy, t)

        return noise, noise_pred

    def sample(self, batch_size=1, img_size=8):
        """
        Generate samples
        """
        shape = (batch_size, latent_dim, img_size, img_size)
        latent_samples = self.p_sample_loop(shape)

        # Decode the latent samples
        with torch.no_grad():
            samples = self.vae.decode(latent_samples)

        return samplesregister_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - self.alphas_cumprod))

        # Calculations for posterior q(x_{t-1} | x_t, x_0)
        self.register_buffer('posterior_variance', self.betas * (1. - alphas_cumprod_prev) / (1. - self.alphas_cumprod))

    def q_sample(self, x_start, t, noise=None):
        """
        Forward diffusion process
        """
        if noise is None:
            noise = torch.randn_like(x_start)

        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t]
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t]

        # Reshape for proper broadcasting
        sqrt_alphas_cumprod_t = sqrt_alphas_cumprod_t.view(-1, 1, 1, 1)
        sqrt_one_minus_alphas_cumprod_t = sqrt_one_minus_alphas_cumprod_t.view(-1, 1, 1, 1)

        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

    def p_sample(self, model_output, x, t, betas):
        """
        Sample from the model using noise prediction
        """
        alpha = self.alphas[t]
        alpha_bar = self.alphas_cumprod[t]
        beta = betas[t]

        # Reshape for proper broadcasting
        alpha = alpha.view(-1, 1, 1, 1)
        alpha_bar = alpha_bar.view(-1, 1, 1, 1)
        beta = beta.view(-1, 1, 1, 1)

        # Formula for sampling
        eps_coef = beta / torch.sqrt(1 - alpha_bar)
        mean = (1 / torch.sqrt(alpha)) * (x - eps_coef * model_output)
        var = beta * torch.ones_like(x)

        # Sample
        noise = torch.randn_like(x) if t > 0 else torch.zeros_like(x)
        return mean + torch.sqrt(var) * noise

    def p_sample_loop(self, shape):
        """
        Generate samples from the model
        """
        device = next(self.parameters()).device
        b = shape[0]

        # Start from pure noise
        img = torch.randn(shape, device=device)

        for i in tqdm(reversed(range(0, self.n_steps)), desc='Sampling', total=self.n_steps):
            t = torch.full((b,), i, device=device, dtype=torch.long)
            with torch.no_grad():
                pred_noise = self.unet(img, t)
                img = self.p_sample(pred_noise, img, i, self.betas)

        return img

    def forward(self, x, encode=True):
        """
        Training step
        """
        if encode:
            with torch.no_grad():
                x = self.vae.encode(x)

        batch_size = x.shape[0]
        device = x.device

        # Sample timesteps uniformly
        t = torch.randint(0, self.n_steps, (batch_size,), device=device).long()

        # Sample noise
        noise = torch.randn_like(x)

        # Apply forward diffusion
        x_noisy = self.q_sample(x, t, noise=noise)

        # Predict noise
        noise_pred = self.unet(x_noisy, t)

        return noise, noise_pred

    def sample(self, batch_size=1, img_size=8):
        """
        Generate samples
        """
        shape = (batch_size, latent_dim, img_size, img_size)
        latent_samples = self.p_sample_loop(shape)

        # Decode the latent samples
        with torch.no_grad():
            samples = self.vae.decode(latent_samples)

        return samples

# VAE Loss Function
def vae_loss_function(recon_x, x, mu, log_var, kld_weight=0.00025):
    # Reconstruction loss (average over batch and pixels)
    recon_loss = F.mse_loss(recon_x, x, reduction='mean')

    # KL divergence loss (average over batch)
    kld_loss = -0.5 * torch.mean(torch.mean(1 + log_var - mu.pow(2) - log_var.exp(), dim=[1, 2, 3]))

    # Total loss (both terms are scalars)
    loss = recon_loss + kld_weight * kld_loss

    return loss, recon_loss, kld_loss

# Diffusion Loss Function (simple MSE loss between predicted noise and actual noise)
def diffusion_loss_function(noise, noise_pred):
    return F.mse_loss(noise_pred, noise, reduction='mean')

# Training functions
def train_vae(vae, dataloader, optimizer, num_epochs=30, device=device, save_dir='vae_checkpoints'):
    os.makedirs(save_dir, exist_ok=True)

    vae.train()
    for epoch in range(num_epochs):
        total_loss = 0
        recon_loss_total = 0
        kld_loss_total = 0
        num_batches = 0

        for batch in tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False):
            optimizer.zero_grad()

            # Move data to device
            x = batch.to(device)

            # Forward pass
            recon_x, mu, log_var = vae(x)

            # Calculate loss
            loss, recon_loss, kld_loss = vae_loss_function(recon_x, x, mu, log_var)

            # Backward pass
            loss.backward()
            optimizer.step()

            # Accumulate loss values (they're already averaged per batch)
            total_loss += loss.item()
            recon_loss_total += recon_loss.item()
            kld_loss_total += kld_loss.item()
            num_batches += 1

        # Calculate average loss across all batches
        avg_loss = total_loss / num_batches
        avg_recon_loss = recon_loss_total / num_batches
        avg_kld_loss = kld_loss_total / num_batches

        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}, Recon Loss: {avg_recon_loss:.4f}, KLD Loss: {avg_kld_loss:.4f}')

        # Save checkpoint
        if (epoch + 1) % 5 == 0 or epoch == num_epochs - 1:
            torch.save(vae.state_dict(), f'{save_dir}/vae_epoch_{epoch+1}.pt')

        # Generate and save sample reconstructions
        if (epoch + 1) % 5 == 0:
            generate_reconstructions(vae, dataloader, epoch, save_dir)

    return vae

def train_diffusion(ldm, dataloader, optimizer, num_epochs=50, device=device, save_dir='ldm_checkpoints'):
    os.makedirs(save_dir, exist_ok=True)

    ldm.train()
    for epoch in range(num_epochs):
        total_loss = 0
        num_batches = 0

        for batch in tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False):
            optimizer.zero_grad()

            # Move data to device
            x = batch.to(device)

            # Forward pass
            noise, noise_pred = ldm(x)

            # Calculate loss
            loss = diffusion_loss_function(noise, noise_pred)

            # Backward pass
            loss.backward()
            optimizer.step()

            # Accumulate loss (already mean per batch)
            total_loss += loss.item()
            num_batches += 1

        # Calculate average loss across all batches
        avg_loss = total_loss / num_batches
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}')

        # Save checkpoint
        if (epoch + 1) % 10 == 0 or epoch == num_epochs - 1:
            torch.save(ldm.state_dict(), f'{save_dir}/ldm_epoch_{epoch+1}.pt')

        # Generate and save samples
        if (epoch + 1) % 10 == 0:
            generate_samples(ldm, epoch, save_dir)

    return ldm

def generate_reconstructions(vae, dataloader, epoch, save_dir):
    vae.eval()
    with torch.no_grad():
        # Get a batch of data
        x = next(iter(dataloader)).to(device)

        # Generate reconstructions
        recon_x, _, _ = vae(x)

        # Create a grid of images
        n = min(8, x.size(0))
        comparison = torch.cat([x[:n], recon_x[:n]])

        # Save grid
        save_image(comparison.cpu(), f'{save_dir}/reconstruction_epoch_{epoch+1}.png', nrow=n, normalize=True)

    vae.train()

def generate_samples(ldm, epoch, save_dir):
    ldm.eval()
    with torch.no_grad():
        # Generate samples
        samples = ldm.sample(batch_size=8)

        # Save images
        save_image(samples.cpu(), f'{save_dir}/samples_epoch_{epoch+1}.png', nrow=4, normalize=True)

    ldm.train()

# Main execution
# Complete implementation including BasicDiffusionModel
def main():
    # Initialize models
    vae = VAE(in_channels=channels, latent_dim=latent_dim).to(device)

    # First, train the VAE
    print("Training VAE...")
    vae_optimizer = torch.optim.Adam(vae.parameters(), lr=learning_rate)
    vae = train_vae(vae, train_data, vae_optimizer, num_epochs=30, device=device)  # Reduced epochs for testing

    # Then, freeze VAE and train the diffusion model
    print("Training Diffusion Model...")
    for param in vae.parameters():
        param.requires_grad = False

    # Define a simpler noise predictor network
    class NoisePredictor(nn.Module):
        def __init__(self, latent_dim=32):
            super().__init__()
            # Simple network with time embedding as a channel
            self.net = nn.Sequential(
                nn.Conv2d(latent_dim + 1, 64, 3, padding=1),
                nn.GroupNorm(8, 64),
                nn.SiLU(),
                nn.Conv2d(64, 128, 3, padding=1),
                nn.GroupNorm(8, 128),
                nn.SiLU(),
                nn.Conv2d(128, 128, 3, padding=1),
                nn.GroupNorm(8, 128),
                nn.SiLU(),
                nn.Conv2d(128, 64, 3, padding=1),
                nn.GroupNorm(8, 64),
                nn.SiLU(),
                nn.Conv2d(64, latent_dim, 3, padding=1)
            )

        def forward(self, x, t):
            # Normalize timestep to [0, 1] and add as channel
            t_emb = t.float() / 1000.0  # assuming 1000 diffusion steps
            t_emb = t_emb.view(-1, 1, 1, 1).expand(-1, 1, x.shape[2], x.shape[3])
            x_input = torch.cat([x, t_emb], dim=1)  # Add timestep as channel
            return self.net(x_input)

    # Create the noise predictor
    noise_pred = NoisePredictor(latent_dim=latent_dim).to(device)

    # Create latent diffusion model with the trained VAE and noise predictor
    ldm = LatentDiffusion(vae, noise_pred, n_steps=1000).to(device)
    ldm_optimizer = torch.optim.Adam(ldm.unet.parameters(), lr=learning_rate)

    ldm = train_diffusion(ldm, train_data, ldm_optimizer, num_epochs=50, device=device)

    print("Training completed!")

    # Generate final samples
    print("Generating final samples...")
    os.makedirs('final_samples', exist_ok=True)
    ldm.eval()
    with torch.no_grad():
        samples = ldm.sample(batch_size=16)
        save_image(samples.cpu(), 'final_samples/final_samples.png', nrow=4, normalize=True)

    print("Done!")

if __name__ == "__main__":
    main()

Using device: cuda
Training VAE...




Epoch 1/30, Loss: 0.4137, Recon Loss: 0.4092, KLD Loss: 17.8637




Epoch 2/30, Loss: 0.3589, Recon Loss: 0.3573, KLD Loss: 6.0168




Epoch 3/30, Loss: 0.3110, Recon Loss: 0.3096, KLD Loss: 5.4383




Epoch 4/30, Loss: 0.2249, Recon Loss: 0.2241, KLD Loss: 3.4221




Epoch 5/30, Loss: 0.1888, Recon Loss: 0.1875, KLD Loss: 5.1745




Epoch 6/30, Loss: 0.1484, Recon Loss: 0.1473, KLD Loss: 4.5026




Epoch 7/30, Loss: 0.1363, Recon Loss: 0.1352, KLD Loss: 4.3195




Epoch 8/30, Loss: 0.1228, Recon Loss: 0.1219, KLD Loss: 3.6752




Epoch 9/30, Loss: 0.1098, Recon Loss: 0.1088, KLD Loss: 3.9143




Epoch 10/30, Loss: 0.1062, Recon Loss: 0.1052, KLD Loss: 3.8800




Epoch 11/30, Loss: 0.0985, Recon Loss: 0.0974, KLD Loss: 4.1380




Epoch 12/30, Loss: 0.0938, Recon Loss: 0.0927, KLD Loss: 4.3620




Epoch 13/30, Loss: 0.0785, Recon Loss: 0.0775, KLD Loss: 4.0544




Epoch 14/30, Loss: 0.0732, Recon Loss: 0.0722, KLD Loss: 4.0969




Epoch 15/30, Loss: 0.0693, Recon Loss: 0.0682, KLD Loss: 4.4365




Epoch 16/30, Loss: 0.0585, Recon Loss: 0.0574, KLD Loss: 4.3152




Epoch 17/30, Loss: 0.0561, Recon Loss: 0.0551, KLD Loss: 3.9679




Epoch 18/30, Loss: 0.0520, Recon Loss: 0.0510, KLD Loss: 3.8361




Epoch 19/30, Loss: 0.0527, Recon Loss: 0.0517, KLD Loss: 3.7732




Epoch 20/30, Loss: 0.0477, Recon Loss: 0.0467, KLD Loss: 3.7472




Epoch 21/30, Loss: 0.0479, Recon Loss: 0.0469, KLD Loss: 3.7840




Epoch 22/30, Loss: 0.0488, Recon Loss: 0.0478, KLD Loss: 3.9233




Epoch 23/30, Loss: 0.0489, Recon Loss: 0.0479, KLD Loss: 4.1170




Epoch 24/30, Loss: 0.0442, Recon Loss: 0.0432, KLD Loss: 4.3614




Epoch 25/30, Loss: 0.0426, Recon Loss: 0.0416, KLD Loss: 4.0799




Epoch 26/30, Loss: 0.0406, Recon Loss: 0.0396, KLD Loss: 3.9734




Epoch 27/30, Loss: 0.0397, Recon Loss: 0.0387, KLD Loss: 3.8391




Epoch 28/30, Loss: 0.0367, Recon Loss: 0.0358, KLD Loss: 3.6792




Epoch 29/30, Loss: 0.0370, Recon Loss: 0.0361, KLD Loss: 3.6136




Epoch 30/30, Loss: 0.0373, Recon Loss: 0.0363, KLD Loss: 3.7345
Training Diffusion Model...




Epoch 1/50, Loss: 1.0411




Epoch 2/50, Loss: 0.9931




Epoch 3/50, Loss: 0.9750




Epoch 4/50, Loss: 0.9476




Epoch 5/50, Loss: 0.9173




Epoch 6/50, Loss: 0.8833




Epoch 7/50, Loss: 0.8608




Epoch 8/50, Loss: 0.8270




Epoch 9/50, Loss: 0.8052




Epoch 10/50, Loss: 0.7799


Sampling: 100%|██████████| 1000/1000 [00:01<00:00, 874.09it/s]


Epoch 11/50, Loss: 0.7609




Epoch 12/50, Loss: 0.7418




Epoch 13/50, Loss: 0.7253




Epoch 14/50, Loss: 0.7107




Epoch 15/50, Loss: 0.6916




Epoch 16/50, Loss: 0.6989




Epoch 17/50, Loss: 0.6747




Epoch 18/50, Loss: 0.6532




Epoch 19/50, Loss: 0.6443




Epoch 20/50, Loss: 0.6357


Sampling: 100%|██████████| 1000/1000 [00:01<00:00, 954.56it/s]


Epoch 21/50, Loss: 0.6238




Epoch 22/50, Loss: 0.6290




Epoch 23/50, Loss: 0.6341




Epoch 24/50, Loss: 0.6011




Epoch 25/50, Loss: 0.5968




Epoch 26/50, Loss: 0.5779




Epoch 27/50, Loss: 0.5872




Epoch 28/50, Loss: 0.5730




Epoch 29/50, Loss: 0.5690




Epoch 30/50, Loss: 0.5727


Sampling: 100%|██████████| 1000/1000 [00:01<00:00, 921.37it/s]


Epoch 31/50, Loss: 0.5487




Epoch 32/50, Loss: 0.5455




Epoch 33/50, Loss: 0.5421




Epoch 34/50, Loss: 0.5413




Epoch 35/50, Loss: 0.5416




Epoch 36/50, Loss: 0.5358




Epoch 37/50, Loss: 0.5276




Epoch 38/50, Loss: 0.5333




Epoch 39/50, Loss: 0.5241




Epoch 40/50, Loss: 0.5251


Sampling: 100%|██████████| 1000/1000 [00:01<00:00, 931.38it/s]


Epoch 41/50, Loss: 0.5016




Epoch 42/50, Loss: 0.4907




Epoch 43/50, Loss: 0.4997




Epoch 44/50, Loss: 0.4949




Epoch 45/50, Loss: 0.5096




Epoch 46/50, Loss: 0.4760




Epoch 47/50, Loss: 0.4901




Epoch 48/50, Loss: 0.4761




Epoch 49/50, Loss: 0.4930




Epoch 50/50, Loss: 0.4918


Sampling: 100%|██████████| 1000/1000 [00:01<00:00, 926.17it/s]


Training completed!
Generating final samples...


Sampling: 100%|██████████| 1000/1000 [00:01<00:00, 912.79it/s]


Done!
