# DDPM Training on Google Colab with Real-time Visualization

This notebook trains a DDPM model on MNIST using GPU acceleration with live progress tracking, loss graphs, and sample visualization.


In [None]:
# Install dependencies
!pip install denoising_diffusion_pytorch torch torchvision matplotlib tqdm seaborn plotly

# Check GPU availability
import torch
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import clear_output, display
import time

print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    print("⚠️ No GPU detected! Training will be slow on CPU.")


In [None]:
# Clone your repository
!git clone https://github.com/yatuzhang/ddpm-mnist.git
%cd ddpm-mnist

# List files to verify
!ls -la

# Create directories for outputs
import os
os.makedirs('colab_outputs', exist_ok=True)
os.makedirs('colab_outputs/samples', exist_ok=True)
os.makedirs('colab_outputs/checkpoints', exist_ok=True)


In [None]:
# Enhanced training script with real-time visualization
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from denoising_diffusion_pytorch import Unet, GaussianDiffusion
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import time
from IPython.display import clear_output, display
import seaborn as sns

# Set style for better plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

def get_data_loaders(batch_size=128, num_workers=2):
    """Get MNIST data loaders."""
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    
    train_dataset = datasets.MNIST(
        root='./data', train=True, download=True, transform=transform
    )
    test_dataset = datasets.MNIST(
        root='./data', train=False, download=True, transform=transform
    )
    
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers
    )
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers
    )
    
    return train_loader, test_loader

def create_model(device):
    """Create DDPM model."""
    model = Unet(
        dim=64,
        channels=1,
        dim_mults=(1, 2, 4),
        flash_attn=True
    )
    
    diffusion = GaussianDiffusion(
        model,
        image_size=28,
        timesteps=1000,
        sampling_timesteps=250,
        objective='pred_v'
    ).to(device)
    
    return diffusion

def save_samples_grid(samples, title="Generated Samples", nrow=4, ncol=4):
    """Save samples in a grid format."""
    samples = (samples + 1) / 2
    samples = torch.clamp(samples, 0, 1)
    
    fig, axes = plt.subplots(nrow, ncol, figsize=(ncol*2, nrow*2))
    if nrow == 1 and ncol == 1:
        axes = [axes]
    elif nrow == 1 or ncol == 1:
        axes = axes.flatten()
    else:
        axes = axes.flatten()
    
    for i in range(min(nrow * ncol, len(samples))):
        axes[i].imshow(samples[i].squeeze().cpu().numpy(), cmap='gray')
        axes[i].axis('off')
    
    for i in range(nrow * ncol, len(axes)):
        axes[i].set_visible(False)
    
    plt.suptitle(title, fontsize=16)
    plt.tight_layout()
    return fig

def visualize_denoising_process(diffusion, original_image, device, save_path=None):
    """Visualize the denoising process step by step."""
    diffusion.eval()
    
    # Get a random timestep
    t = torch.randint(0, diffusion.num_timesteps, (1,), device=device)
    
    # Add noise to original image
    noise = torch.randn_like(original_image)
    noisy_image = diffusion.q_sample(original_image, t, noise)
    
    # Denoise step by step
    denoised_images = [noisy_image]
    
    with torch.no_grad():
        # Use DDIM sampling for visualization
        x = noisy_image
        for i in range(0, diffusion.num_timesteps, 50):  # Every 50 steps
            t_batch = torch.full((1,), i, device=device, dtype=torch.long)
            predicted_noise = diffusion.model(x, t_batch)
            x = diffusion.p_sample(x, i)  # p_sample expects integer timestep
            denoised_images.append(x)
    
    # Create comparison plot
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    
    images_to_show = [
        original_image, noisy_image, 
        denoised_images[2], denoised_images[4],
        denoised_images[6], denoised_images[8],
        denoised_images[10], denoised_images[-1]
    ]
    
    titles = ['Original', 'Noisy', 'Step 100', 'Step 200', 
              'Step 300', 'Step 400', 'Step 500', 'Final']
    
    for i, (img, title) in enumerate(zip(images_to_show, titles)):
        row, col = i // 4, i % 4
        axes[row, col].imshow(img.squeeze().cpu().numpy(), cmap='gray')
        axes[row, col].set_title(title)
        axes[row, col].axis('off')
    
    plt.suptitle('Denoising Process Visualization', fontsize=16)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    
    return fig

def train_with_visualization(epochs=50, batch_size=128):
    """Train with real-time visualization."""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Get data loaders
    train_loader, test_loader = get_data_loaders(batch_size)
    
    # Create model
    diffusion = create_model(device)
    print(f"Model created with {sum(p.numel() for p in diffusion.parameters())} parameters")
    
    # Create optimizer
    optimizer = optim.AdamW(diffusion.parameters(), lr=2e-4, weight_decay=1e-4)
    
    # Training history
    train_losses = []
    test_losses = []
    
    # Get a sample image for denoising visualization
    sample_batch = next(iter(test_loader))[0][:1].to(device)
    
    print("Starting training with real-time visualization...")
    
    for epoch in range(epochs):
        # Training
        diffusion.train()
        total_loss = 0
        
        pbar = tqdm(train_loader, desc=f'Epoch {epoch}')
        for batch_idx, (data, _) in enumerate(pbar):
            data = data.to(device)
            
            optimizer.zero_grad()
            loss = diffusion(data)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            pbar.set_postfix({'Loss': f'{loss.item():.4f}'})
        
        avg_train_loss = total_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        
        # Testing
        diffusion.eval()
        test_loss = 0
        with torch.no_grad():
            for data, _ in test_loader:
                data = data.to(device)
                loss = diffusion(data)
                test_loss += loss.item()
        
        avg_test_loss = test_loss / len(test_loader)
        test_losses.append(avg_test_loss)
        
        # Generate samples
        with torch.no_grad():
            samples = diffusion.sample(batch_size=16)
        
        # Clear output and show progress
        clear_output(wait=True)
        
        # Create comprehensive visualization
        fig = plt.figure(figsize=(20, 12))
        
        # Loss plot
        plt.subplot(2, 4, 1)
        plt.plot(train_losses, label='Train Loss', color='blue')
        plt.plot(test_losses, label='Test Loss', color='red')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training Progress')
        plt.legend()
        plt.grid(True)
        
        # Generated samples
        plt.subplot(2, 4, (2, 4))
        samples_grid = save_samples_grid(samples, f'Generated Samples - Epoch {epoch}')
        plt.show()
        
        # Denoising process
        plt.subplot(2, 4, (5, 8))
        denoising_fig = visualize_denoising_process(diffusion, sample_batch, device)
        plt.show()
        
        # Save samples
        samples_grid.savefig(f'colab_outputs/samples/epoch_{epoch:03d}_samples.png', 
                           dpi=150, bbox_inches='tight')
        denoising_fig.savefig(f'colab_outputs/samples/epoch_{epoch:03d}_denoising.png', 
                            dpi=150, bbox_inches='tight')
        
        # Save checkpoint
        if epoch % 10 == 0 or epoch == epochs - 1:
            torch.save({
                'epoch': epoch,
                'model_state_dict': diffusion.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_losses': train_losses,
                'test_losses': test_losses,
            }, f'colab_outputs/checkpoints/checkpoint_epoch_{epoch:03d}.pth')
        
        print(f'Epoch {epoch:3d}: Train Loss = {avg_train_loss:.4f}, Test Loss = {avg_test_loss:.4f}')
        
        # Close figures to free memory
        plt.close('all')
    
    print("Training completed!")
    return diffusion, train_losses, test_losses

# Start training
diffusion, train_losses, test_losses = train_with_visualization(epochs=50, batch_size=128)


In [None]:
# Generate final samples and create summary
print("Generating final samples...")

# Generate more samples for final evaluation
with torch.no_grad():
    final_samples = diffusion.sample(batch_size=64)

# Create a large grid of final samples
fig, axes = plt.subplots(8, 8, figsize=(16, 16))
for i in range(64):
    row, col = i // 8, i % 8
    axes[row, col].imshow(final_samples[i].squeeze().cpu().numpy(), cmap='gray')
    axes[row, col].axis('off')

plt.suptitle('Final Generated Samples (64 images)', fontsize=20)
plt.tight_layout()
plt.savefig('colab_outputs/final_samples_64.png', dpi=150, bbox_inches='tight')
plt.show()

# Create training summary
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Loss curves
ax1.plot(train_losses, label='Train Loss', color='blue', linewidth=2)
ax1.plot(test_losses, label='Test Loss', color='red', linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training Progress')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Loss improvement
ax2.plot(np.diff(train_losses), label='Train Loss Change', color='blue', alpha=0.7)
ax2.plot(np.diff(test_losses), label='Test Loss Change', color='red', alpha=0.7)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Loss Change')
ax2.set_title('Loss Improvement Over Time')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('colab_outputs/training_summary.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"Final training loss: {train_losses[-1]:.4f}")
print(f"Final test loss: {test_losses[-1]:.4f}")
print(f"Total improvement: {train_losses[0] - train_losses[-1]:.4f}")


In [None]:
# Download results
from google.colab import files
import zipfile
import os

# Create zip file with all results
with zipfile.ZipFile('ddpm_training_results.zip', 'w') as zipf:
    # Add all output files
    for root, dirs, files in os.walk('colab_outputs'):
        for file in files:
            file_path = os.path.join(root, file)
            zipf.write(file_path, file_path)
    
    # Add final samples
    if os.path.exists('colab_outputs/final_samples_64.png'):
        zipf.write('colab_outputs/final_samples_64.png')
    if os.path.exists('colab_outputs/training_summary.png'):
        zipf.write('colab_outputs/training_summary.png')

print("Files to download:")
for root, dirs, files in os.walk('colab_outputs'):
    for file in files:
        print(f"  - {os.path.join(root, file)}")

# Download the zip file
files.download('ddpm_training_results.zip')

print("\n🎉 Training completed successfully!")
print("📁 All results have been saved and downloaded.")
print("📊 Check the visualizations above to see your model's progress!")
