In [1]:
import torch
from miniTransformer import Transformer
from miniGPTDataset import SimpleLetterTokenizer

batch_size = 128
max_seq = 4
d_model = 4
d_ff = 4
n_blocks = 1
n_heads = 1
drop_out_rate = 0.1
learning_rate = 1e-3
epochs = 10
v_size = SimpleLetterTokenizer().n_vocab
print(f"Vocabulary size: {v_size}")
start_token_id = v_size + 1
end_token_id = v_size + 2
v_size = v_size + 2
print(f"Start token ID: {start_token_id}, End token ID: {end_token_id}")
print(f"Using vocabulary size: {v_size} (including start and end tokens)")

checkpoint = torch.load('checkpoints/best_model.pth')  # Load the model checkpoint
model = Transformer(v_size=v_size, max_seq=max_seq, d_model=d_model, drop_out_rate=drop_out_rate, d_ff=d_ff, n_blocks=n_blocks, n_heads=n_heads, pad_idx=0)
model.load_state_dict(checkpoint['model_state_dict'])  # Load the model state dict
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")
model = model.to(device)

# test an example
input_seq = "ABC"
input_ids = SimpleLetterTokenizer().encode(input_seq)
input_ids = input_ids + [0]
print(f"Start token ID: {start_token_id}, End token ID: {end_token_id}")
# convert to tensor and add batch dimension
input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(device)
print(f"Input sequence: {input_seq}, Input IDs: {input_ids}")
output_token_ids = model.generate(
    input_ids=input_ids,
    max_length=max_seq,
    start_token_id=start_token_id,
    end_token_id=end_token_id
)

output_token_ids = output_token_ids[0].tolist()
# remove start token
output_token_ids = output_token_ids[1:]  # remove the start token
# remove end token if it exists
if end_token_id in output_token_ids:
    output_token_ids = output_token_ids[:output_token_ids.index(end_token_id)]

output_seq = SimpleLetterTokenizer().decode(output_token_ids)
print(f"Input sequence: {input_seq}, Generated sequence: {output_seq}")


Vocabulary size: 27
Start token ID: 28, End token ID: 29
Using vocabulary size: 29 (including start and end tokens)
Using device: mps
Start token ID: 28, End token ID: 29
Input sequence: ABC, Input IDs: tensor([[1, 2, 3, 0]], device='mps:0')
output_ids: tensor([[28]], device='mps:0')
output_ids: tensor([[28,  9]], device='mps:0')
output_ids: tensor([[28,  9,  2]], device='mps:0')
output_ids: tensor([[28,  9,  2,  1]], device='mps:0')
Input sequence: ABC, Generated sequence: IBAJ


In [2]:
src_mask = model._create_padding_mask(input_ids)
encoder_outputs = model.encoder(input_ids, src_mask=src_mask)
print(f"Encoder outputs shape: {encoder_outputs.shape}")  # b, max_seq, d_model
print(f"Encoder outputs: {encoder_outputs}")

Encoder outputs shape: torch.Size([1, 4, 4])
Encoder outputs: tensor([[[ 1.1810,  1.4783, -2.4246, -0.0530],
         [ 2.4247, -2.0914,  0.1005, -0.1978],
         [-0.7326, -1.8711,  2.1021,  0.2646],
         [-0.9289,  2.6003, -1.2676, -0.3874]]], device='mps:0',
       grad_fn=<NativeLayerNormBackward0>)


  nonzero_finite_vals = torch.masked_select(


In [12]:
# initialize output_ids with start_token_id
output_ids = torch.full((input_ids.size(0), 1), fill_value=start_token_id, dtype=torch.long, device=input_ids.device)
print(f"Shape of output IDs: {output_ids.shape}")  # should be (batch_size, 1)
print(f"Initial output IDs: {output_ids}")
output_ids = torch.cat([output_ids, torch.tensor([[3]], dtype=torch.long, device=input_ids.device)], dim=1)
print(f"Shape of output IDs: {output_ids.shape}")  # should be (batch_size, 2)
print(f"Updated output IDs: {output_ids}")

Shape of output IDs: torch.Size([1, 1])
Initial output IDs: tensor([[28]], device='mps:0')
Shape of output IDs: torch.Size([1, 2])
Updated output IDs: tensor([[28,  3]], device='mps:0')


In [13]:
tgt_mask = model._create_causal_mask(output_ids.size(1)) # b, 1, s+1, s+1
tgt_mask = tgt_mask.to(output_ids.device)
print(f"Target mask shape: {tgt_mask.shape}")  # should be (batch_size, 1, seq_len, seq_len)
print(f"Target mask: {tgt_mask}")

Target mask shape: torch.Size([1, 1, 2, 2])
Target mask: tensor([[[[ True, False],
          [ True,  True]]]], device='mps:0')


In [14]:
decoder_output, decoder_self_attentions, decoder_cross_attentions = model.decoder(output_ids, encoder_outputs, src_mask, tgt_mask, return_attention=True)
print(f"Decoder output shape: {decoder_output.shape}")  # should be (batch_size, seq_len, d_model)
print(f"Decoder output: {decoder_output}")
# print last layer of attentions
print(f"last layer decoder self attention shape: {decoder_self_attentions[-1].shape}")  # should be (batch_size, n_heads, seq_len, seq_len)
print(f"last layer decoder self attention: {decoder_self_attentions[-1]}")
print(f"last layer decoder cross attention shape: {decoder_cross_attentions[-1].shape}")  # should be (batch_size, n_heads, seq_len, seq_len)
print(f"last layer decoder cross attention: {decoder_cross_attentions[-1]}")


Decoder output shape: torch.Size([1, 2, 29])
Decoder output: tensor([[[-3.5205,  3.8691,  2.2394, -3.8453, -1.0840,  6.2557,  1.2145,
           6.2666,  1.6552,  6.5308,  3.8158,  4.5190, -9.8917, -9.0440,
          -3.9611, -7.1893,  3.5974, -4.9453, -3.3011, -1.4426,  2.1608,
           4.1235, -5.0371, -6.7411,  5.1051,  6.3331, -4.2807, -3.5562,
          -3.2285],
         [-6.7916,  5.2714,  6.2297,  3.3506, -6.1825, -0.8195,  5.9423,
           2.3901, -5.6700,  0.4123,  5.4526,  4.3191, -5.3846, -4.6253,
          -9.6467, -8.9116, -4.0615, -7.8830,  3.0376,  4.3956,  6.1333,
           4.7061,  2.3845, -1.2127, -2.5649,  2.5135, -0.6261, -6.2159,
          -5.9123]]], device='mps:0', grad_fn=<LinearBackward0>)
last layer decoder self attention shape: torch.Size([1, 1, 2, 2])
last layer decoder self attention: tensor([[[[1.0000, 0.0000],
          [0.4825, 0.5175]]]], device='mps:0', grad_fn=<SoftmaxBackward0>)
last layer decoder cross attention shape: torch.Size([1, 1, 2, 4])