# Large Language Diffusion Models (LLaDA)

The paper introduces **masked diffusion models**, bidrectional BERT-like models that operate 


## Diffusion Duality

In [20]:
len(tokenizer.vocab)

50368

In [36]:
torch.randint(len(tokenizer.vocab), (32,)).tolist()

[33986,
 29625,
 42119,
 36341,
 42868,
 19619,
 4407,
 13624,
 24001,
 15302,
 33402,
 549,
 42438,
 3555,
 32822,
 465,
 36802,
 17610,
 22496,
 12877,
 15244,
 28122,
 45055,
 22231,
 3860,
 43852,
 10751,
 9378,
 1102,
 32589,
 17177,
 36561]

In [54]:
import sys
import time
import torch

PROMPT_TEMPLATE = "{cls_token_id} User: {user_prompt} {sep_token} Assistant:\n"

@torch.no_grad()
def iter_mask_decode(model, tokenizer, prompt: str, answer_length = 32, random_toks = False) -> str:
    user_prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
    if random_toks:
        assistant_init_ids = torch.randint(len(tokenizer.vocab), (answer_length,)).tolist()
    else:
        assistant_init_ids = [tokenizer.mask_token_id] * answer_length
        
    ids = (
        user_prompt_ids
        + assistant_init_ids
        + [tokenizer.sep_token_id]
    )

    answer_start = len(user_prompt_ids)
    answer_end = answer_start + answer_length

    device = next(model.parameters()).device
    idxs = []
    for step in range(answer_length):
        logits = model(input_ids=torch.tensor([ids]).to(device)).logits
        out_probs = torch.softmax(logits[0], dim=-1)
        if random_toks:
            mask_locs = torch.arange(answer_start, answer_end)
        else:
            mask_locs = (
                (torch.tensor(ids) == tokenizer.mask_token_id)
                .nonzero(as_tuple=True)[0]
            )
        
        assert len(mask_locs) != 0, "out of masks"
        
        candidate_probs = out_probs[mask_locs]
        candidate_max_probs = candidate_probs.max(dim=-1)[0]

        # Argmax top prob
        idx = candidate_max_probs.argmax()
        pos = mask_locs[idx]
        new_token = candidate_probs[idx].argmax().item()
        ids[pos] = new_token

        # Token, Idx in answer mask
        yield new_token, pos.item() - answer_end

def iterative_print(
    model,
    tokenizer,
    user_prompt,
    answer_length,
    delay=0.0,
    random_toks=False,
    decode_func=iter_mask_decode,
):
    def _print_step(resp, n_clear):
        "1) move to start, 2) blank the full width, 3) move back, 4) write new text"
        resp = resp.encode('unicode_escape').decode('ascii')
        blank = " " * n_clear
        sys.stdout.write("\r" + blank + "\r" + resp)
        sys.stdout.flush()
        return len(resp)

    prompt = PROMPT_TEMPLATE.format(
        cls_token_id=tokenizer.cls_token,
        user_prompt=user_prompt,
        sep_token=tokenizer.sep_token,
    )
    print(prompt, end="")
    
    tok_iter = iter_mask_decode(model, tokenizer, prompt, answer_length, random_toks)
    
    masked = ["[MASK]"] * answer_length
    n_clear = 0
    for tok, pos in tok_iter:
        masked[pos] = tokenizer.decode(tok)
        # print(tokenizer.decode(tok))
        resp = "".join(masked)
        print(pos, resp)
        n_clear = _print_step(resp, n_clear)
        time.sleep(delay)

In [55]:
import os, random, itertools, math, torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

device =  (
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)

model_id = "tommyp111/modernbert-flowlm"
# model_id = "tommyp111/modernbert-diffusion"
# model_id = "answerdotai/ModernBERT-large"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForMaskedLM.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device).eval()

In [56]:
user_prompt = "What is the nature of life?"
iterative_print(model, tokenizer, user_prompt, answer_length=8, delay=0.1, random_toks=True)

[CLS] User: What is the nature of life? [SEP] Assistant:
-6 [MASK][MASK] of[MASK][MASK][MASK][MASK][MASK]
[MASK][MASK] of[MASK][MASK][MASK][MASK][MASK]-6 [MASK][MASK] of[MASK][MASK][MASK][MASK][MASK]
[MASK][MASK] of[MASK][MASK][MASK][MASK][MASK]-6 [MASK][MASK] of[MASK][MASK][MASK][MASK][MASK]
[MASK][MASK] of[MASK][MASK][MASK][MASK][MASK]-6 [MASK][MASK] of[MASK][MASK][MASK][MASK][MASK]
[MASK][MASK] of[MASK][MASK][MASK][MASK][MASK]-6 [MASK][MASK] of[MASK][MASK][MASK][MASK][MASK]
[MASK][MASK] of[MASK][MASK][MASK][MASK][MASK]-6 [MASK][MASK] of[MASK][MASK][MASK][MASK][MASK]
[MASK][MASK] of[MASK][MASK][MASK][MASK][MASK]-6 [MASK][MASK] of[MASK][MASK][MASK][MASK][MASK]
[MASK][MASK] of[MASK][MASK][MASK][MASK][MASK]-6 [MASK][MASK] of[MASK][MASK][MASK][MASK][MASK]
[MASK][MASK] of[MASK][MASK][MASK][MASK][MASK]

In [None]:
user_prompt = "Who is the greatest soccer player?"
for i in range(1, 16):
    print("*" * 10, i, "*" * 10)
    iterative_print(model, tokenizer, user_prompt, answer_length=i, delay=0.0)
    print()