Skip to content

shayan-kousha/SurVAE

Repository files navigation

EECS6322-project

This project is a reproduction of the major results from two papers, SurVAE Flows: Surjections to Bridge the Gap between VAEs and Flows, and Normalizing Flows with Multi-Scale Autoregressive Priors as well as a stretch goal of implementing the idea of ProNF. The reproduction is made in JAX library. You can find the original codes for the first paper in this repository and the second paper in this repository. Both codes were implemented in PyTorch, and our repository contains a JAX implementation of them.

Oral Presentation

The below video is an oral presentation that illustrates and gives an overview of the scope of the project and the results.

oral_presentation.mp4

Dependencies

python3

pip install -r requirements.txt

JAX

pip install jax==0.2.8
pip install jaxlib==0.1.56+cuda100 -f https://storage.googleapis.com/jax-releases/jax_releases.html

Experiments Commands

Toy Datasets(AbsFlow Experiment)

Command for checkerboard:

python experiments/toy/train_abs_unif.py --hidden_units [200,100] --dataset checkerboard --clim 0.05

Command for corners:

python experiments/toy/train_abs_flow.py --hidden_units [200,100] --dataset corners --clim 0.1 --scale_fn softplus

Command for eightgaussians:

python experiments/toy/train_abs_flow.py --hidden_units [200,100] --dataset eight_gaussians --clim 0.15 --scale_fn softplus

Command for fourcircle:

python experiments/toy/train_abs_flow.py --hidden_units [200,100] --dataset fourcircle --clim 0.2 --scale_fn softplus

Max Pooling Experiment

Command for pool = none

python experiments/max_pooling/max_pooling_experiment.py --epochs 500 --batch_size 32 --optimizer adamax --lr 1e-4 --gamma 0.995 --eval_every 1 --check_every 10 --warmup 5000 --num_steps 12 --num_scales 2 --dequant flow --pooling none --dataset cifar10 --augmentation eta --name nonpool --model_dir ./experiments/max_pooling/checkpoints/

Command for pool = max

python experiments/max_pooling/max_pooling_experiment.py --epochs 500 --batch_size 32 --optimizer adamax --lr 1e-4 --gamma 0.995 --eval_every 1 --check_every 10 --warmup 5000 --num_steps 12 --num_scales 2 --dequant flow --pooling max --dataset cifar10 --augmentation eta --name maxpool --model_dir ./experiments/max_pooling/checkpoints/

MSAR-SCF Experiment

python experiments/msar_scf/train_msar_scf.py --ckptdir "experiments/msar_scf/ckpt_sigmoid" --activation "sigmoid" --resume True --num_epochs 3000

Stretch Goal First approach

## 16x16 => 32x32
python experiments/pro_nf/train_pronf.py --ckptdir "experiments/pro_nf/ckpt_32" --resume True --warmup 50000  --ms
## 8x8 => 16x16
python experiments/pro_nf/train_pronf.py --ckptdir "experiments/pro_nf/ckpt_16" --resume True --warmup 50000 --input_res 16 --num_layers 2 --ms --learning_rate 1e-4 
## 4x4 => 8x8
python experiments/pro_nf/train_pronf.py --ckptdir "experiments/pro_nf/ckpt_8" --resume True --warmup 50000 --input_res 8 --num_layers 2 --ms --learning_rate 1e-4
## 4x4 unconditional
python experiments/pro_nf/train_pronf.py --ckptdir "experiments/pro_nf/ckpt_4" --resume True --warmup 50000 --input_res 4 
## chain-up
python experiments/pro_nf/merge.py --ckptdir "experiments/pro_nf" --resume True

Stretch Goal Second approach

python experiments/pro_nf_2/pro_nf.py --batch_size 32 --augmentation eta --dataset cifar10 --image_size 32
python experiments/pro_nf_2/pro_nf.py --batch_size 32 --augmentation eta --dataset cifar10 --image_size 16 --smallest
python experiments/pro_nf_2/pro_nf.py --batch_size 32 --augmentation eta --dataset cifar10 --image_size 32 16 --resume

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages