In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import nltk, torch, tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from IPython.display import clear_output, display, HTML

import infer_biblically

nltk.download('punkt_tab', quiet=True)

True

In [3]:
model_name = "meta-llama/Llama-3.3-70B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")

Loading checkpoint shards:   0%|          | 0/30 [00:00<?, ?it/s]

In [4]:
word_mentions = {}
token_trie = infer_biblically.TokenTrie()
verses = list(infer_biblically.kjv_verses("kjv.txt"))
for verse, _, words in tqdm.tqdm(verses):
    token_trie.add_verse(verse, words, tokenizer)
    word_mentions.setdefault(verse.lower(), []).append(verse)
    for word in words:
        word_mentions.setdefault(word.lower(), []).append(verse)

can_end_word = infer_biblically.tokens_that_can_end_word(tokenizer)


00%|████████████████████████████████████████████████████████████| 31102/31102 [00:25<00:00, 1237.01it/s]

In [5]:
def render_html(messages):
    html = []
    for m in messages:
        html.append(f"""<p><strong>{m['role']}</strong>: """)
        if m["role"] == "user":
            html.append(f"<span>{m['content']}</span>")
        else:
            for i, spaced_part in enumerate(m["content"].split(" ")):
                if i > 0:
                    html.append(" ")
                for word in nltk.word_tokenize(spaced_part):
                    if word.lower() in word_mentions:
                        verses = word_mentions[word.lower()]
                        verses_str = ", ".join(verses[:5]) + "..." if len(verses) > 5 else ", ".join(verses)
                        html.append(f"""<span title="{verses_str}">{word}</span>""")
                    else:
                        html.append(f"""<span style="border: 1px solid red">{word}</span>""")
        html.append("</p>")
    return HTML("".join(html))

(*If you hover over words in the assistant's responses below, it will show the first five verses the word appears in.*)

In [154]:
messages = []
try:
    while True:
        q = input("> ")
        messages.append({"role": "user", "content": q.strip()})
        inp = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
        messages.append({"role": "assistant", "content": ""})
        for p in infer_biblically.generate(model, tokenizer, inp, 128, token_trie, can_end_word):
            messages[-1]["content"] += p
            clear_output(wait=True)
            display(render_html(messages))
except (KeyboardInterrupt, EOFError):
    print("\nQuit.")


Quit.


In [134]:
def mk_table(tokens_to_ouch, ouch_preds, pos, n):
    html = "<table style='display: block; font-size: 10px'><tr><th>pred</th><th>→</th><th colspan=10 style='text-align: left'>options</th></tr>"
    for i in range(n):
        html += f"<tr style='text-align: left; border-top: 1px solid grey'><td><code style='white-space: nowrap'>{tokenizer.decode([tokens_to_ouch[pos + i]])!r}</code></td><td>→</td>"
        for t in ouch_preds[0, i, :]:
            cell = f"<code style='white-space: nowrap'>{tokenizer.decode([t.item()])!r}</code>"
            if t.item() == tokens_to_ouch[pos + i + 1]:
                cell = f"<span style='border: 2px solid green; padding: 2px'>{cell}</span>"
            html += f"<td style='display: inline-block'>{cell}</td>"
        html += "</tr>"
    return HTML(html + "</table>")

In [136]:
tokens_to_ouch = tokenizer.apply_chat_template(messages)
pos, n = 396, 4
print("pred = ..." + tokenizer.decode(tokens_to_ouch[:pos+n])[-150:])
logits = model(input_ids=torch.tensor([tokens_to_ouch]).to(model.device)).logits[:, pos:pos+n, :].detach().cpu()
ouch_preds = torch.argsort(logits, descending=True, dim=2)[:, :, :50]
mk_table(tokens_to_ouch, ouch_preds, pos, n)

pred = ... and the language I used to answer your questions ( lots of simple sentences and somewhat archer or old fashioned language) I would take a wild ouches


pred,→,options,options.1,options.2,options.3,options.4,options.5,options.6,options.7,options.8,options.9,Unnamed: 12,Unnamed: 13,Unnamed: 14,Unnamed: 15,Unnamed: 16,Unnamed: 17,Unnamed: 18,Unnamed: 19,Unnamed: 20,Unnamed: 21,Unnamed: 22,Unnamed: 23,Unnamed: 24,Unnamed: 25,Unnamed: 26,Unnamed: 27,Unnamed: 28,Unnamed: 29,Unnamed: 30,Unnamed: 31,Unnamed: 32,Unnamed: 33,Unnamed: 34,Unnamed: 35,Unnamed: 36,Unnamed: 37,Unnamed: 38,Unnamed: 39,Unnamed: 40,Unnamed: 41,Unnamed: 42,Unnamed: 43,Unnamed: 44,Unnamed: 45,Unnamed: 46,Unnamed: 47,Unnamed: 48,Unnamed: 49,Unnamed: 50,Unnamed: 51
' wild',→,' guess',' stab',' ',' educated',' and',' speculation',' speculative',' chance',' conject',' but',' leap',' hypothesis',' qu',' estimate',' g',' guest',' shot',' inference',' bet','-g',"' ""'",' spec',"','",' swing',' theory',' a',' card',' guessing',' guesses',' wild',' Guess',' assumption',' prediction',' to',' h',' probability',' possibility',' idea',' hypothetical',' G',' estimation','...',' step',' **',' sup',' guessed',' informed',' speculate',' wager',' theoretical'
' ',→,' guess',' and',' Peter',' educated',' I',' to',' hypothesis',' estimate',' guest',' conject','1',' Guess',' a',""" '""",' guessing',' g',"' ""'",' that',' -',' but',' the',' possibility',' speculation',' or',' qu',' stab',' inference',' chance',' guesses',' bet',' an',' assumption',' **',' my',' is','...',' with',' attempt',' am',' (',' possible',' speculative',' estimation','guess',' best',' at',' wild','2','ouch','unction'
'ouch',→,' and',' guess',' that','...',' (',' ',"','",'ie','.\n\n',' \n\n','.',' or',' I','...\n\n',' to',' an',' a',' but','...','ance',' chance',' Guess','!','/g','ay','guess','er',' -',' here','s','a',':','t',' of',"','",' \n','!\n\n','es','est',' in','é','i','..',' educated',"' ""'",' the','ess',' at','ouch',')'
'es',→,' and',' that',' ',' (','...',' to',' \n\n',"','",' guess','.',' or','.\n\n',' I',' an',' at',' but',' a',' -','...\n\n','...',"','",':',' the',' here',"' ""'",' is','.',' my',')','/g',' of',' as',' it','(','.\n\n',' educated','\n\n',' its',"""'s""",' nd',' guest','....',' )','(g',"""'""",' on','d','..','-g',' \n\n'


In [153]:
[[tokenizer.decode(t) for t in tokenizer.encode(s, add_special_tokens=False)] for s in (" ouch", " unction")]

[[' ', 'ouch'], [' ', 'unction']]

In [155]:
render_html(long_convo)