Skip to content

rajesh-lab/MultivariateDiffusionModels

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Multivariate Diffusion Models

This package implements multivariate diffusion models (MDMs), a larger class of inference diffusion processes for diffusion-based generative models.

This repository implements the following paper:

Where to Diffuse, How to Diffuse, and How to Get Back: Automated Learning for Multivariate Diffusions [openreview].


Quickstart

To get started, you need to run the following command in the main directory:

pip install -e .

This command will install the MDM source code as a python package.

We use Weights and Biases for monitoring experiments. To start, first create an account on WandB and enter your API key locally.

Note: This codebase relies on pytorch lightning.


Training

For training a classifier, the command options are:

python mdm/bin/train.py 
  --config_path: Path to yaml file containing the experiment configuration
  --debug_mode: use only a fraction of data for debugging
  --offline: run the code without wandb logging for debugging

Grid Search

For doing a grid-search over hyper-parameters, we have a grid.py file. You can select which hyper-parameters to run there and start the grid search as follows:

python mdm/bin/grid.py 
  --config_path: Path to yaml file containing the experiment configuration
  --index: which index in the grid to run
  --debug_mode: use only a fraction of data for debugging
  --offline: run the code without wandb logging for debugging

New Diffusion Process

To define a new diffusion process, we need to specify two matrices, Q and D, which govern the diffusion process dynamics. For instance, to train a diffusion process with CLD, we just need to specify its Q and D matrices, without having to derive its transition mean and covariance matrices.

import torch
from .mdm import MDM


class CLD(MDM):
    def __init__(self, config):
        self.cld_gamma = config.cld_gamma

        q = torch.tensor([[0.0, -1.0], [1.0, 0.0]])
        d = torch.tensor([[0.0, 0.0], [0.0, self.cld_gamma]])

        super().__init__(config, q_fixed=q, d_fixed=d)


Configs

The config file contains the settings for your run. For now, you can choose between Learned-2,3, MALDA, ALDA, CLD and VP-SDE as the inference diffusion process.

seed: 0
wandb_project: PROJECT_NAME
wandb_entity: ENTITY_NAME

save_dir: DIRERCTORY_TO_SAVE_CHECKPOINTS
data_dir: DATA_DIRECTORY

# checkpoint
resume_checkpoint: False
checkpoint_path: PATH_TO_CHECKPOINT

# dataset configs
dataset: cifar
is_image: True
in_channels: 3
out_channels: 3
height: 32
width: 32

# monte carlo samples
elbo_mc_samples_eval: 1
elbo_offset_train: 1
elbo_mc_samples_train: 1
elbo_offset_eval: True
hutch_mc: 10

# sde 
sde_type: learned_2
n_vars: 2
stationary_aux_var: 1.0
stationary_x_var: 1.0
init_aux_var: 0.01

# beta's for SDE 
beta_0: 0.1
beta_1: 10
beta_fn_type: inhom

# score_model type
score_parameterization: noise_pred
score_model_type: HoUNet
dropout: 0.0

# data transform 
transform: logit

# batch size
batch_size: 128
test_batch_size: 128
accumulate_grad_batches: 1

# optimization
n_epochs: 300
lr: 0.0002
lr_scheduling: False
lr_sched_max_iters: 50000
grad_clip_val: 2.
optim_type: adam
warmup_iters: 5000
weight_decay: 0.0
use_ema: False
ema_decay: 0.9999

# objectives
val_loss_type: dsm_elbo
train_loss_type: dsm_elbo
imp_weight_train: True
switch_epoch: -1

# fid
val_fid_epoch: 400
fid_n_samples: 50000


# diffusion
hybrid_transition_kernel: True

T_max: 1.
T_min_sampling: 0.001
T_min_eval: 0.001
T_min_train: 0.001

# sampling params
log_image_step: 50
log_image_size: 8
n_FEs: 500
sampling_method: EM
sampling_t_arr_fn: quadratic

Citation

If you find this code useful, please cite our paper:

Singhal, R., Goldstein, M. and Ranganath, R., Where to Diffuse, How to Diffuse, and How to Get Back: Automated Learning for Multivariate Diffusions. In The Eleventh International Conference on Learning Representations.

Acknowledgements

We thank the authors of the following repositories for their code, which we have used in this repository:

  1. Yang Song: NCSNpp Code
  2. Jonathan Ho: UNet Code
  3. CW Huang: Logit Transform
  4. Xu ma: ImageNet Dataloader

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published