# Run VAE experiments

In [5]:
import sys, subprocess, platform
print('Python executable:', sys.executable)
print('Python version:', platform.python_version())
print('Verifying torch in current kernel...')
subprocess.run([sys.executable, '-m', 'pip', 'show', 'torch'])

Python executable: /Users/zoe/Desktop/VAEs/beta-tcvae/.venv/bin/python
Python version: 3.11.7
Verifying torch in current kernel...


CompletedProcess(args=['/Users/zoe/Desktop/VAEs/beta-tcvae/.venv/bin/python', '-m', 'pip', 'show', 'torch'], returncode=-9)

In [8]:
# Common arguments
DATASET = 'shapes'
BETA = 6
TCVAE = True
NUM_EPOCHS = 3
LOG_FREQ = 7
MWS_BATCH_SIZES = [32, 64, 256, 1024, 2048]

# W&B
USE_WANDB = True
WANDB_PROJECT = 'beta-tcvae'
WANDB_ENTITY = None
WANDB_MODE = 'online'

In [9]:
# single experiment
import sys, subprocess
bs = 64
cmd = [sys.executable, 'vae_quant.py',
       '--dataset', DATASET,
       '--beta', str(BETA),
       '--mws-batch-size', str(bs),
       '--num-epochs', str(NUM_EPOCHS),
       '--log_freq', str(LOG_FREQ)]
if TCVAE:
    cmd.append('--tcvae')
print('Running:', ' '.join(cmd))
subprocess.run(cmd, check=False)

Running: /Users/zoe/Desktop/VAEs/beta-tcvae/.venv/bin/python vae_quant.py --dataset shapes --beta 6 --mws-batch-size 64 --num-epochs 3 --log_freq 7 --tcvae


CompletedProcess(args=['/Users/zoe/Desktop/VAEs/beta-tcvae/.venv/bin/python', 'vae_quant.py', '--dataset', 'shapes', '--beta', '6', '--mws-batch-size', '64', '--num-epochs', '3', '--log_freq', '7', '--tcvae'], returncode=-9)

In [10]:
# Sweep over batch sizes and save logs
import os, sys, subprocess
os.makedirs('runs', exist_ok=True)

for bs in MWS_BATCH_SIZES:
    cmd = [sys.executable, 'vae_quant.py',
           '--dataset', DATASET,
           '--beta', str(BETA),
           '--mws-batch-size', str(bs),
           '--num-epochs', str(NUM_EPOCHS),
           '--log_freq', str(LOG_FREQ)]
    if TCVAE:
        cmd.append('--tcvae')
    if USE_WANDB:
        cmd += ['--wandb', '--wandb_project', WANDB_PROJECT, '--wandb_mode', WANDB_MODE]
        if WANDB_ENTITY:
            cmd += ['--wandb_entity', WANDB_ENTITY]
        cmd += ['--wandb_run_name', f'bs{bs}']
    print('\n=== Running batch-size', bs, '===\n', ' '.join(cmd))
    log_path = f'runs/bs{bs}.log'
    with open(log_path, 'w') as f:
        proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
        for line in proc.stdout:
            print(line, end='')
            f.write(line)
        proc.wait()
    print(f'[exit code {proc.returncode}] log saved to {log_path}')


=== Running batch-size 32 ===
 /Users/zoe/Desktop/VAEs/beta-tcvae/.venv/bin/python vae_quant.py --dataset shapes --beta 6 --mws-batch-size 32 --num-epochs 3 --log_freq 7 --tcvae --wandb --wandb_project beta-tcvae --wandb_mode online --wandb_run_name bs32
[exit code -9] log saved to runs/bs32.log

=== Running batch-size 64 ===
 /Users/zoe/Desktop/VAEs/beta-tcvae/.venv/bin/python vae_quant.py --dataset shapes --beta 6 --mws-batch-size 64 --num-epochs 3 --log_freq 7 --tcvae --wandb --wandb_project beta-tcvae --wandb_mode online --wandb_run_name bs64
[exit code -9] log saved to runs/bs64.log

=== Running batch-size 256 ===
 /Users/zoe/Desktop/VAEs/beta-tcvae/.venv/bin/python vae_quant.py --dataset shapes --beta 6 --mws-batch-size 256 --num-epochs 3 --log_freq 7 --tcvae --wandb --wandb_project beta-tcvae --wandb_mode online --wandb_run_name bs256
[exit code -9] log saved to runs/bs256.log

=== Running batch-size 1024 ===
 /Users/zoe/Desktop/VAEs/beta-tcvae/.venv/bin/python vae_quant.py --d