# Variational Autoencoder (VAE) for Image Generation

**Author:** Molla Samser  
**Website:** https://rskworld.in  
**Email:** help@rskworld.in, support@rskworld.in  
**Phone:** +91 93305 39277  
**Designer & Tester:** Rima Khatun

This notebook demonstrates how to train and use a Variational Autoencoder for image generation.


## 1. Import Libraries


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
import numpy as np
import os
from vae_model import VAE, vae_loss
from utils import save_reconstructions, visualize_latent_space

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')


## 2. Load and Prepare Data


In [None]:
# Data transformation
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # For MNIST (grayscale)
    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # For CIFAR10 (RGB)
])

# Load dataset (MNIST example)
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

print(f'Training samples: {len(train_dataset)}')
print(f'Test samples: {len(test_dataset)}')


## 3. Initialize Model


In [None]:
# Model parameters
latent_dim = 128
input_channels = 1  # 1 for MNIST, 3 for CIFAR10

# Initialize model
model = VAE(input_channels=input_channels, latent_dim=latent_dim).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f'Total parameters: {total_params:,}')
print(f'Model architecture:')
print(model)


## 4. Training Function


In [None]:
def train_epoch(model, dataloader, optimizer, device, beta=1.0):
    """
    Train the model for one epoch.
    
    Author: Molla Samser
    Website: https://rskworld.in
    Email: help@rskworld.in, support@rskworld.in
    Phone: +91 93305 39277
    Designer & Tester: Rima Khatun
    """
    model.train()
    total_loss = 0
    total_recon_loss = 0
    total_kl_loss = 0
    
    for batch_idx, (data, _) in enumerate(dataloader):
        data = data.to(device)
        optimizer.zero_grad()
        
        # Forward pass
        reconstructed, mu, logvar, z = model(data)
        
        # Calculate loss
        loss, recon_loss, kl_loss = vae_loss(reconstructed, data, mu, logvar, beta)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        total_recon_loss += recon_loss.item()
        total_kl_loss += kl_loss.item()
    
    avg_loss = total_loss / len(dataloader.dataset)
    avg_recon = total_recon_loss / len(dataloader.dataset)
    avg_kl = total_kl_loss / len(dataloader.dataset)
    
    return avg_loss, avg_recon, avg_kl


## 5. Train the Model


In [None]:
# Training parameters
num_epochs = 50
learning_rate = 1e-3
beta = 1.0  # KL divergence weight

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training history
train_losses = []
train_recon_losses = []
train_kl_losses = []

# Create output directory
os.makedirs('outputs/samples', exist_ok=True)

# Training loop
for epoch in range(num_epochs):
    print(f'\nEpoch {epoch+1}/{num_epochs}')
    print('-' * 50)
    
    # Train
    train_loss, train_recon, train_kl = train_epoch(model, train_loader, optimizer, device, beta)
    train_losses.append(train_loss)
    train_recon_losses.append(train_recon)
    train_kl_losses.append(train_kl)
    
    print(f'Train Loss: {train_loss:.4f}, Recon: {train_recon:.4f}, KL: {train_kl:.4f}')
    
    # Generate samples every 5 epochs
    if (epoch + 1) % 5 == 0:
        model.eval()
        with torch.no_grad():
            # Generate from random latent vectors
            samples = model.generate(num_samples=64, device=device)
            save_image(samples, f'outputs/samples/epoch_{epoch+1}.png', nrow=8, normalize=True)
            
            # Reconstruct some test images
            test_data, _ = next(iter(test_loader))
            test_data = test_data[:8].to(device)
            recon_data, _, _, _ = model(test_data)
            comparison = torch.cat([test_data, recon_data], dim=0)
            save_image(comparison, f'outputs/samples/recon_epoch_{epoch+1}.png', nrow=8, normalize=True)
        
        print(f'Samples saved for epoch {epoch+1}')

print('\nTraining completed!')


## 6. Visualize Training Progress


In [None]:
# Plot training curves
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(train_losses)
plt.title('Total Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)

plt.subplot(1, 3, 2)
plt.plot(train_recon_losses)
plt.title('Reconstruction Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)

plt.subplot(1, 3, 3)
plt.plot(train_kl_losses)
plt.title('KL Divergence Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)

plt.tight_layout()
plt.savefig('outputs/training_curves.png', dpi=150)
plt.show()


## 7. Generate New Images


In [None]:
# Generate new images from random latent vectors
model.eval()
with torch.no_grad():
    generated = model.generate(num_samples=64, device=device)

# Display generated images
grid = make_grid(generated, nrow=8, normalize=True)
plt.figure(figsize=(12, 12))
plt.imshow(grid.permute(1, 2, 0).cpu())
plt.axis('off')
plt.title('Generated Images from Random Latent Vectors')
plt.tight_layout()
plt.savefig('outputs/generated_images.png', dpi=150, bbox_inches='tight')
plt.show()


## 8. Image Reconstruction


In [None]:
# Reconstruct test images
model.eval()
test_data, _ = next(iter(test_loader))
test_data = test_data[:16].to(device)

with torch.no_grad():
    reconstructed, mu, logvar, z = model(test_data)

# Display original vs reconstructed
comparison = torch.cat([test_data, reconstructed], dim=0)
grid = make_grid(comparison, nrow=8, normalize=True)
plt.figure(figsize=(15, 8))
plt.imshow(grid.permute(1, 2, 0).cpu())
plt.axis('off')
plt.title('Original (top) vs Reconstructed (bottom)')
plt.tight_layout()
plt.savefig('outputs/reconstruction_comparison.png', dpi=150, bbox_inches='tight')
plt.show()


## 9. Latent Space Interpolation


In [None]:
# Interpolate between two random points in latent space
model.eval()
num_steps = 10

z1 = torch.randn(1, latent_dim).to(device)
z2 = torch.randn(1, latent_dim).to(device)

alphas = torch.linspace(0, 1, num_steps).to(device)
interpolated_images = []

with torch.no_grad():
    for alpha in alphas:
        z_interp = (1 - alpha) * z1 + alpha * z2
        img = model.decoder(z_interp)
        interpolated_images.append(img)

# Display interpolation
result = torch.cat(interpolated_images, dim=0)
grid = make_grid(result, nrow=num_steps, normalize=True)
plt.figure(figsize=(15, 2))
plt.imshow(grid.permute(1, 2, 0).cpu())
plt.axis('off')
plt.title('Latent Space Interpolation')
plt.tight_layout()
plt.savefig('outputs/interpolation.png', dpi=150, bbox_inches='tight')
plt.show()


## 10. Save Model


In [None]:
# Save trained model
torch.save(model.state_dict(), 'vae_model.pth')
print('Model saved to vae_model.pth')
