In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torchdiffeq import odeint
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import json
from datetime import datetime

In [None]:
# ================== Experiment Configuration ==================
EXPERIMENT_NAME = "baseline"  # Change this for each experiment
RESULTS_DIR = f"results/{EXPERIMENT_NAME}"
os.makedirs(RESULTS_DIR, exist_ok=True)

In [None]:
# ================== Configuration Parameters ==================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_size = 32  # CIFAR-10 is 32x32
channels = 3     # CIFAR-10 is RGB
batch_size = 128
num_classes = 10
model_save_path = os.path.join(RESULTS_DIR, 'FMmodel.pth')

In [None]:
# Learning Rate
lr = 1e-4
# lr = 1e-3

# Number of Epochs
epochs = 100

# Model Size
BASE_CHANNELS = 128

# Time Embedding Dimension 
TIME_EMBED_DIM = 128

# ODE Solver
ODE_METHOD = 'dopri5'  # adaptive Runge-Kutta
# ODE_METHOD = 'euler'  # simplest method: Euler
# ODE_METHOD = 'rk4' # heun's method

# Number of ODE Steps 
ODE_STEPS = 50  #100

# Loss Function: b/w u^target and u^theta
LOSS_TYPE = 'mse'  # Mean Squared Error
# LOSS_TYPE = 'l1'   # L1 Loss (Mean Absolute Error)
# LOSS_TYPE = 'huber'  # Huber Loss

# Normalization 
NORM_TYPE = 'groupnorm'
# NORM_TYPE = 'batchnorm'
# NORM_TYPE = 'layernorm'

# Activation Function
ACTIVATION = 'silu'  # SiLU/Swish

# Optimizer 
OPTIMIZER_TYPE = 'adamw'
# OPTIMIZER_TYPE = 'sgd'

# Save experiment config
config = {
    'experiment_name': EXPERIMENT_NAME,
    'lr': lr,
    'epochs': epochs,
    'batch_size': batch_size,
    'base_channels': BASE_CHANNELS,
    'time_embed_dim': TIME_EMBED_DIM,
    'ode_method': ODE_METHOD,
    'ode_steps': ODE_STEPS,
    'loss_type': LOSS_TYPE,
    'norm_type': NORM_TYPE,
    'activation': ACTIVATION,
    'optimizer_type': OPTIMIZER_TYPE,
    'image_size': image_size,
    'channels': channels,
    'device': str(device)
}

with open(os.path.join(RESULTS_DIR, 'config.json'), 'w') as f:
    json.dump(config, f, indent=4)

In [None]:
# ================== Data Loading ==================
def normalize_img(x):
    """Normalize image to [-1, 1]"""
    return 2 * x - 1

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(normalize_img)
])

# ================== Helper Functions ==================
def get_activation():
    """Return activation function based on config"""
    if ACTIVATION == 'silu':
        return F.silu
    elif ACTIVATION == 'swiglu':
        return SwiGLU()
    elif ACTIVATION == 'gelu':
        return F.gelu
    else:
        return F.silu

def get_norm_layer(num_channels):
    """Return normalization layer based on config"""
    if NORM_TYPE == 'groupnorm':
        return nn.GroupNorm(min(32, num_channels), num_channels)
    elif NORM_TYPE == 'batchnorm':
        return nn.BatchNorm2d(num_channels)
    elif NORM_TYPE == 'layernorm':
        # For LayerNorm with images, normalize over C, H, W
        return nn.GroupNorm(1, num_channels)
    else:
        return nn.GroupNorm(min(32, num_channels), num_channels)

In [None]:
# ================== Model Architecture ==================
class ConditionedDoubleConv(nn.Module):
    """Double convolution module with condition injection"""
    
    def __init__(self, in_channels, out_channels, cond_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.norm1 = get_norm_layer(out_channels)
        self.conv2 = nn.Conv2d(out_channels + cond_dim, out_channels, kernel_size=3, padding=1)
        self.norm2 = get_norm_layer(out_channels)
        self.activation = get_activation()
    
    def forward(self, x, cond):
        x = self.activation(self.norm1(self.conv1(x)))
        cond = cond.expand(-1, -1, x.size(2), x.size(3))
        x = torch.cat([x, cond], dim=1)
        return self.activation(self.norm2(self.conv2(x)))


class Down(nn.Module):
    """Downsampling module"""
    
    def __init__(self, in_channels, out_channels, cond_dim):
        super().__init__()
        self.maxpool = nn.MaxPool2d(2)
        self.conv = ConditionedDoubleConv(in_channels, out_channels, cond_dim)
    
    def forward(self, x, cond):
        x = self.maxpool(x)
        return self.conv(x, cond)


class Up(nn.Module):
    """Upsampling module"""
    
    def __init__(self, in_channels, out_channels, cond_dim):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = ConditionedDoubleConv(in_channels, out_channels, cond_dim)
    
    def forward(self, x1, x2, cond):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x, cond)


class ConditionalUNet(nn.Module):
    """Enhanced UNet for CIFAR-10"""
    
    def __init__(self):
        super().__init__()
        # Condition dimensions
        self.label_dim = 32
        self.cond_dim = TIME_EMBED_DIM + self.label_dim
        
        # Time embedding
        self.time_embed = nn.Sequential(
            nn.Linear(1, TIME_EMBED_DIM * 2),
            nn.SiLU(),
            nn.Linear(TIME_EMBED_DIM * 2, TIME_EMBED_DIM)
        )
        
        # Label embedding
        self.label_embed = nn.Embedding(num_classes, self.label_dim)
        
        # Encoder path (deeper for CIFAR-10)
        self.inc = ConditionedDoubleConv(channels, BASE_CHANNELS, self.cond_dim)
        self.down1 = Down(BASE_CHANNELS, BASE_CHANNELS * 2, self.cond_dim)
        self.down2 = Down(BASE_CHANNELS * 2, BASE_CHANNELS * 4, self.cond_dim)
        self.down3 = Down(BASE_CHANNELS * 4, BASE_CHANNELS * 8, self.cond_dim)
        
        # Decoder path
        self.up1 = Up(BASE_CHANNELS * 8 + BASE_CHANNELS * 4, BASE_CHANNELS * 4, self.cond_dim)
        self.up2 = Up(BASE_CHANNELS * 4 + BASE_CHANNELS * 2, BASE_CHANNELS * 2, self.cond_dim)
        self.up3 = Up(BASE_CHANNELS * 2 + BASE_CHANNELS, BASE_CHANNELS, self.cond_dim)
        self.outc = nn.Conv2d(BASE_CHANNELS, channels, kernel_size=1)
    
    def forward(self, x, t, labels):
        # Condition encoding
        t_emb = self.time_embed(t.view(-1, 1))
        lbl_emb = self.label_embed(labels)
        cond = torch.cat([t_emb, lbl_emb], dim=1)
        cond = cond.unsqueeze(-1).unsqueeze(-1)
        
        # Encoder
        x1 = self.inc(x, cond)
        x2 = self.down1(x1, cond)
        x3 = self.down2(x2, cond)
        x4 = self.down3(x3, cond)
        
        # Decoder
        x = self.up1(x4, x3, cond)
        x = self.up2(x, x2, cond)
        x = self.up3(x, x1, cond)
        return self.outc(x)

In [None]:
# ================== Training and Generation ==================
model = ConditionalUNet().to(device)

# Initialize optimizer based on config
if OPTIMIZER_TYPE == 'adam':
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
elif OPTIMIZER_TYPE == 'adamw':
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
elif OPTIMIZER_TYPE == 'sgd':
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
else:
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)


def compute_loss(pred, target):
    """Compute loss based on config"""
    if LOSS_TYPE == 'mse':
        return F.mse_loss(pred, target)
    elif LOSS_TYPE == 'l1':
        return F.l1_loss(pred, target)
    elif LOSS_TYPE == 'huber':
        return F.huber_loss(pred, target)
    else:
        return F.mse_loss(pred, target)


@torch.no_grad()
def generate_with_label(label, num_samples=16):
    """Generate samples with specified label"""
    current_model_state = model.training
    model.eval()
    
    x0 = torch.randn(num_samples, channels, image_size, image_size, device=device)
    labels = torch.full((num_samples,), label, device=device, dtype=torch.long)
    
    def ode_func(t: torch.Tensor, x: torch.Tensor):
        t_expanded = t.expand(x.size(0))
        vt = model(x, t_expanded, labels)
        return vt
    
    # Use different time points based on ODE method
    if ODE_METHOD in ['euler', 'rk4', 'midpoint']:
        t_eval = torch.linspace(0.0, 1.0, ODE_STEPS + 1, device=device)
    else:
        t_eval = torch.tensor([0.0, 1.0], device=device)
    
    generated = odeint(
        ode_func,
        x0,
        t_eval,
        rtol=1e-5,
        atol=1e-5,
        method=ODE_METHOD
    )
    
    model.train(current_model_state)
    
    images = (generated[-1].clamp(-1, 1) + 1) / 2
    return images.cpu()


def visualize_train(epoch):
    """Generate visualization grid"""
    print(f"Generating visualization for epoch {epoch}...")
    plt.figure(figsize=(12, 12))
    plt.subplots_adjust(wspace=0.1, hspace=0.1)
    
    for label in range(num_classes):
        generated_images = generate_with_label(label=label, num_samples=10)
        
        for i in range(10):
            ax = plt.subplot(10, num_classes, (i * num_classes) + label + 1)
            # Convert CHW to HWC for display
            img = generated_images[i].permute(1, 2, 0).numpy()
            plt.imshow(img, vmin=0, vmax=1)
            ax.axis('off')
            if i == 0:
                ax.set_title(str(label), fontsize=14, pad=5)
    
    plt.suptitle(f"Generated Samples - Epoch {epoch}", fontsize=18, y=0.98)
    plt.savefig(os.path.join(RESULTS_DIR, f"epoch{epoch}.jpg"), dpi=150, bbox_inches='tight')
    plt.close()


def generate_final_samples(num_sample=5):
    """Generate final sample grids"""
    print(f"Generating {num_sample} final sample grids...")
    
    for k in range(num_sample):
        plt.figure(figsize=(12, 12))
        plt.subplots_adjust(wspace=0.1, hspace=0.1)
        
        print(f"Generating grid {k + 1}/{num_sample}...")
        for label in tqdm(range(num_classes), desc=f"Grid {k+1}"):
            generated_images = generate_with_label(label=label, num_samples=10)
            
            for i in range(10):
                ax = plt.subplot(10, num_classes, (i * num_classes) + label + 1)
                img = generated_images[i].permute(1, 2, 0).numpy()
                plt.imshow(img, vmin=0, vmax=1)
                ax.axis('off')
                if i == 0:
                    ax.set_title(str(label), fontsize=14, pad=5)
        
        plt.suptitle(f"Final Generated Samples (Grid {k + 1})", fontsize=18, y=0.98)
        save_path = os.path.join(RESULTS_DIR, f"generated_grid{k + 1}.jpg")
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Saved: {save_path}")
        plt.close()

In [None]:
def train(num_epochs=100):
    """Training loop with logging"""
    print(f"Starting training for experiment: {EXPERIMENT_NAME}")
    print(f"Results will be saved to: {RESULTS_DIR}")
    print(f"Configuration: {config}")
    
    global train_loader
    training_log = []
    
    for epoch in range(num_epochs):
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}")
        model.train()
        total_loss = 0
        num_batches = 0
        
        for images, labels in progress_bar:
            images = images.to(device)
            labels = labels.to(device)
            
            noise = torch.randn_like(images)
            t = torch.rand(images.size(0), device=device)
            xt = (1 - t.view(-1, 1, 1, 1)) * noise + t.view(-1, 1, 1, 1) * images
            
            vt_pred = model(xt, t, labels)
            loss = compute_loss(vt_pred, images - noise)
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
            
            progress_bar.set_postfix({"Loss": f"{loss.item():.4f}"})
        
        avg_loss = total_loss / num_batches
        training_log.append({'epoch': epoch + 1, 'loss': avg_loss})
        print(f"Epoch {epoch + 1} - Average Loss: {avg_loss:.4f}")
        
        # Save loss log
        with open(os.path.join(RESULTS_DIR, 'training_log.json'), 'w') as f:
            json.dump(training_log, f, indent=4)
        
        # Generate samples
        if (epoch + 1) % 10 == 0 or epoch == 0:
            visualize_train(epoch + 1)
    
    # Save model
    torch.save(model.state_dict(), model_save_path)
    print(f"Training complete. Model saved to: {model_save_path}")
    
    # Plot training curve
    plot_training_curve(training_log)


def plot_training_curve(training_log):
    """Plot and save training loss curve"""
    epochs_list = [entry['epoch'] for entry in training_log]
    losses = [entry['loss'] for entry in training_log]
    
    plt.figure(figsize=(10, 6))
    plt.plot(epochs_list, losses, linewidth=2)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.title(f'Training Loss - {EXPERIMENT_NAME}', fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join(RESULTS_DIR, 'training_curve.png'), dpi=150, bbox_inches='tight')
    plt.close()
    print(f"Training curve saved to: {os.path.join(RESULTS_DIR, 'training_curve.png')}")

In [None]:
# Load CIFAR-10
train_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                            num_workers=4, pin_memory=True)

print(f"Dataset loaded: {len(train_dataset)} training images")
print(f"Image shape: {channels} x {image_size} x {image_size}")

# Train model
train(epochs)

# Generate final samples
generate_final_samples(num_sample=5)

print(f"\nExperiment '{EXPERIMENT_NAME}' completed!")
print(f"All results saved in: {RESULTS_DIR}")