# Large Language Diffusion Models (LLaDA)

The paper introduces **masked diffusion models** (MDM), bidrectional BERT-like models that operate on entire sequences in parallel.

## Inference

`[Prompt] [MASK] [MASK] [MASK] [MASK] [EOS]`

At each time step the model can choose which token to demask (e.g. greedily unmask the token with the highest confidence in a token).

- Can unmask multiple tokens per timestep, speeding up generation

### Generating in blocks

These can be somewhat mitigated by generating "blocks" of output autoregressively at a time -- aka generating the first 32 masked tokens, then the next 32. This allows us to have a non-fixed response length, and still use the KV cache for most of the output.

You can think about a block length of 1 being a true autoregressive transformer.


## Training

Given a sequence of text, "corrupt" a variable percentage with `[MASK]`. Model is trained to predict the original tokens that were masked with CEL.

- can remask randomly
- They do "CFG" where they combine the probability of the token with no prompt.


- Should loss be based on scaling ratio? 99% is much harder than 15%. I found the loss curve is very jagged.

In [None]:
from transformers import TrainingArguments

In [88]:

tok = AutoTokenizer.from_pretrained(model_id)
print(tok.chat_template)

None


In [None]:
tok.

In [None]:
import os, random, torch
from datasets import load_dataset
from accelerate import notebook_launcher
from transformers import (AutoTokenizer, AutoModelForMaskedLM,
                          DataCollatorWithPadding, TrainingArguments, Trainer)
def main():
    model_id = "answerdotai/ModernBERT-large"
    tok = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForMaskedLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
    mask_id, sep_id, sep = tok.mask_token_id, tok.sep_token_id, tok.sep_token
    tok.chat_template = (
        "User: {{ messages[0]['content'] }}\n" + sep + "\nAssistant:\n{{ messages[1]['content'] }}")
    
    ## Dataset
    def preprocess(batch):
        def _single(messages):
            text = tok.apply_chat_template(messages, tokenize=False)
            enc = tok(text, truncation=True, padding="max_length", max_length=512)
            ids, labels = enc["input_ids"], [-100] * len(enc["input_ids"])
            if sep_id not in ids: return None
            start = ids.index(sep_id) + 1
            cand = [i for i in range(start, len(ids))
                    if ids[i] not in (tok.pad_token_id, sep_id)]
            if not cand: return None
            n_mask = max(int(len(cand) * random.uniform(0.15, 0.99)), 1)
            for i in random.sample(cand, n_mask):
                labels[i] = ids[i]
                dice = random.random()
                ids[i] = (
                    mask_id if dice < .8
                    else random.randint(0, tok.vocab_size - 1) if dice < .9
                    else ids[i]
                )
            return ids, labels
            
        mapped = (_single(m) for m in batch["messages"])
        filtered = (tup for tup in mapped if tup is not None)
        ids, labels = zip(*filtered)
        return {"input_ids": list(ids), "labels": list(labels)}
    
    ds = load_dataset("allenai/tulu-3-sft-mixture-0225", split="train")
    dd = (ds
        .map(preprocess, num_proc=32, batched=True, remove_columns=ds.column_names)
        .train_test_split(0.05, seed=42))
    
    ## Train
    project_name, run_name = "dllm", "modernbert-dllm-tulu"
    os.environ.setdefault("WANDB_PROJECT", run_name)
    args = TrainingArguments(
        run_name,
        per_device_train_batch_size=32, per_device_eval_batch_size=32,
        bf16=True, remove_unused_columns=False,
        eval_strategy="steps", eval_steps=200, num_train_epochs=1,
        report_to="wandb", push_to_hub=True,
    )
    trainer = Trainer(
        model=model, args=args,
        train_dataset=dd["train"], eval_dataset=dd["test"],
    )
    trainer.train()
    trainer.push_to_hub()

notebook_launcher(main, num_processes=torch.cuda.device_count())

## Diffusion Duality

In [84]:
import sys
import time
import torch

PROMPT_TEMPLATE = "{cls_token_id} User: {user_prompt} {sep_token} Assistant:\n"


def mk_tokens(tokenizer, prompt, answer_length, random_toks):
    user_prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
    if random_toks:
        assistant_init_ids = torch.randint(len(tokenizer.vocab), (answer_length,)).tolist()
    else:
        assistant_init_ids = [tokenizer.mask_token_id] * answer_length

    return user_prompt_ids, assistant_init_ids + [tokenizer.sep_token_id]


@torch.no_grad()
def iter_mask_decode(model, tokenizer, prompt: str, answer_length = 32, random_toks = False) -> str:
    user_prompt_ids, assistant_init_ids = mk_tokens(tokenizer, prompt, answer_length, random_toks)
    ids = user_prompt_ids + assistant_init_ids
    
    answer_start = len(user_prompt_ids)
    answer_end = answer_start + answer_length

    device = next(model.parameters()).device
    idxs = []
    for step in range(answer_length):
        logits = model(input_ids=torch.tensor([ids]).to(device)).logits
        out_probs = torch.softmax(logits[0], dim=-1)
        if random_toks:
            candidate_locs = torch.arange(answer_start, answer_end)
        else:
            candidate_locs = (
                (torch.tensor(ids) == tokenizer.mask_token_id)
                .nonzero(as_tuple=True)[0]
            )
        
        assert len(candidate_locs) != 0, "out of masks"
        
        candidate_probs = out_probs[candidate_locs]
        candidate_max_probs = candidate_probs.max(dim=-1)[0]

        # Argmax top prob
        idx = candidate_max_probs.argmax()
        pos = candidate_locs[idx]
        new_token = candidate_probs[idx].argmax().item()
        ids[pos] = new_token

        # Token, Idx in answer mask
        yield new_token, pos.item() - answer_start


def iterative_print(
    model,
    tokenizer,
    user_prompt,
    answer_length,
    delay=0.0,
    random_toks=False,
    decode_func=iter_mask_decode,
):
    def _print_step(resp, n_clear):
        "1) move to start, 2) blank the full width, 3) move back, 4) write new text"
        resp = resp.encode('unicode_escape').decode('ascii')
        blank = " " * n_clear
        sys.stdout.write("\r" + blank + "\r" + resp)
        sys.stdout.flush()
        return len(resp)

    prompt = PROMPT_TEMPLATE.format(
        cls_token_id=tokenizer.cls_token,
        user_prompt=user_prompt,
        sep_token=tokenizer.sep_token,
    )
    print(prompt, end="")
    
    tok_iter = iter_mask_decode(model, tokenizer, prompt, answer_length, random_toks)
    
    masked = [""] * answer_length # ["[MASK]"] * answer_length
    n_clear = 0
    for tok, pos in tok_iter:
        masked[pos] = tokenizer.decode(tok)
        # print(tokenizer.decode(tok))
        resp = "".join(masked)
        # print(pos, resp)
        n_clear = _print_step(resp, n_clear)
        time.sleep(delay)

In [81]:
import os, random, itertools, math, torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

device =  (
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)

# model_id = "tommyp111/modernbert-flowlm"
model_id = "tommyp111/modernbert-diffusion"
# model_id = "answerdotai/ModernBERT-large"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForMaskedLM.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device).eval()

In [82]:
user_prompt = "What is the nature of life?"
iterative_print(model, tokenizer, user_prompt, answer_length=8, delay=0.1, random_toks=True)

[CLS] User: What is the nature of life? [SEP] Assistant:
7
.7
.7
.7
.7
.7
.7
.7
.

In [83]:
user_prompt = "Who is the greatest soccer player?"
for i in range(1, 16):
    print("*" * 10, i, "*" * 10)
    iterative_print(model, tokenizer, user_prompt, answer_length=i, delay=0.0)
    print()

********** 1 **********
[CLS] User: Who is the greatest soccer player? [SEP] Assistant:
0
Unknown
********** 2 **********
[CLS] User: Who is the greatest soccer player? [SEP] Assistant:
0
Michael1
Michael Jordan
********** 3 **********
[CLS] User: Who is the greatest soccer player? [SEP] Assistant:
0
M2
Minho1
Mourinho
********** 4 **********
[CLS] User: Who is the greatest soccer player? [SEP] Assistant:
3
.0
M.2
Minho.1
Mourinho.
********** 5 **********
[CLS] User: Who is the greatest soccer player? [SEP] Assistant:
4
.2
 Mess.3
 Messi.1
el Messi.0
 Lionel Messi.
********** 6 **********
[CLS] User: Who is the greatest soccer player? [SEP] Assistant:
5
.1
 is.0
This is.4
This is question.2
This is a question.3
This is a subjective question.
********** 7 **********
[CLS] User: Who is the greatest soccer player? [SEP] Assistant:
6
.1
 is.0
There is.2
There is no.4
There is no soccer.5
There is no soccer player.3
There is no greatest soccer player.
********** 8 **********
[CLS] User: Who