Skip to content

scottstts/Learning-Diffusion-LM

Repository files navigation

Learn masked diffusion language model (MDLM)

The model design and architecture here is from this paper: MDLM paper

I tried to simulated the Karpathy "Let's Build GPT" experience.

Denoising Process Visualizing the denoising process during text generation

Code and Content

  • learning_guide.md is the central learning guide generated by Opus 4.5

  • train_model.ipynb contains the code of the model and training loop, as well as relevant math (I did) and other info that can be helpful learning the ins and outs of the model implementation and the concept

  • sampling.ipynb contains the inference code for testing trained checkpoints

  • bpe_tokenizer.py is a simple byte pair encoding tokenizer

  • sample_visual.py is a vibe coded pygame to visualize the model inference denoising (de-masking) process, and truly be reminded that this really is a diffusion model not a transformer encoder :)

Deviations From Guide

There are a few things that deviated from the learning guide that I decided to implement along the way

  • Using RMSNorm, SwiGLU, RoPE to replace LayerNorm, gelu, simple PE, respectively

  • Using BPE tokenizer instead of simple character level tokenizer

  • Using mixed precision and gradient accumulation to speed up training

  • Using the original Ancestral Sampling instead of topK to meet the soul of the model

  • Training on Harry Potter not Tiny Shakespeare just for fun

  • Maybe other stuff too, lost count

Note

The entire setup and training is a bit casual, not to produce production code here just to learn diffusion LM.

Masked diffusion LM is kinda cool. It’s technically a diffusion model but the model arch is almost a bidirectional transformer encoder. Because it uses masks to corrupt data, it has the “absorbing states” property (once masked at t, always masked at t’ > t), so mathematically we can collapse the entire noise applying process and each intermediate state. We basically used Monte Carlo sampling to approximate the diffusion step integral. And the loss function after elegant math becomes the average cross entropy loss of masked tokens. The math beautifully collapses within itself and what's left is a simple and clean code implementation that looks suspiciously like BERT.

The "Soul" of the Model is at that little touch of time factor.

Great fun learning and training it!

About

Learn and train a masked diffusion language model (MDLM)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors