Skip to content

tyuxie/VADD

Repository files navigation

Variational Autoencoding Discrete Diffusion with Enhanced Dimensional Correlations Modeling

ICLR 2026 Poster

Tianyu Xie, Shuchen Xue, Zijin Feng, Tianyang Hu, Jiacheng Sun, Zhenguo Li, Cheng Zhang

This repository implements VADD (Variational Autoencoding Discrete Diffusion), which extends MDLM (Masked Diffusion Language Models) with a variational autoencoder (VAE) framework. By introducing a latent variable $z$ into the masked diffusion process, the model captures global document-level semantics while retaining the token-level generative capability of discrete diffusion.


Code Structure

.
├── main.py                 # Entry point: training, PPL evaluation, sample generation
├── diffusion.py            # Diffusion process, VAE loss, sampling procedures
├── dataloader.py           # Dataset loading, tokenization, data utilities
├── noise_schedule.py       # Noise schedules (loglinear, linear, cosine, etc.)
├── utils.py                # LR scheduler, logging, fsspec helpers
├── models/
│   ├── __init__.py
│   ├── dit.py              # Transformer backbones: GenDIT (decoder), InferDIT (encoder)
│   ├── autoregressive.py   # AR baseline (for reference)
│   └── ema.py              # Exponential Moving Average
├── configs/
│   ├── config.yaml         # Main Hydra config
│   ├── data/               # Dataset configs (OpenWebText, WikiText, LM1B, etc.)
│   ├── model/              # Model size configs (small, medium, tiny)
│   ├── decoder/            # Decoder (encoder network) configs
│   ├── noise/              # Noise schedule configs
│   ├── lr_scheduler/       # Learning rate scheduler configs
│   ├── callbacks/          # Lightning callback configs
│   └── strategy/           # Distributed training strategy configs (DDP, FSDP)
├── scripts/
│   ├── train_owt_vadd.sh          # Single-node training (local GPUs)
│   └── eval_owt_vadd.sh           # Sample evaluation
├── requirements.yaml       # Conda environment specification
└── LICENSE

Architecture Overview

VADD consists of two main components:

  • Encoder (InferDIT): Takes clean text $x_0$ and noisy text $x_t$, produces latent distribution parameters $(\mu, \log\sigma)$.
  • Decoder (GenDIT): Takes noisy text $x_t$ and a sampled latent $z$, predicts the clean text $x_0$.

The latent variable $z$ is injected into the decoder via Adaptive Layer Normalization (AdaLN), conditioning the transformer blocks on both the diffusion timestep and the latent.


Getting Started

1. Environment Setup

conda env create -f requirements.yaml
conda activate vadd

Note: flash-attn requires a CUDA-compatible GPU and may need to be installed separately:

pip install flash-attn --no-build-isolation

2. Create Output Directories

mkdir -p outputs

Training

Single-node training (local GPUs)

# Usage: CUDA_VISIBLE_DEVICES=<gpus> bash scripts/train_owt_vadd.sh \
#   <batch_size> <seq_len> <num_particles> <encoder_training_times> \
#   <latent_struc> <mask_struc> <latent_dim>

# Example: 4 GPUs, batch_size=16, seq_len=1024, AdaLN structure, latent_dim=512
CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/train_owt_vadd.sh 16 1024 1 1 adaln adaln 512

Resume from checkpoint

python main.py \
  <your_training_args> \
  checkpointing.resume_from_ckpt=True \
  checkpointing.resume_ckpt_path=/path/to/checkpoint.ckpt

Key training hyperparameters

Parameter Description Default
training.loss_type Loss function vae
training.num_particles Number of particles for IWAE bound 1
training.loss_strategy KL weight strategy (klweight, none) klweight
training.init_kl_weight Initial KL weight for warmup 0.0
training.kl_weight_interval Steps to linearly anneal KL weight to 1.0 100000
model.latent_dim Latent dimension 512

Evaluation

Sample generation & generative perplexity

# Usage: bash scripts/eval_owt_vadd.sh <fixz: True/False> [checkpoint_path]
bash scripts/eval_owt_vadd.sh True /path/to/checkpoint.ckpt

This sweeps over diffusion steps (16, 32, 64, 128, 256, 512, 1024) and measures generative perplexity under GPT-2 Large.

Zero-shot perplexity evaluation

Evaluate on any dataset config under configs/data/:

python main.py \
  mode=ppl_eval \
  loader.batch_size=16 \
  loader.eval_batch_size=16 \
  data=lambada \
  model=small \
  parameterization=subs \
  backbone=dit \
  model.length=1024 \
  training.loss_type=vae \
  model.latentMLP=True \
  model.latent1d=True \
  model.latent_struc=adaln \
  model.latent_dim=512 \
  decoder.mask_struc=adaln \
  eval.checkpoint_path=/path/to/checkpoint.ckpt \
  +eval.ppl_times=5

Available zero-shot evaluation datasets:

  • lambada, wikitext2, wikitext103, ptb
  • ag_news, scientific_papers_arxiv, scientific_papers_pubmed
  • text8, lm1b

Sampling

Standard sampling (VAE decoder with prior $z \sim \mathcal{N}(0, I)$)

python main.py \
  mode=sample_eval \
  data=openwebtext-split \
  model=small \
  parameterization=subs \
  model.length=1024 \
  training.loss_type=vae \
  model.latentMLP=True \
  model.latent1d=True \
  model.latent_struc=adaln \
  model.latent_dim=512 \
  decoder.mask_struc=adaln \
  sampling.predictor=vae \
  sampling.steps=256 \
  sampling.fixz=True \
  loader.eval_batch_size=8 \
  sampling.num_sample_batches=4 \
  eval.checkpoint_path=/path/to/checkpoint.ckpt

Cached sampling (faster, with token caching)

Use sampling.predictor=vae_cache for faster generation:

python main.py \
  mode=sample_eval \
  ... \
  sampling.predictor=vae_cache \
  sampling.steps=1024 \
  eval.checkpoint_path=/path/to/checkpoint.ckpt

Acknowledgements

This repository is built upon MDLM and SEDD.

Citation

@inproceedings{
xie2026vadd,
title={Variational Autoencoding Discrete Diffusion with Enhanced Dimensional Correlations Modeling},
author={Tianyu Xie and Shuchen Xue and Zijin Feng and Tianyang Hu and Jiacheng Sun and Zhenguo Li and Cheng Zhang},
booktitle={The Fourteenth International Conference on Learning Representations},
year={2026},
url={https://openreview.net/forum?id=yh7MV2V0ba}
}

About

The official codebase for VADD (ICLR 2026)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors