In [None]:
import torch
from miniTransformer import Transformer
from miniGPTDataset import SimpleLetterTokenizer

batch_size = 128
max_seq = 4
d_model = 4
d_ff = 4
n_blocks = 1
n_heads = 1
drop_out_rate = 0.1
learning_rate = 1e-3
epochs = 10
v_size = SimpleLetterTokenizer().n_vocab
print(f"Vocabulary size: {v_size}")
start_token_id = v_size + 1
end_token_id = v_size + 2
v_size = v_size + 2
print(f"Start token ID: {start_token_id}, End token ID: {end_token_id}")
print(f"Using vocabulary size: {v_size} (including start and end tokens)")

checkpoint = torch.load('checkpoints/best_model.pth')  # Load the model checkpoint
model = Transformer(v_size=v_size, max_seq=max_seq, d_model=d_model, drop_out_rate=drop_out_rate, d_ff=d_ff, n_blocks=n_blocks, n_heads=n_heads, pad_idx=0)
model.load_state_dict(checkpoint['model_state_dict'])  # Load the model state dict
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
model = model.to(device)

# test an example
input_seq = "WIL"
input_ids = SimpleLetterTokenizer().encode(input_seq)
input_ids = input_ids + [0]
print(f"Start token ID: {start_token_id}, End token ID: {end_token_id}")
# convert to tensor and add batch dimension
input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(device)
print(f"Input sequence: {input_seq}, Input IDs: {input_ids}")
output_token_ids = model.generate(
    input_ids=input_ids,
    max_length=max_seq,
    start_token_id=start_token_id,
    end_token_id=end_token_id
)

output_token_ids = output_token_ids[0].tolist()
# remove start token
output_token_ids = output_token_ids[1:]  # remove the start token
# remove end token if it exists
if end_token_id in output_token_ids:
    output_token_ids = output_token_ids[:output_token_ids.index(end_token_id)]

output_seq = SimpleLetterTokenizer().decode(output_token_ids)
print(f"Input sequence: {input_seq}, Generated sequence: {output_seq}")


Vocabulary size: 27
Start token ID: 28, End token ID: 29
Using vocabulary size: 29 (including start and end tokens)
Using device: mps
Start token ID: 28, End token ID: 29
Input sequence: YAN, Input IDs: tensor([[25,  1, 14,  0]], device='mps:0')
output_ids: tensor([[28]], device='mps:0')
output_ids: tensor([[28,  8]], device='mps:0')
output_ids: tensor([[28,  8, 11]], device='mps:0')
output_ids: tensor([[28,  8, 11, 25]], device='mps:0')
Input sequence: YAN, Generated sequence: HKYK
