<a href="https://colab.research.google.com/github/verammaz/KMeans-VAE/blob/main/run_in_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# VAE Training Pipeline on Google Colab

This notebook provides a complete workflow for:
1. Setting up the environment
2. Generating synthetic datasets (Gaussian & Bernoulli)
3. Training a Variational Autoencoder (VAE)

## Setup & Installation

In [None]:
from google.colab import auth
auth.authenticate_user()

In [None]:
# Install required packages
! pip install -q torch tqdm matplotlib wandb

print("Packages installed successfully!")

In [None]:
# Check available device
import torch
import sys

print(f"Python version: {sys.version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")

## Clone GitHub Repo


In [None]:
! git clone https://github.com/verammaz/KMeans-VAE.git
%cd KMeans-VAE

## W&B Setup

If you want to log experiments to Weights & Biases:

In [40]:
USE_WANDB = True 

if USE_WANDB:
    import wandb
    wandb.login()
    print("W&B configured!")
else:
    print("W&B logging disabled. Set USE_WANDB=True to enable.")

W&B configured!


## Step 1: Generate Synthetic Datasets

Generate both Gaussian and Bernoulli mixture datasets.

In [41]:
# Specify parameters for data generation
DATA_DIR = './datasets'
k = 5
dims = 64

In [42]:
# Generate datasets
! python data/make_datasets.py \
    --k {k} \
    --dims {dims} \
    --target-mb 256 \
    --seed 42 \
    --outroot {DATA_DIR}

Per-dataset per-component counts: [102475, 102475, 102475, 102475, 102475] (dims=64)
Wrote ./datasets/gaussian_raw  (~127.05 MB)
Wrote ./datasets/bernoulli_raw   (~127.05 MB)
Total on disk â‰ˆ 254.10 MB (target 256.00 MB)


In [43]:
# Verify data generation
import os
import json

for dataset in ['gaussian_raw', 'bernoulli_raw']:
    path = f'{DATA_DIR}/{dataset}'
    if os.path.exists(path):
        with open(os.path.join(path, 'metadata.json')) as f:
            meta = json.load(f)
        print(f"\n{dataset}:")
        print(f"  Type: {meta['type']}")
        print(f"  Classes: {meta['k']}")
        print(f"  Dimensions: {meta['dims']}")
        print(f"  Samples per class: {meta['n_per']}")


gaussian_raw:
  Type: gaussian
  Classes: 5
  Dimensions: 64
  Samples per class: [102475, 102475, 102475, 102475, 102475]

bernoulli_raw:
  Type: bernoulli
  Classes: 5
  Dimensions: 64
  Samples per class: [102475, 102475, 102475, 102475, 102475]


## Step 2: Train VAE

Choose your configuration and train the model.

### Configuration Options

In [44]:
# Training configuration
CONFIG = {
    # Dataset
    'dataset': 'gaussian',  # 'gaussian' or 'bernoulli'

    # Model 
    'model_name': None,   # Auto-generated if None
    'latent_dim': 10,
    'hidden_dims': [128, 64],
    'kl_beta': 1.0,  # 1.0 = standard VAE, >1.0 = beta-VAE
    'activation': 'LeakyReLU',

    # Training
    'epochs': 1,
    'batch_size': 128,
    'lr': 3e-4,
    'optimizer': 'adam',

    # System
    'seed': 3407,
    'device': 'auto',  # 'auto', 'cuda', 'cpu'

    # W&B (if enabled)
    'use_wandb': USE_WANDB,
    'wandb_project': 'vae-colab-experiments',
    'wandb_name': None,  # Auto-generated if None
}

### Run Training

In [46]:
# Build command line arguments
data_dir = f"{DATA_DIR.replace('./', '')}/{CONFIG['dataset']}_raw"
hidden_dims_str = f"'{str(CONFIG['hidden_dims']).replace(' ', '')}'"

cmd = f"""
python -m vae.main \
    --data.data_dir={data_dir} \
    --model.latent_dim={CONFIG['latent_dim']} \
    --model.hidden_dims={hidden_dims_str} \
    --model.kl_beta={CONFIG['kl_beta']} \
    --model.activation={CONFIG['activation']} \
    --trainer.epochs={CONFIG['epochs']} \
    --trainer.batch_size={CONFIG['batch_size']} \
    --trainer.lr={CONFIG['lr']} \
    --trainer.optimizer={CONFIG['optimizer']} \
    --trainer.device={CONFIG['device']} \
    --system.seed={CONFIG['seed']} \
"""

# Add W&B flags if enabled
if CONFIG['use_wandb']:
    cmd += f" \
    --wandb.enabled=True \
    --wandb.project={CONFIG['wandb_project']}"
    if CONFIG['wandb_name']:
        cmd += f" \
    --wandb.name={CONFIG['wandb_name']}"

print("Training command:")
print(cmd)
print("\n" + "="*60)
print("Starting training...")
print("="*60 + "\n")

# Run training
!{cmd}

Training command:

python -m vae.main     --data.data_dir=datasets/gaussian_raw     --model.name=None
    --model.latent_dim=10     --model.hidden_dims='[128,64]'     --model.kl_beta=1.0     --model.activation=LeakyReLU     --trainer.epochs=1     --trainer.batch_size=128     --trainer.lr=0.0003     --trainer.optimizer=adam     --trainer.device=auto     --system.seed=3407      --wandb.enabled=True     --wandb.project=vae-colab-experiments

Starting training...

command line overwriting config attribute data.data_dir with datasets/gaussian_raw
command line overwriting config attribute model.name with None
Configuration:
system:
    seed: 3407
    out_dir: ./out
wandb:
    enabled: False
    project: kmeans-vae
    entity: None
    name: None
    tags: []
    notes: 
    log_freq: 10
data:
    data_dir: datasets/gaussian_raw
model:
    name: None
    latent_dim: 10
    hidden_dims: [128, 64]
    likelihood: gaussian
    kl_beta: 1.0
    seed: 42
    activation: LeakyReLU
trainer:
    devi

In [37]:
if CONFIG['model_name'] is None:
    dataset = 'gaus' if CONFIG['dataset'] == 'gaussian' else 'ber'
    MODEL_NAME = f"vae_{dataset}_i{dims}_k{k}_z{CONFIG['latent_dim']}_beta{CONFIG['kl_beta']}"
else:
    MODEL_NAME = CONFIG['model_name']

if USE_WANDB:
    RUN_NAME = MODEL_NAME if CONFIG['wandb_name'] is None else CONFIG['wandb_name']


print("Model name: ", MODEL_NAME)

Model name:  vae_gaus_i64_k5_z10_beta1.0


## Step 3: Analyze Results

Load the trained model.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from vae.model import VAE
from data.data_io import load_and_split

# Load trained model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
checkpoint = torch.load(os.path.join('./out', f'{MODEL_NAME}/model.pt'), map_location=device)
config = checkpoint['config']

print("Model configuration:")
print(f"  Latent dim: {config['model']['latent_dim']}")
print(f"  Hidden dims: {config['model']['hidden_dims']}")
print(f"  Beta: {config['model']['kl_beta']}")
print(f"\nTest statistics:")
for k, v in checkpoint['test_stats'].items():
    print(f"  {k}: {v:.4f}")

In [None]:
# Recreate model
model_config = config['model']
input_dim = checkpoint['model_state_dict']['mean.weight'].shape[1]

model = VAE(
    input_dim=input_dim,
    latent_dim=model_config['latent_dim'],
    hidden_dims=model_config['hidden_dims'],
    likelihood=model_config['likelihood'],
    beta=model_config['kl_beta'],
    activation=model_config.get('activation', 'LeakyReLU')
)

model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()

print("Model loaded successfully!")

In [None]:
# Load test data
data_dir = config['data']['data_dir']
data = load_and_split(data_dir, normalize=True)

X_test = torch.tensor(data['X_test'], dtype=torch.float32).to(device)
y_test = torch.tensor(data['y_test'], dtype=torch.long)

print(f"Test set: {X_test.shape[0]} samples, {X_test.shape[1]} features")

## Download Results

#### To Local Machine

In [None]:
# Zip the output directory
!zip -r vae_results.zip out/

print("Results zipped!")
print("Download 'vae_results.zip' from the file browser on the left.")

# Zip the data directory
!zip -r data.zip {DATA_DIR}

print("Data zipped!")
print("Download 'data.zip' from the file browser on the left.")

#### To Google Drive

In [None]:
# Google Drive 
from google.colab import drive
drive.mount('/content/drive')
!cp -r out/ /content/drive/MyDrive/vae/

!cp -r {DATA_DIR} /content/drive/MyDrive/data/

#### From W&B

In [39]:
api = wandb.Api()

# Replace these:
entity = "vmm2146-columbia-university"      # your wandb username or team
project = CONFIG['wandb_project']           # your wandb project name
run_id = "f10w6e5f"                         # your run id

run = api.run(f"{entity}/{project}/runs/{run_id}")

# List files stored in this run
for f in run.files():
    print(f.name)

# Download a specific file (e.g., model checkpoint)
run.file(f"out/{MODEL_NAME}/model.pt").download(root=".", replace=True)

# List and download all files matching that prefix
for f in run.files():
    if f.name.startswith(DATA_DIR):
        print(f"Downloading {f.name} ...")
        f.download(root=".", replace=True)


artifact/2148009639/wandb_manifest.json
config.yaml
data/data_set/gaussian_raw/centers.npy
data/data_set/gaussian_raw/gauss_c0_X.npy
data/data_set/gaussian_raw/gauss_c0_y.npy
data/data_set/gaussian_raw/gauss_c1_X.npy
data/data_set/gaussian_raw/gauss_c1_y.npy
data/data_set/gaussian_raw/gauss_c2_X.npy
data/data_set/gaussian_raw/gauss_c2_y.npy
data/data_set/gaussian_raw/gauss_c3_X.npy
data/data_set/gaussian_raw/gauss_c3_y.npy
data/data_set/gaussian_raw/gauss_c4_X.npy
data/data_set/gaussian_raw/gauss_c4_y.npy
data/data_set/gaussian_raw/metadata.json
data/data_set/gaussian_raw/splits.json
data/data_set/gaussian_raw/variances.npy
out/vae_gaus_i64_k5_z10_beta1.0/model.pt
output.log
requirements.txt
wandb-metadata.json
wandb-summary.json


<_io.TextIOWrapper name='./out/vae_gaus_i64_k5_z10_beta1.0/model.pt' mode='r' encoding='UTF-8'>

## Experiment: Compare Different Beta Values

Run a quick sweep to see the effect of different beta values.

In [None]:
# Sweep different beta values
beta_values = [0.5, 1.0, 2.0, 4.0]

for beta in beta_values:
    print(f"\n{'='*60}")
    print(f"Training with beta = {beta}")
    print(f"{'='*60}\n")

    cmd = f"""python main.py \\
        --data.data_dir=./data_set/gaussian_raw \\
        --model.kl_beta={beta} \\
        --model.latent_dim=10 \\
        --trainer.epochs=30 \\
        --trainer.batch_size=128 \\
        --system.out_dir=./out/vae_beta_{beta}"""

    if USE_WANDB:
        cmd += f" \\\n        --wandb.enabled=True \\\n        --wandb.name=beta_{beta}"

    # Execute command
    import os
    os.system(cmd.replace('\\\n', ' '))

print("\nBeta sweep complete! Check out/vae_beta_* directories for results.")