In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import nltk, torch, tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from IPython.display import clear_output, display, HTML

import infer_biblically

nltk.download('punkt_tab', quiet=True)

True

In [5]:
model_name = "meta-llama/Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to("cuda:0")

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [7]:
word_mentions = {}
token_trie = infer_biblically.TokenTrie()
verses = list(infer_biblically.kjv_verses("kjv.txt"))
for verse, _, words in tqdm.tqdm(verses):
    token_trie.add_verse(verse, words, tokenizer)
    word_mentions.setdefault(verse.lower(), []).append(verse)
    for word in words:
        word_mentions.setdefault(word.lower(), []).append(verse)

can_end_word = infer_biblically.tokens_that_can_end_word(tokenizer)

100%|██████████| 31102/31102 [00:54<00:00, 569.20it/s]


In [8]:
def render_html(messages):
    html = []
    for m in messages:
        html.append(f"""<p><strong>{m['role']}</strong>: """)
        if m["role"] == "user":
            html.append(f"<span>{m['content']}</span>")
        else:
            for i, spaced_part in enumerate(m["content"].split(" ")):
                if i > 0:
                    html.append(" ")
                for word in nltk.word_tokenize(spaced_part):
                    if word.lower() in word_mentions:
                        verses = word_mentions[word.lower()]
                        verses_str = ", ".join(verses[:5]) + "..." if len(verses) > 5 else ", ".join(verses)
                        html.append(f"""<span title="{verses_str}">{word}</span>""")
                    else:
                        html.append(f"""<span style="border: 1px solid red">{word}</span>""")
        html.append("</p>")
    return HTML("".join(html))

(*If you hover over words in the assistant's responses below, it will show the first five verses the word appears in.*)

In [15]:
messages = []
try:
    while True:
        q = input("> ")
        messages.append({"role": "user", "content": q.strip()})
        inp = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
        messages.append({"role": "assistant", "content": ""})
        for p in infer_biblically.generate(model, tokenizer, inp, 128, token_trie, can_end_word):
            messages[-1]["content"] += p
            clear_output(wait=True)
            display(render_html(messages))
except (KeyboardInterrupt, EOFError):
    print("\nQuit.")


Quit.
