# Run VAE experiments

In [None]:
import sys, subprocess, platform
print('Python executable:', sys.executable)
print('Python version:', platform.python_version())
print('Verifying torch in current kernel...')
subprocess.run([sys.executable, '-m', 'pip', 'show', 'torch'])

In [None]:
# Common arguments
DATASET = 'shapes'
BETA = 6
TCVAE = True
NUM_EPOCHS = 1
LOG_FREQ = 50
BATCH_SIZES = [32, 64, 128, 256]

# W&B
USE_WANDB = True
WANDB_PROJECT = 'beta-tcvae'
WANDB_ENTITY = None
WANDB_MODE = 'online'

In [None]:
# single experiment
import sys, subprocess
bs = 64
cmd = [sys.executable, 'vae_quant.py',
       '--dataset', DATASET,
       '--beta', str(BETA),
       '--batch-size', str(bs),
       '--num-epochs', str(NUM_EPOCHS),
       '--log_freq', str(LOG_FREQ)]
if TCVAE:
    cmd.append('--tcvae')
print('Running:', ' '.join(cmd))
subprocess.run(cmd, check=False)

In [None]:
# Sweep over batch sizes and save logs
import os, sys, subprocess, datetime
os.makedirs('runs', exist_ok=True)
ts = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')

for bs in BATCH_SIZES:
    cmd = [sys.executable, 'vae_quant.py',
           '--dataset', DATASET,
           '--beta', str(BETA),
           '--batch-size', str(bs),
           '--num-epochs', str(NUM_EPOCHS),
           '--log_freq', str(LOG_FREQ)]
    if TCVAE:
        cmd.append('--tcvae')
    if USE_WANDB:
        cmd += ['--wandb', '--wandb_project', WANDB_PROJECT, '--wandb_mode', WANDB_MODE]
        if WANDB_ENTITY:
            cmd += ['--wandb_entity', WANDB_ENTITY]
        cmd += ['--wandb_run_name', f'bs{bs}_{ts}']
    print('\n=== Running batch-size', bs, '===\n', ' '.join(cmd))
    log_path = f'runs/bs{bs}_{ts}.log'
    with open(log_path, 'w') as f:
        proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
        for line in proc.stdout:
            print(line, end='')
            f.write(line)
        proc.wait()
    print(f'[exit code {proc.returncode}] log saved to {log_path}')