In [16]:
import sys
import os
import torch

# Add project root to Python path
notebook_dir = os.path.dirname(os.path.abspath('__file__'))
project_root = os.path.abspath(os.path.join(notebook_dir, '..'))
sys.path.append(project_root)

from transformers.TokenTransformer import DecoderTransformer, SimpleTokenizer
from config import *

# Set device
device = torch.device('mps' if torch.backends.mps.is_available() else 
                     ('cuda' if torch.cuda.is_available() else 'cpu'))
print(f"Using device: {device}")



In [17]:
# Initialize tokenizer with the same text used for training
# Replace this path with your training data path
with open('/Users/Tom/Documents/dev/deep-learning-edu/1828-embedding-model/transformer-1830/data/book_of_mormon.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# with open('data/cleaned_dictionary.txt', 'r', encoding='utf-8') as f:
#     text  += f.read()

# with open('data/lady_of_perth.txt', 'r', encoding='utf-8') as f:
#     text  += f.read()

tokenizer = SimpleTokenizer()
tokenizer.fit(text)



In [18]:
def generate_text(prompt, max_new_tokens=100, temperature=temperature, top_k=top_k):
    # Encode the prompt
    encoded = tokenizer.encode(prompt)
    context = torch.tensor([encoded], dtype=torch.long, device=device)
    
    # Generate new tokens
    generated = model.generate(
        context,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_k=top_k
    )
    
    # Decode and return the result
    return tokenizer.decode(generated[0].tolist())

In [19]:
# Initialize and load the model
model = DecoderTransformer(
    vocab_size=tokenizer.vocab_size,
    n_embd=n_embd,
    n_head=n_head,
    n_layer=n_layer,
    block_size=block_size,
    dropout=dropout
).to(device)

model_path = os.path.join('../models/final_model_256_8head.pth')
model.load_state_dict(torch.load(model_path))
print(f"Model loaded from {model_path}")



In [20]:
# Test the text generation
prompt = "The Lord said"
generated_text = generate_text(
    prompt,
    max_new_tokens=50,
    temperature=0.9,  # You can override the default temperature
    top_k=10          # You can override the default top_k
)
print(generated_text)

