Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers (SiT)
Official PyTorch Implementation
Paper | Project Page |
This repo contains PyTorch model definitions, pre-trained weights and training/sampling code for our paper exploring interpolant models with scalable transformers (SiTs).
Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers
Nanye Ma, Mark Goldstein, Michael Albergo, Nicholas Boffi, Eric Vanden-Eijnden, Saining Xie
New York University
We present Scalable Interpolant Transformers (SiT), a family of generative models built on the backbone of Diffusion Transformers (DiT). The interpolant framework, which allows for connecting two distributions in a more flexible way than standard diffusion models, makes possible a modular study of various design choices impacting generative models built on dynamical transport: using discrete vs. continuous time learning, deciding the model to learn, choosing the interpolant connecting the distributions, and deploying a deterministic or stochastic sampler. By carefully introducing the above ingredients, SiT surpasses DiT uniformly across model sizes on the conditional ImageNet 256x256 benchmark using the exact same backbone, number of parameters, and GFLOPs. By exploring various diffusion coefficients, which can be tuned separately from learning, SiT achieves an FID-50K score of 2.06.
This repository contains:
- 🪐 A simple PyTorch implementation of SiT
- ⚡️ Pre-trained class-conditional SiT models trained on ImageNet 256x256
- 🛸 A SiT training script using PyTorch DDP
First, download and set up the repo:
git clone https://github.com/willisma/SiT.git
cd SiT
We provide an environment.yml
file that can be used to create a Conda environment. If you only want
to run pre-trained models locally on CPU, you can remove the cudatoolkit
and pytorch-cuda
requirements from the file.
conda env create -f environment.yml
conda activate SiT
Pre-trained SiT checkpoints. You can sample from our pre-trained SiT models with sample.py
. Weights for our pre-trained SiT model will be
automatically downloaded depending on the model you use. The script has various arguments to adjust sampler configurations (ODE & SDE), sampling steps, change the classifier-free guidance scale, etc. For example, to sample from
our 256x256 SiT-XL model with default ODE setting, you can use:
python sample.py ODE --image-size 256 --seed 1
For convenience, our pre-trained SiT models can be downloaded directly here as well:
SiT Model | Image Resolution | FID-50K | Inception Score | Gflops |
---|---|---|---|---|
XL/2 | 256x256 | 2.06 | 270.27 | 119 |
Custom SiT checkpoints. If you've trained a new SiT model with train.py
(see below), you can add the --ckpt
argument to use your own checkpoint instead. For example, to sample from the EMA weights of a custom
256x256 SiT-L/4 model with ODE sampler, run:
python sample.py ODE --model SiT-L/4 --image-size 256 --ckpt /path/to/model.pt
ODE | --atol |
float |
Absolute error tolerance |
--rtol |
float |
Relative error tolenrace | |
--sampling-method |
str |
Sampling methods (refer to torchdiffeq ) |
SDE | --diffusion-form |
str |
Form of SDE's diffusion coefficient (refer to Tab. 2 in paper) |
--diffusion-norm |
float |
Magnitude of SDE's diffusion coefficient | |
--last-step |
str |
Form of SDE's last step | |
None - Single SDE integration step | |||
"Mean" - SDE integration step without diffusion coefficient | |||
"Tweedie" - Tweedie's denoising step | |||
"Euler" - Single ODE integration step | |||
--sampling-method |
str |
Sampling methods | |
"Euler" - First order integration | |||
"Heun" - Second order integration |
There are some more options; refer to train_utils.py
for details.
We provide a training script for SiT in train.py
. To launch SiT-XL/2 (256x256) training with N
GPUs on
one node:
torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-XL/2 --data-path /path/to/imagenet/train
Logging. To enable wandb
, firstly set WANDB_KEY
, ENTITY
, and PROJECT
as environment variables:
export WANDB_KEY="key"
export ENTITY="entity name"
export PROJECT="project name"
Then in training command add the --wandb
flag:
torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-XL/2 --data-path /path/to/imagenet/train --wandb
Interpolant settings. We also support different choices of interpolant and model predictions. For example, to launch SiT-XL/2 (256x256) with Linear
interpolant and noise
prediction:
torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-XL/2 --data-path /path/to/imagenet/train --path-type Linear --prediction noise
Resume training. To resume training from custom checkpoint:
torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-L/2 --data-path /path/to/imagenet/train --ckpt /path/to/model.pt
Caution. Resuming training will automatically restore both model, EMA, and optimizer states and training configs to be the same as in the checkpoint.
We include a sample_ddp.py
script which samples a large number of images from a SiT model in parallel. This script
generates a folder of samples as well as a .npz
file which can be directly used with ADM's TensorFlow
evaluation suite to compute FID, Inception Score and
other metrics. For example, to sample 50K images from our pre-trained SiT-XL/2 model over N
GPUs under default ODE sampler settings, run:
torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py ODE --model SiT-XL/2 --num-fid-samples 50000
Likelihood. Likelihood evaluation is supported. To calculate likelihood, you can add the --likelihood
flag to ODE sampler:
torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py ODE --model SiT-XL/2 --likelihood
Notice that only under ODE sampler likelihood can be calculated; see sample_ddp.py
for more details and settings.
Training (and sampling) could likely be speed-up significantly by:
- using Flash Attention in the SiT model
- using
torch.compile
in PyTorch 2.0
Basic features that would be nice to add:
- Monitor FID and other metrics
- AMP/bfloat16 support
Precision in likelihood calculation could likely be improved by:
- Uniform / Gaussian Dequantization
Our models were originally trained in JAX on TPUs. The weights in this repo are ported directly from the JAX models. There may be minor differences in results stemming from sampling on different platforms (TPU vs. GPU). We observed that sampling on TPU performs marginally worse than GPU (2.15 FID versus 2.06 in the paper).
This project is under the MIT license. See LICENSE for details.