This repository implements VADD (Variational Autoencoding Discrete Diffusion), which extends MDLM (Masked Diffusion Language Models) with a variational autoencoder (VAE) framework. By introducing a latent variable
.
├── main.py # Entry point: training, PPL evaluation, sample generation
├── diffusion.py # Diffusion process, VAE loss, sampling procedures
├── dataloader.py # Dataset loading, tokenization, data utilities
├── noise_schedule.py # Noise schedules (loglinear, linear, cosine, etc.)
├── utils.py # LR scheduler, logging, fsspec helpers
├── models/
│ ├── __init__.py
│ ├── dit.py # Transformer backbones: GenDIT (decoder), InferDIT (encoder)
│ ├── autoregressive.py # AR baseline (for reference)
│ └── ema.py # Exponential Moving Average
├── configs/
│ ├── config.yaml # Main Hydra config
│ ├── data/ # Dataset configs (OpenWebText, WikiText, LM1B, etc.)
│ ├── model/ # Model size configs (small, medium, tiny)
│ ├── decoder/ # Decoder (encoder network) configs
│ ├── noise/ # Noise schedule configs
│ ├── lr_scheduler/ # Learning rate scheduler configs
│ ├── callbacks/ # Lightning callback configs
│ └── strategy/ # Distributed training strategy configs (DDP, FSDP)
├── scripts/
│ ├── train_owt_vadd.sh # Single-node training (local GPUs)
│ └── eval_owt_vadd.sh # Sample evaluation
├── requirements.yaml # Conda environment specification
└── LICENSE
VADD consists of two main components:
-
Encoder (InferDIT): Takes clean text
$x_0$ and noisy text$x_t$ , produces latent distribution parameters$(\mu, \log\sigma)$ . -
Decoder (GenDIT): Takes noisy text
$x_t$ and a sampled latent$z$ , predicts the clean text$x_0$ .
The latent variable
conda env create -f requirements.yaml
conda activate vaddNote:
flash-attnrequires a CUDA-compatible GPU and may need to be installed separately:pip install flash-attn --no-build-isolation
mkdir -p outputs# Usage: CUDA_VISIBLE_DEVICES=<gpus> bash scripts/train_owt_vadd.sh \
# <batch_size> <seq_len> <num_particles> <encoder_training_times> \
# <latent_struc> <mask_struc> <latent_dim>
# Example: 4 GPUs, batch_size=16, seq_len=1024, AdaLN structure, latent_dim=512
CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/train_owt_vadd.sh 16 1024 1 1 adaln adaln 512python main.py \
<your_training_args> \
checkpointing.resume_from_ckpt=True \
checkpointing.resume_ckpt_path=/path/to/checkpoint.ckpt| Parameter | Description | Default |
|---|---|---|
training.loss_type |
Loss function | vae |
training.num_particles |
Number of particles for IWAE bound | 1 |
training.loss_strategy |
KL weight strategy (klweight, none) |
klweight |
training.init_kl_weight |
Initial KL weight for warmup | 0.0 |
training.kl_weight_interval |
Steps to linearly anneal KL weight to 1.0 | 100000 |
model.latent_dim |
Latent dimension | 512 |
# Usage: bash scripts/eval_owt_vadd.sh <fixz: True/False> [checkpoint_path]
bash scripts/eval_owt_vadd.sh True /path/to/checkpoint.ckptThis sweeps over diffusion steps (16, 32, 64, 128, 256, 512, 1024) and measures generative perplexity under GPT-2 Large.
Evaluate on any dataset config under configs/data/:
python main.py \
mode=ppl_eval \
loader.batch_size=16 \
loader.eval_batch_size=16 \
data=lambada \
model=small \
parameterization=subs \
backbone=dit \
model.length=1024 \
training.loss_type=vae \
model.latentMLP=True \
model.latent1d=True \
model.latent_struc=adaln \
model.latent_dim=512 \
decoder.mask_struc=adaln \
eval.checkpoint_path=/path/to/checkpoint.ckpt \
+eval.ppl_times=5Available zero-shot evaluation datasets:
lambada,wikitext2,wikitext103,ptbag_news,scientific_papers_arxiv,scientific_papers_pubmedtext8,lm1b
python main.py \
mode=sample_eval \
data=openwebtext-split \
model=small \
parameterization=subs \
model.length=1024 \
training.loss_type=vae \
model.latentMLP=True \
model.latent1d=True \
model.latent_struc=adaln \
model.latent_dim=512 \
decoder.mask_struc=adaln \
sampling.predictor=vae \
sampling.steps=256 \
sampling.fixz=True \
loader.eval_batch_size=8 \
sampling.num_sample_batches=4 \
eval.checkpoint_path=/path/to/checkpoint.ckptUse sampling.predictor=vae_cache for faster generation:
python main.py \
mode=sample_eval \
... \
sampling.predictor=vae_cache \
sampling.steps=1024 \
eval.checkpoint_path=/path/to/checkpoint.ckptThis repository is built upon MDLM and SEDD.
@inproceedings{
xie2026vadd,
title={Variational Autoencoding Discrete Diffusion with Enhanced Dimensional Correlations Modeling},
author={Tianyu Xie and Shuchen Xue and Zijin Feng and Tianyang Hu and Jiacheng Sun and Zhenguo Li and Cheng Zhang},
booktitle={The Fourteenth International Conference on Learning Representations},
year={2026},
url={https://openreview.net/forum?id=yh7MV2V0ba}
}