Skip to content

willisma/SiT

Repository files navigation

Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers (SiT)
Official PyTorch Implementation

SiT samples

This repo contains PyTorch model definitions, pre-trained weights and training/sampling code for our paper exploring interpolant models with scalable transformers (SiTs).

Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers
Nanye Ma, Mark Goldstein, Michael Albergo, Nicholas Boffi, Eric Vanden-Eijnden, Saining Xie
New York University

We present Scalable Interpolant Transformers (SiT), a family of generative models built on the backbone of Diffusion Transformers (DiT). The interpolant framework, which allows for connecting two distributions in a more flexible way than standard diffusion models, makes possible a modular study of various design choices impacting generative models built on dynamical transport: using discrete vs. continuous time learning, deciding the model to learn, choosing the interpolant connecting the distributions, and deploying a deterministic or stochastic sampler. By carefully introducing the above ingredients, SiT surpasses DiT uniformly across model sizes on the conditional ImageNet 256x256 benchmark using the exact same backbone, number of parameters, and GFLOPs. By exploring various diffusion coefficients, which can be tuned separately from learning, SiT achieves an FID-50K score of 2.06.

This repository contains:

  • 🪐 A simple PyTorch implementation of SiT
  • ⚡️ Pre-trained class-conditional SiT models trained on ImageNet 256x256
  • 🛸 A SiT training script using PyTorch DDP

Setup

First, download and set up the repo:

git clone https://github.com/willisma/SiT.git
cd SiT

We provide an environment.yml file that can be used to create a Conda environment. If you only want to run pre-trained models locally on CPU, you can remove the cudatoolkit and pytorch-cuda requirements from the file.

conda env create -f environment.yml
conda activate SiT

Sampling Open In Colab

More SiT samples

Pre-trained SiT checkpoints. You can sample from our pre-trained SiT models with sample.py. Weights for our pre-trained SiT model will be automatically downloaded depending on the model you use. The script has various arguments to adjust sampler configurations (ODE & SDE), sampling steps, change the classifier-free guidance scale, etc. For example, to sample from our 256x256 SiT-XL model with default ODE setting, you can use:

python sample.py ODE --image-size 256 --seed 1

For convenience, our pre-trained SiT models can be downloaded directly here as well:

SiT Model Image Resolution FID-50K Inception Score Gflops
XL/2 256x256 2.06 270.27 119

Custom SiT checkpoints. If you've trained a new SiT model with train.py (see below), you can add the --ckpt argument to use your own checkpoint instead. For example, to sample from the EMA weights of a custom 256x256 SiT-L/4 model with ODE sampler, run:

python sample.py ODE --model SiT-L/4 --image-size 256 --ckpt /path/to/model.pt

Advanced sampler settings

ODE --atol float Absolute error tolerance
--rtol float Relative error tolenrace
--sampling-method str Sampling methods (refer to torchdiffeq )
SDE --diffusion-form str Form of SDE's diffusion coefficient (refer to Tab. 2 in paper)
--diffusion-norm float Magnitude of SDE's diffusion coefficient
--last-step str Form of SDE's last step
None - Single SDE integration step
"Mean" - SDE integration step without diffusion coefficient
"Tweedie" - Tweedie's denoising step
"Euler" - Single ODE integration step
--sampling-method str Sampling methods
"Euler" - First order integration
"Heun" - Second order integration

There are some more options; refer to train_utils.py for details.

Training SiT

We provide a training script for SiT in train.py. To launch SiT-XL/2 (256x256) training with N GPUs on one node:

torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-XL/2 --data-path /path/to/imagenet/train

Logging. To enable wandb, firstly set WANDB_KEY, ENTITY, and PROJECT as environment variables:

export WANDB_KEY="key"
export ENTITY="entity name"
export PROJECT="project name"

Then in training command add the --wandb flag:

torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-XL/2 --data-path /path/to/imagenet/train --wandb

Interpolant settings. We also support different choices of interpolant and model predictions. For example, to launch SiT-XL/2 (256x256) with Linear interpolant and noise prediction:

torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-XL/2 --data-path /path/to/imagenet/train --path-type Linear --prediction noise

Resume training. To resume training from custom checkpoint:

torchrun --nnodes=1 --nproc_per_node=N train.py --model SiT-L/2 --data-path /path/to/imagenet/train --ckpt /path/to/model.pt

Caution. Resuming training will automatically restore both model, EMA, and optimizer states and training configs to be the same as in the checkpoint.

Evaluation (FID, Inception Score, etc.)

We include a sample_ddp.py script which samples a large number of images from a SiT model in parallel. This script generates a folder of samples as well as a .npz file which can be directly used with ADM's TensorFlow evaluation suite to compute FID, Inception Score and other metrics. For example, to sample 50K images from our pre-trained SiT-XL/2 model over N GPUs under default ODE sampler settings, run:

torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py ODE --model SiT-XL/2 --num-fid-samples 50000

Likelihood. Likelihood evaluation is supported. To calculate likelihood, you can add the --likelihood flag to ODE sampler:

torchrun --nnodes=1 --nproc_per_node=N sample_ddp.py ODE --model SiT-XL/2 --likelihood

Notice that only under ODE sampler likelihood can be calculated; see sample_ddp.py for more details and settings.

Enhancements

Training (and sampling) could likely be speed-up significantly by:

  • using Flash Attention in the SiT model
  • using torch.compile in PyTorch 2.0

Basic features that would be nice to add:

  • Monitor FID and other metrics
  • AMP/bfloat16 support

Precision in likelihood calculation could likely be improved by:

  • Uniform / Gaussian Dequantization

Differences from JAX

Our models were originally trained in JAX on TPUs. The weights in this repo are ported directly from the JAX models. There may be minor differences in results stemming from sampling on different platforms (TPU vs. GPU). We observed that sampling on TPU performs marginally worse than GPU (2.15 FID versus 2.06 in the paper).

License

This project is under the MIT license. See LICENSE for details.

About

Official PyTorch Implementation of "SiT: Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published