In [4]:
import torch
print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())


Torch version: 2.4.1.post302
CUDA available: True


In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel

# Check if GPU is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# 1. Load the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize between -1 and 1
])

mnist_train = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
train_loader = DataLoader(mnist_train, batch_size=64, shuffle=True)

# 2. Define the Diffusion Model (UNet)
model = UNet2DModel(
    sample_size=28,  # MNIST images are 28x28
    in_channels=1,    # Grayscale images
    out_channels=1,
    layers_per_block=2,
    block_out_channels=(64, 128, 256),
    down_block_types=("DownBlock2D", "DownBlock2D", "AttnDownBlock2D"),
    up_block_types=("AttnUpBlock2D", "UpBlock2D", "UpBlock2D")
).to(device)

# 3. Define the Scheduler (for the diffusion process)
scheduler = DDPMScheduler(num_train_timesteps=1000)

# 4. Define Optimizer and Loss
optimizer = optim.AdamW(model.parameters(), lr=1e-4)

# 5. Training Loop
num_epochs = 5
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(train_loader):
        images = images.to(device)

        # Sample noise and create noisy images
        noise = torch.randn_like(images).to(device)
        timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (images.shape[0],), device=device).long()
        noisy_images = scheduler.add_noise(images, noise, timesteps)

        # Predict noise with the model
        noise_pred = model(noisy_images, timesteps).sample

        # Compute loss (Mean Squared Error)
        loss = nn.functional.mse_loss(noise_pred, noise)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 500 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Step {i}, Loss: {loss.item():.4f}")

print("Training complete!")

# 6. Generate new images with the trained model
pipeline = DDPMPipeline(unet=model, scheduler=scheduler)
pipeline.to(device)

# Generate 4 new images
generated_images = pipeline(batch_size=4).images

# Show generated images
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 4, figsize=(10, 5))
for i, img in enumerate(generated_images):
    axes[i].imshow(img, cmap="gray")
    axes[i].axis("off")
plt.show()


Using device: cuda
Epoch 1/5, Step 0, Loss: 1.1872
Epoch 1/5, Step 500, Loss: 0.0411
Epoch 2/5, Step 0, Loss: 0.0289
Epoch 2/5, Step 500, Loss: 0.0240


KeyboardInterrupt: 

In [5]:
import time

num_epochs = 5
total_steps = len(train_loader)  # Number of batches per epoch

for epoch in range(num_epochs):
    start_time = time.time()  # Track epoch start time
    epoch_loss = 0  # Track total loss per epoch
    
    for i, (images, _) in enumerate(train_loader):
        batch_start_time = time.time()  # Track batch time
        
        images = images.to(device)

        # Sample noise and create noisy images
        noise = torch.randn_like(images).to(device)
        timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (images.shape[0],), device=device).long()
        noisy_images = scheduler.add_noise(images, noise, timesteps)

        # Predict noise with the model
        noise_pred = model(noisy_images, timesteps).sample

        # Compute loss
        loss = nn.functional.mse_loss(noise_pred, noise)
        epoch_loss += loss.item()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Print progress every 500 steps
        if i % 500 == 0 or i == total_steps - 1:
            batch_time = time.time() - batch_start_time  # Time taken per batch
            print(f"Epoch {epoch+1}/{num_epochs}, Step {i}/{total_steps}, Loss: {loss.item():.4f}, Batch Time: {batch_time:.2f} sec")

    # End of epoch summary
    epoch_time = time.time() - start_time  # Time taken for full epoch
    avg_loss = epoch_loss / total_steps  # Average loss per epoch
    print(f"Epoch {epoch+1} completed in {epoch_time:.2f} sec - Average Loss: {avg_loss:.4f}\n")


Epoch 1/5, Step 0/938, Loss: 0.0348, Batch Time: 0.71 sec


KeyboardInterrupt: 