# Diffusion LLMs

Current closed model: 
[https://chat.inceptionlabs.ai/](https://chat.inceptionlabs.ai/)

# Large Language Diffusion Models (LLaDA)

_"Large Language Diffusion Models" by Nie, Zhu, et al. (arXiv:2502.09992)_

Proposes that "_Generative modeling principles_, rather than the autoregressive formulation [...] fundamentally underpin the essential properties of LLMs".

The paper introduces **masked diffusion models** (MDM), bidrectional BERT-like models that operate on entire sequences in parallel.

Instead of predicting the next token, a diffusion LLM must "de-mask" a corrupted sequence of text.

This is very similar to BERT! Except that we use a variable percentage of noise, rather than a fixed masking rate of 15%.

In [3]:
# ids[i] = (
#     mask_id if dice < .8
#     else random.randint(0, tok.vocab_size - 1) if dice < .9
#     else ids[i]
# )

## Training

A major disadvantage for all "BERT-like" models -- you only get signal when training on a small portion of the masked tokens.

```
The [MASK] of France [MASK] Paris [EOS]
       ^                ^
  Only the masked tokens contribute to the loss
```

> For chosen tokens, BERT masks the token 80% of the time, adds a random vocab 10%, and keeps the token the same 10% of the time.

This is a real disadvantage, in autoregressive models every token predicts the next, and so all contribute to the loss.

This isn't possible in BERT like models since we have no attention masking.

### A Minimal Training Script

Given a sequence of text, "corrupt" a variable percentage with `[MASK]`. The model is trained to predict the original tokens that were masked with CEL.

In [None]:
import os, random, torch
from datasets import load_dataset
from accelerate import notebook_launcher
from transformers import (AutoTokenizer, AutoModelForMaskedLM,
                          DataCollatorWithPadding, TrainingArguments, Trainer)
def main():
    model_id = "answerdotai/ModernBERT-large"
    tok = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForMaskedLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
    mask_id, sep_id, sep = tok.mask_token_id, tok.sep_token_id, tok.sep_token
    tok.chat_template = (
        "User: {{ messages[0]['content'] }}\n" + sep + "\nAssistant:\n{% generation %}{{ messages[1]['content'] }}{% endgeneration %}")
    
    ## Dataset
    def preprocess(batch):
        def _single(messages):
            enc = tok.apply_chat_template(messages, truncation=True, padding="max_length", max_length=512)
            ids, labels = enc["input_ids"], [-100] * len(enc["input_ids"])
            if sep_id not in ids: return None
            start = ids.index(sep_id) + 1
            cand = [i for i in range(start, len(ids))
                    if ids[i] not in (tok.pad_token_id, sep_id)]
            if not cand: return None
            # Each row we apply a random amount of noise (15-99%) and train as usual.
            n_mask = max(int(len(cand) * random.uniform(0.15, 0.99)), 1)
            for i in random.sample(cand, n_mask):
                labels[i] = ids[i] # If masked, we apply loss on the token.
                ids[i] = mask_id 
            return ids, labels
            
        mapped = (_single(m) for m in batch["messages"])
        filtered = (tup for tup in mapped if tup is not None)
        ids, labels = zip(*filtered)
        return {"input_ids": list(ids), "labels": list(labels)}
    
    ds = load_dataset("allenai/tulu-3-sft-mixture-0225", split="train")
    dd = (ds
        .map(preprocess, num_proc=32, batched=True, remove_columns=ds.column_names)
        .train_test_split(0.05, seed=42))
    
    ## Train
    project_name, run_name = "dllm", "modernbert-dllm-tulu"
    os.environ.setdefault("WANDB_PROJECT", run_name)
    args = TrainingArguments(
        run_name, bf16=True,
        per_device_train_batch_size=32, per_device_eval_batch_size=32,
        eval_strategy="steps", eval_steps=200, num_train_epochs=1,
        report_to="wandb", push_to_hub=True,
    )
    trainer = Trainer(
        model=model, args=args,
        train_dataset=dd["train"], eval_dataset=dd["test"],
    )
    trainer.train()
    tok.push_to_hub(f"tommyp111/{run_name}")

notebook_launcher(main, num_processes=torch.cuda.device_count()) # DDP over number of ranks.

## Inference

`[Prompt] [MASK] [MASK] [MASK] [MASK] ... [EOS]`

### Iterative Denoising

Unmask the token the model is most confident at each iteration step (greedy strategy).

To speed up generation, we may unmask multiple tokens per step.

### Block Generation

One of the key limitations of diffusion inference is that the KV caching is not possible, since the attention mask is bi-directional all-to-all.

This can be somewhat mitigated by generating "blocks" of tokens at a time, autoregressively.

- E.g. When generating 128 tokens, autoregressively generate the first 32, add these to the KV cache and generate the next 32.
- This also allows us to have a non-fixed response length.
- A block length of 1 is somewhat equivalent to an autoregressive transformer.


### Remasking & CFG

Some more inference tricks:

We can remask generated tokens to give the model another iteration at prediction.

This can be done randomly, or remask low confidence tokens.

---

A form of "CFG" (classifier-free guidance).

In this case this just passing only the paritially generated response without the prompt.

It's unclear whether these optimizations helped.

In [4]:
import sys
import time
import torch

PROMPT_TEMPLATE = "{cls_token_id} User: {user_prompt} {sep_token} Assistant:\n"

def mk_assistant_message(tokenizer, answer_length, random_toks):
    if random_toks:
        assistant_init_ids = torch.randint(len(tokenizer.vocab), (answer_length,)).tolist()
    else:
        assistant_init_ids = [tokenizer.mask_token_id] * answer_length

    return tokenizer.decode(assistant_init_ids)

def mk_tokens(tokenizer, prompt, answer_length, random_toks):
    # user_prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
    if random_toks:
        assistant_init_ids = torch.randint(len(tokenizer.vocab), (answer_length,)).tolist()
    else:
        assistant_init_ids = [tokenizer.mask_token] * answer_length

    tokenizer.apply_chat_template([{"content": prompt}, {"content": tokenizer.decode(assistant_init_ids)}])
    
    # return user_prompt_ids, assistant_init_ids + [tokenizer.sep_token_id]


@torch.no_grad()
def iter_mask_decode(model, tokenizer, prompt: str, assistant_message, answer_length = 32, random_toks = False) -> str:
    toks_dict = tokenizer.apply_chat_template(
        [{"content": prompt}, {"content": assistant_message}],
        return_dict=True, return_assistant_tokens_mask=True, return_tensors="pt")
    ids = toks_dict['input_ids'][0].tolist()
    assistant_mask = toks_dict['assistant_masks']

    answer_start =  assistant_mask[0].nonzero().min().item()
    answer_end = answer_start + answer_length
    device = next(model.parameters()).device
    idxs = []
    for step in range(answer_length):
        logits = model(input_ids=torch.tensor([ids]).to(device)).logits
        out_probs = torch.softmax(logits[0], dim=-1)
        if random_toks:
            candidate_locs = torch.arange(answer_start, answer_end)
        else:
            candidate_locs = (
                (torch.tensor(ids) == tokenizer.mask_token_id)
                .nonzero(as_tuple=True)[0]
            )
        
        assert len(candidate_locs) != 0, "out of masks"
        
        candidate_probs = out_probs[candidate_locs]
        candidate_max_probs = candidate_probs.max(dim=-1)[0]

        # Argmax top prob
        idx = candidate_max_probs.argmax()
        pos = candidate_locs[idx]
        new_token = candidate_probs[idx].argmax().item()
        ids[pos] = new_token

        # Token, Idx in answer mask
        yield new_token, pos.item() - answer_start


def iterative_print(
    model,
    tokenizer,
    user_prompt,
    answer_length,
    delay=0.0,
    random_toks=False,
):
    def _print_step(resp, n_clear):
        "1) move to start, 2) blank the full width, 3) move back, 4) write new text"
        resp = resp.encode('unicode_escape').decode('ascii')
        blank = " " * n_clear
        sys.stdout.write("\r" + blank + "\r" + resp)
        sys.stdout.flush()
        return len(resp)

    prompt = PROMPT_TEMPLATE.format(
        cls_token_id=tokenizer.cls_token,
        user_prompt=user_prompt,
        sep_token=tokenizer.sep_token,
    )
    print(prompt, end="")
    
    assistant_message = mk_assistant_message(tokenizer, answer_length, random_toks)
    tok_iter = iter_mask_decode(model, tokenizer, prompt, assistant_message, answer_length, random_toks)
    
    # masked = tokenizer.convert_ids_to_tokens(tokenizer.encode(assistant_message))
    
    masked = ["[MASK]"] * answer_length # ["[MASK]"] * answer_length
    n_clear = 0
    for tok, pos in tok_iter:
        masked[pos] = tokenizer.decode(tok)
        # print(tokenizer.decode(tok))
        resp = "".join(masked)
        # print(pos, resp)
        n_clear = _print_step(resp, n_clear)
        time.sleep(delay)

In [5]:
import os, random, itertools, math, torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

device =  (
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)

model_id = "tommyp111/modernbert-dllm-tulu" # https://wandb.ai/graphcore/modernbert-dllm-tulu/runs/nfblha6t?nw=nwusertompollak
# model_id = "tommyp111/modernbert-flowlm-tulu" # https://wandb.ai/graphcore/dllm/runs/ldah4164?nw=nwusertompollak
# model_id = "answerdotai/ModernBERT-large"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForMaskedLM.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device).eval()


prompts = [
    "Who is the greatest soccer player?",
    "What is the meaning of life?",
    "What is the best place to go for a holiady?",
    "x = 2y + 6; y = 3 * 4. Find x.",
    "What is singular value decomposition?",
    "What is the captial of Australia?",
]
for p in prompts:
    iterative_print(model, tokenizer, p, answer_length=25, delay=0.1, random_toks=False)
    print("\n", "*" * 25, "\n")

[CLS] User: Who is the greatest soccer player? [SEP] Assistant:
\nThe greatest soccer player in history is the Brazilian football player, Ronaldinho Jr, who was born in 1986.[SEP].                                
 ************************* 

[CLS] User: What is the meaning of life? [SEP] Assistant:
\nThe meaning of life is to live, to love, to learn, to give, to share, to grow.[SEP]                                                             
 ************************* 

[CLS] User: What is the best place to go for a holiady? [SEP] Assistant:
\nThe best place to go for a holiady is a beach. The beach is full of sun and fun.[SEP]                                                              
 ************************* 

[CLS] User: x = 2y + 6; y = 3 * 4. Find x. [SEP] Assistant:
\nx = 2y + 6; y = 3 * 4.\nThe solution is:\nx = 2y[SEP]                                                                                              
 ************************* 

[CLS] User: What is singular valu

# Diffusion Duality

_"The Diffusion Duality" by Sahoo, Deschenaux et al. (arXiv: 2506.10892)_


Instead of `[MASK]`: Corrupt the input embeddings with noise, model is trained to denoise to the original embedding.

- They argue that CEL with "hard" one-hot targets can lead to high-variance gradients, making training unstable when the input is very noisy.
- The model predictions swing wildly between steps as it tries to achieve this certainty.

### Duo Corruption

```python
clean_weight = 1 - random.uniform(0.03, 0.15)
noise_weight = sqrt(1 - clean_weight ** 2)

hot = one_hot(input_ids) # (B,S,V)
w = clean_weight * hot + noise_weight * random.randn_like(hot)
soft_latents = softmax(w / temperature) # as temperature -> 0, this becomes the original one-hot token

soft_embs = soft_latents @ model.W_embedding # (B,S,V) @ (V,D) -> (B,S,D)
logits = model(input_embs=soft_embs)
```

Temperature is annealed from `1e-3` to `0` over the course of training

This includes very large matmul of K=vocab size

### Duo Loss

```python
p_log  = log_softmax(logits)
p_prob = exp(p_log)

q_prob = one_hot(target_ids)
q_log  = log(q_prob)

kl_pq  = kl_div(p_log, q_prob)      # (B,S,V)
kl_qp  = kl_div(q_log, p_prob)      # (B,S,V)
sym_kl = sum(kl_pq + kl_qp, dim=-1) # (B,S)

# As clean_weight -> 1, view_scale -> 1, loss counts fully.
view_scale = (1.0 - clean_weight) / (vocab_size * clean_weight + 1.0 - clean_weight)
loss = (view_scale * sym_kl).sum()
```