# Module 5 Class Activity â€” Add VAE to Helper Library

This notebook trains a Variational Autoencoder (VAE) using your `helper_lib` package and visualizes generated samples.

## 0. Environment Check
Install missing packages if needed (run only if they are not already installed).

In [None]:
# If needed, uncomment:
# !pip install torch torchvision matplotlib
import torch, torchvision, matplotlib
print('Torch:', torch.__version__)
print('CUDA available:', torch.cuda.is_available())


## 1. Imports

In [None]:
from helper_lib.data_loader import get_data_loader
from helper_lib.trainer import train_model, train_vae_model
from helper_lib.model import get_model
from helper_lib.generator import generate_samples

import torch
import torch.nn as nn
import torch.optim as optim


## 2. Data Loaders
Load MNIST (or your course default dataset) via `helper_lib.data_loader`. Change batch size as you like.

In [None]:
train_loader = get_data_loader(root='data', batch_size=128, train=True)
test_loader  = get_data_loader(root='data', batch_size=128, train=False)
len(train_loader), len(test_loader)

## 3. Build VAE Model
We construct the VAE via `get_model("VAE", latent_dim=20)`.

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
vae = get_model("VAE", latent_dim=20)
optimizer = optim.Adam(vae.parameters(), lr=1e-3)
vae


## 4. Train VAE
`train_vae_model` implements BCE+KLD loss internally. Adjust `epochs` and `beta` if needed.

In [None]:
_ = train_vae_model(vae, train_loader, optimizer, device=device, epochs=5, beta=1.0)

## 5. Generate Samples
Sample random latent vectors, decode, and visualize a grid of generated images. Also save to file if you like.

In [None]:
# Display grid
generate_samples(vae, device=device, num_samples=16, seed=42)

# Optionally save:
# generate_samples(vae, device=device, num_samples=16, seed=42, save_path='vae_samples.png')


## 6. (Optional) Reconstruction vs Original
Visualize how well the model reconstructs input images.

In [None]:
import matplotlib.pyplot as plt
vae.eval()
images, _ = next(iter(test_loader))
images = images.to(device)
with torch.no_grad():
    recon, mu, logvar = vae(images)

# show first 8 originals and reconstructions
n = 8
fig, axes = plt.subplots(2, n, figsize=(n*1.5, 3))
for i in range(n):
    axes[0, i].imshow(images[i,0].detach().cpu(), cmap='gray')
    axes[0, i].axis('off')
    axes[1, i].imshow(recon[i,0].detach().cpu().clamp(0,1), cmap='gray')
    axes[1, i].axis('off')
axes[0,0].set_title('Originals')
axes[1,0].set_title('Reconstructions')
plt.tight_layout()
plt.show()