In [None]:
import os, random, itertools, math, torch
from torch.utils.data import DataLoader
from transformers import (
    AutoTokenizer, AutoModelForMaskedLM,
    get_cosine_schedule_with_warmup
)
from torch.optim import AdamW
from datasets import load_dataset
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
model_id = "johnowhitaker/modernbert-diffusion"
tokenizer = AutoTokenizer.from_pretrained(model_id)
SEP_ID, CLS_ID, MASK_ID = tokenizer.sep_token_id, tokenizer.cls_token_id, tokenizer.mask_token_id
model = AutoModelForMaskedLM.from_pretrained(model_id, device_map=device)
model.eval();

In [None]:
# Single forward pass:
prompt = "User: Which is the best programming language? " + tokenizer.sep_token + " Assistant:"
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
ans_len = 12
ids = [CLS_ID] + prompt_ids + [SEP_ID] + [MASK_ID]*ans_len + [SEP_ID]
with torch.no_grad():
  outs = model(input_ids=torch.tensor([ids]).to(device)).logits
print(outs.shape)
out_ids = outs[0].argmax(dim=-1).tolist()
print(tokenizer.decode(out_ids))

torch.Size([1, 28, 50368])
[CLS]User: Which is the best programming language? 
 Assistant: Python, Python,,,,,, is Python..[SEP]


In [None]:
# In a loop, keeping the most confident
prompt = "User: Which is the best programming language? " + tokenizer.sep_token + " Assistant:"
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
ans_len = 32
outs = None
ids = [CLS_ID] + prompt_ids + [SEP_ID] + [MASK_ID]*ans_len + [SEP_ID]
for i in range(ans_len):
  if i % 4 == 0: # Optional: only run through the model every 4 (i.e. keep the top 4 each forwrd pass)
    with torch.no_grad():
      outs = model(input_ids=torch.tensor([ids]).to(device)).logits
  out_probs = torch.softmax(outs[0], dim=-1)
  mask_locs = (torch.tensor(ids) == MASK_ID).nonzero(as_tuple=True)[0]
  new_probs = torch.zeros_like(out_probs)
  new_probs[mask_locs] = out_probs[mask_locs]
  max_probs, max_locs = new_probs.max(dim=-1)
  max_loc = max_probs.argmax(dim=-1)
  print(max_loc, tokenizer.decode(new_probs[max_loc].argmax().item()))
  ids[max_loc] = new_probs[max_loc].argmax().item()
print(tokenizer.decode(ids))

tensor(46, device='cuda:0') .
tensor(45, device='cuda:0') .
tensor(15, device='cuda:0')  is
tensor(23, device='cuda:0') .
tensor(21, device='cuda:0')  programming
tensor(22, device='cuda:0')  languages
tensor(16, device='cuda:0')  a
tensor(17, device='cuda:0')  best
tensor(44, device='cuda:0')  etc
tensor(19, device='cuda:0')  for
tensor(20, device='cuda:0')  all
tensor(43, device='cuda:0') ,
tensor(41, device='cuda:0') ,
tensor(39, device='cuda:0') ,
tensor(18, device='cuda:0')  language
tensor(42, device='cuda:0')  Python
tensor(38, device='cuda:0')  Java
tensor(37, device='cuda:0') ,
tensor(40, device='cuda:0')  Java
tensor(24, device='cuda:0')  There
tensor(25, device='cuda:0')  are
tensor(26, device='cuda:0')  many
tensor(29, device='cuda:0')  languages
tensor(30, device='cuda:0')  languages
tensor(28, device='cuda:0')  programming
tensor(35, device='cuda:0') ,
tensor(32, device='cuda:0')  as
tensor(27, device='cuda:0')  popular
tensor(31, device='cuda:0')  such
tensor(33, device=

In [None]:
# Wrapping that in a function
def sample(q, ans_len=32):
  prompt = f"User: {q} " + tokenizer.sep_token + " Assistant:"
  prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
  ids = [CLS_ID] + prompt_ids + [SEP_ID] + [MASK_ID]*ans_len + [SEP_ID]
  for i in range(ans_len):
    with torch.no_grad():
      outs = model(input_ids=torch.tensor([ids]).to(device)).logits
    out_probs = torch.softmax(outs[0], dim=-1)
    mask_locs = (torch.tensor(ids) == MASK_ID).nonzero(as_tuple=True)[0]
    new_probs = torch.zeros_like(out_probs)
    new_probs[mask_locs] = out_probs[mask_locs]
    max_probs, max_locs = new_probs.max(dim=-1)
    max_loc = max_probs.argmax(dim=-1)
    ids[max_loc] = new_probs[max_loc].argmax().item()
  return tokenizer.decode(ids)

In [None]:
sample("Tell me a fun fact about cows")

"[CLS]User: Tell me a fun fact about cows [SEP] Assistant:[SEP], here's a fun fact about cows:\n\nThe fact is that cows are the most intelligent animals in the world. They can think and make decisions.[SEP]"

In [None]:
sample("Tell me a funny joke about lemons")

'[CLS]User: Tell me a funny joke about lemons [SEP] Assistant:[SEP]\'s a funny joke about lemons: "I have a lemonade stand, and I\'m going to sell lemons."\n Assistant: That\'s funny.[SEP]'

In [None]:
sample("Which OS is best?")

"[CLS]User: Which OS is best? [SEP] Assistant:[SEP], I don't know. I haven't used them personally. I'm sure there are some that are better than others, but I can't tell you.[SEP]"

In [None]:
sample("Tell me a fun fact about cows - a good one")

"[CLS]User: Tell me a fun fact about cows - a good one [SEP] Assistant:[SEP]'s a fun fact about cows: they can't read.\n\nComment: I'm sorry, but I can't help you with that one, either.[SEP]"