In [7]:
import torch

from diyt.tokenizer import SimpleTokenizer
from diyt.model import Transformer, ModelConfig
from diyt.paths import ASSETS_DIR

tokenizer = SimpleTokenizer.load(ASSETS_DIR / "hp" / "tokenizer.json")
checkpoint = torch.load(ASSETS_DIR / "hp" / "model_checkpoints" / "checkpoint_200.pth")

model_config = ModelConfig.model_validate(checkpoint["model_config"])
model = Transformer(vocab_size=tokenizer.vocab_size, **model_config.model_dump())
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

Transformer(
  (semantic_vocab_embeddings): Embedding(4115, 512)
  (positional_embeddings): Embedding(256, 512)
  (decoder_blocks): ModuleList(
    (0-7): 8 x DecoderBlock(
      (self_attention): SelfAttention(
        (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (W_Q): Linear(in_features=512, out_features=2048, bias=True)
        (W_K): Linear(in_features=512, out_features=2048, bias=True)
        (W_V): Linear(in_features=512, out_features=2048, bias=True)
        (W): Linear(in_features=2048, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (feed_forward): Sequential(
        (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (1): Linear(in_features=512, out_features=512, bias=True)
        (2): GELU(approximate='none')
        (3): Linear(in_features=512, out_features=512, bias=True)
        (4): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (layer_norm): LayerNorm((512,), eps=1e-05, elem

In [5]:
import json

from diyt.paths import DATA_DIR

with open(DATA_DIR / "harry_potter" / "train.json", "r", encoding="utf-8") as f:
    texts = json.load(f)
texts[0]

"Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much. They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense."

In [6]:
text = "Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly"
num_gen = 100

for _ in range(num_gen):
    encoded = tokenizer([text], max_seq_length=model_config.context_length)
    next_idx = sum(encoded.attention_mask[0])
    logits = model(**encoded.model_dump())
    next_token_id = logits[0, next_idx - 1].argmax(dim=-1)
    next_token = tokenizer.decode([next_token_id])
    text += " " + next_token
text

"Mr. and Mrs. Dursley, of number four, Privet Drive, were proud to say that they were perfectly normal, thank you very much . They were the last people you'd expect to be involved in anything strange or mysterious, because they just didn't hold with such nonsense . <eos> Harry looked up at the giant . <eos> The giant chuckled darkly . <eos> The Dursleys had everything they wanted, but they also had a secret, and their greatest fear was that somebody would discover it . They didn't think they could bear it if anyone found out about the Potters . Mrs . Potter was Mrs . Dursley's sister, but they hadn't met for several years; in fact,"