# Experiments

This file presents a Jupyter notebook interface to the experiments. Each section contains a specific experiment, notes about the output which you'll see in the notebook when you execute the cell, as well as a cell with a command to run the experiment.

## table_1_minibatch_gradient_benchmark

(**CPU friendly**) This is okay to run on typical laptop CPU.

In [None]:
%run ../experiments/table_1_minibatch_gradient_benchmark/genjax_vae_overhead.py

  from .autonotebook import tqdm as notebook_tqdm


## table_2_benchmark_timings

(**GPU required**) This should be run on a GPU.

For this experiment, `pytest` will run, and then display a precise timing table (with timing statistics) for each training experiment.

In [None]:
! just table_2

[1mpoetry run pytest experiments/table_2_benchmark_timings --benchmark-disable-gc[0m
[33mThe currently activated Python version 3.12.2 is not supported by the project (3.10.13).
Trying to find and use a compatible version.[39m 
Using [36mpython3.10[39m (3.10.13)
platform darwin -- Python 3.10.13, pytest-8.0.2, pluggy-1.4.0
benchmark: 4.0.0 (defaults: timer=time.perf_counter disable_gc=True min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
rootdir: /Users/mccoybecker/research/programmable-vi-pldi-2024
plugins: jaxtyping-0.2.28, anyio-4.3.0, typeguard-2.13.3, benchmark-4.0.0
collected 5 items                                                              [0m[1m

experiments/table_2_benchmark_timings/test_genjax_enum_air_benchmark.py 

## table_4_objective_values

For each of the scripts and loss functions (ELBO, IWAE, HVI-ELBO, IWHVI, DIWHVI) -- the first number is the mean over training trials, and the second number is the standard deviation.

In [5]:
%run experiments/table_4_objective_values/genjax_cone.py
%run experiments/table_4_objective_values/genjax_cone_marginal.py
%run experiments/table_4_objective_values/numpyro_cone.py
%run experiments/table_4_objective_values/pyro_cone.py

ELBO:
(Array(-8.0759735, dtype=float32), Array(0.8189323, dtype=float32))
IWAE(K = 5):
(Array(-7.6744304, dtype=float32), Array(2.6599298, dtype=float32))
HVI-ELBO(N = 1):
(Array(-9.751298, dtype=float32), Array(0.9588627, dtype=float32))
HVIWAE(N = 5, K = 1):
(Array(-8.182691, dtype=float32), Array(0.9135368, dtype=float32))
HVIWAE(N = 5, K = 5):
(Array(-7.2983704, dtype=float32), Array(1.648208, dtype=float32))


  from .autonotebook import tqdm as notebook_tqdm
100%|█| 6000/6000 [00:00<00:00, 6672.05it/s, init loss: 450.7343, avg. 


NumPyro TraceGraph ELBO:
(Array(8.087665, dtype=float32), Array(0.11515263, dtype=float32))


100%|█| 6000/6000 [00:00<00:00, 8173.14it/s, init loss: 71.4061, avg. l


NumPyro RenyiELBO(k = 5):
(Array(7.8700557, dtype=float32), Array(1.9371673, dtype=float32))


Guessed max_plate_nesting = 1
Guessed max_plate_nesting = 1


Pyro ELBO:
(tensor(8.0826), tensor(0.1097))
Pyro IWAE(K = 5):
(tensor(7.8315), tensor(2.4558))


## fig_2_noisy_cone

The stream of numbers is the mean loss every 1000 iterations.

In [2]:
!python ./experiments/fig_2_noisy_cone/genjax_cone.py

-825.4227
-7.9878664
-8.03403
-8.206105
-7.992387
-1265.4695
-8.004253
-8.012936
-7.9295254
-8.150597
-3.8820481
-7.961704
-7.8041596
-6.529461
-7.2643843
-26.630383
-6.878557
-6.7059927
-7.016681
-6.9601326


## fig_7_air_estimator_evaluation