In [23]:
import os
import yaml
import argparse
import numpy as np
from pathlib import Path
from models import *
from experiment import VAEXperiment
import torch.backends.cudnn as cudnn
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from dataset import VAEDataset
from pytorch_lightning.plugins import DDPPlugin
from torch.utils.data import DataLoader



# Model setup

In [24]:
print(torch.cuda.device_count())       # Number of GPUs available
print(torch.cuda.get_device_name(0))   # Name of GPU at index 0

1
NVIDIA GeForce RTX 4060 Ti


In [25]:
config = {
    'model_params': {'name': 'VanillaVAE',
                     'in_channels': 1, # for MNIST
                     #'in_channels': 3,#for CelebA
                     'latent_dim': 128},
    'data_params': {
        'data_path': "Data/",
        'train_batch_size': 64,
        'val_batch_size': 64,
        'num_workers': 4,        
    },
    'exp_params': {
        'LR': 0.005,
        'weight_decay': 0.0,
        'scheduler_gamma': 0.95,
        'kld_weight': 0.00025,
        'manual_seed': 1265
    },
    'trainer_params': {'max_epochs': 100},
    'logging_params': {'save_dir': 'logs/', 'name': 'VanillaVAE'}
}

In [26]:
# 2. Initialize logger
logger = TensorBoardLogger(save_dir=config['logging_params']['save_dir'],
                           name=config['model_params']['name'])

# 3. Set random seed for reproducibility
seed_everything(config['exp_params']['manual_seed'], True)


Global seed set to 1265


1265

## Loading data 
MNIST data is 1 channel 28\*28 pixels and celeA is 3 channels 64\*64 pixels, so the loading class is different


### run this cell when training MNIST

In [32]:
#for loading MNIST data
from dataset import MNISTDataModule

resize = False
input_res = 28
hidden_dims = [32, 64, 128, 256]
if resize:
    input_res = 64
    hidden_dims = None
data = MNISTDataModule(
    data_dir=config['data_params']['data_path'],
    train_batch_size=config['data_params']['train_batch_size'],
    val_batch_size=config['data_params']['val_batch_size'],
    num_workers=config['data_params']['num_workers'],
    pin_memory=torch.cuda.is_available(),
    resize_mnist = resize
)
data.setup()

# 4. Instantiate model using name key from config
model = VanillaVAE(in_channels=config['model_params']['in_channels'],latent_dim=config['model_params']['latent_dim'],
                   input_height=input_res,input_width=input_res,hidden_dims=hidden_dims)


### run this cell when training on CelebA
for celeA data, torch vision can not download it directly, we can download the data manually from: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
download alig & cropped Images.zip

into working directory: VAE_ModelCollapse/Data/celeba
the orgnization of the data directory will look like:

VAE_ModelCollapse/Data/celeba/img_align_celeba (all the face images are here)
VAE_ModelCollapse/Data/celeba/identity_CelebA.txt
VAE_ModelCollapse/Data/celeba/list_attr_celeba.txt
VAE_ModelCollapse/Data/celeba/list_attr_celeba.txt 
VAE_ModelCollapse/Data/celeba/list_eval_partition.txt          
VAE_ModelCollapse/Data/celeba/list_landmarks_celeba.txt
VAE_ModelCollapse/Data/celeba/list_bbox_celeba.txt
VAE_ModelCollapse/Data/celeba/list_landmarks_align_celeba.txt



In [30]:
#for loading celebA data

input_res = 64
hidden_dims = None
data = VAEDataset(
    data_path=config['data_params']['data_path'],
    train_batch_size=config['data_params']['train_batch_size'],
    val_batch_size=config['data_params']['val_batch_size'],
    patch_size=(input_res, input_res),
    num_workers=config['data_params']['num_workers'],
    pin_memory=torch.cuda.is_available()
)
data.setup()
# 4. Instantiate model using name key from config
model = VanillaVAE(in_channels=config['model_params']['in_channels'],latent_dim=config['model_params']['latent_dim'],
                   input_height=input_res,input_width=input_res,hidden_dims=hidden_dims)


In [33]:
# 5. Wrap in LightningModule for training logic
experiment = VAEXperiment(model, config['exp_params'])
# 6. Get a batch from the training dataloader
batch = next(iter(data.train_dataloader()))
images, labels = batch  # MNIST returns (images, labels)
print("Batch image shape:", images.shape)

Batch image shape: torch.Size([64, 1, 28, 28])


In [34]:
# 7. Define trainer and callbacks
trainer = Trainer(
    logger=logger,
    callbacks=[
        LearningRateMonitor(),
        ModelCheckpoint(save_top_k=2, 
                        dirpath=os.path.join(logger.log_dir, "checkpoints"), 
                        monitor="val_loss",
                        save_last=True)
    ],
    #strategy=DDPPlugin(find_unused_parameters=False),
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    devices=1,
    # **config['trainer_params']
    max_epochs=config['trainer_params']['max_epochs'],
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [38]:
# 8. Create folders for outputs
Path(f"{logger.log_dir}/Samples").mkdir(exist_ok=True, parents=True)
Path(f"{logger.log_dir}/Reconstructions").mkdir(exist_ok=True, parents=True)
logger.log_dir

'logs/VanillaVAE/version_13'

# Training

In [37]:
# 9. Train the model
trainer.fit(experiment, datamodule=data)

# Generate Samples

In [20]:

model.eval()
row_num = 5
device = next(model.parameters()).device
with torch.no_grad():
    samples = model.sample(row_num*row_num,current_device=device)  # Generate 16 samples (returns tensor of shape [16, 1, 64, 64])
#imgs = (samples+1)/2

In [21]:
imgs.shape

torch.Size([25, 3, 64, 64])

In [42]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(row_num, row_num, figsize=(6, 6))
for i, ax in enumerate(axes.flat):
    # for celebA
    #ax.imshow(imgs[i].cpu().numpy().transpose(1, 2, 0), cmap='gray')    
    ax.imshow(imgs[i].cpu().numpy().squeeze(), cmap='gray')
    ax.axis('off')
plt.tight_layout()
plt.show()
