In [2]:
# Decoder-Only Transformer with Causal Masking, <sos>/<eos>, 
# Attention Mask, Larger Dataset, and Robust Inference
import torch
import math
import torch.nn.functional as F
from torch.nn import Parameter

In [3]:
# Step 1: Prepare a slightly larger dataset with <sos> for all sequences
paragraph = [
    '<sos>', 'the', 'sun', 'rises', 'in', 'the', 'east', '<eos>',
    '<sos>', 'the', 'moon', 'shines', 'at', 'night', '<eos>',
    '<sos>', 'stars', 'twinkle', 'in', 'the', 'sky', '<eos>'
]

In [4]:
vocab = sorted(set(paragraph))
word2idx = {w: i for i, w in enumerate(vocab)}
idx2word = {i: w for w, i in word2idx.items()}
vocab_size = len(vocab)

In [5]:
word2idx

{'<eos>': 0,
 '<sos>': 1,
 'at': 2,
 'east': 3,
 'in': 4,
 'moon': 5,
 'night': 6,
 'rises': 7,
 'shines': 8,
 'sky': 9,
 'stars': 10,
 'sun': 11,
 'the': 12,
 'twinkle': 13}

In [6]:
idx2word

{0: '<eos>',
 1: '<sos>',
 2: 'at',
 3: 'east',
 4: 'in',
 5: 'moon',
 6: 'night',
 7: 'rises',
 8: 'shines',
 9: 'sky',
 10: 'stars',
 11: 'sun',
 12: 'the',
 13: 'twinkle'}

In [7]:
vocab_size

14

In [9]:
for i in range(len(paragraph) - 4):
    print(i)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17


In [10]:
# Create input-output pairs for next-word prediction
seq_len = 4
inputs, targets = [], []
for i in range(len(paragraph) - seq_len):
    seq = paragraph[i:i+seq_len]
    target = paragraph[i+seq_len]
    if '<eos>' in seq:
        continue
    inputs.append([word2idx[tok] for tok in seq])
    targets.append(word2idx[target])

In [11]:
inputs

[[1, 12, 11, 7],
 [12, 11, 7, 4],
 [11, 7, 4, 12],
 [7, 4, 12, 3],
 [1, 12, 5, 8],
 [12, 5, 8, 2],
 [5, 8, 2, 6],
 [1, 10, 13, 4],
 [10, 13, 4, 12],
 [13, 4, 12, 9]]

In [12]:
targets

[4, 12, 3, 0, 2, 6, 0, 12, 9, 0]

In [13]:
X = torch.tensor(inputs)
Y = torch.tensor(targets)

In [14]:
batch_size = X.shape[0]

In [15]:
batch_size

10

In [16]:
# Step 2: Model configuration
embed_dim = 16
num_heads = 2
head_dim = embed_dim // num_heads
epochs = 100

In [17]:
# Learnable parameters
embedding_matrix = Parameter(torch.randn(vocab_size, embed_dim))
pos_embedding = Parameter(torch.randn(seq_len, embed_dim))
W_q = Parameter(torch.randn(embed_dim, embed_dim))
W_k = Parameter(torch.randn(embed_dim, embed_dim))
W_v = Parameter(torch.randn(embed_dim, embed_dim))
W1 = Parameter(torch.randn(embed_dim, embed_dim))
b1 = Parameter(torch.zeros(embed_dim))
W2 = Parameter(torch.randn(embed_dim, embed_dim))
b2 = Parameter(torch.zeros(embed_dim))
W_out = Parameter(torch.randn(embed_dim, vocab_size))
b_out = Parameter(torch.zeros(vocab_size))

optimizer = torch.optim.Adam([
    embedding_matrix, pos_embedding, W_q, W_k, W_v,
    W1, b1, W2, b2, W_out, b_out
], lr=0.01)

In [18]:
# Step 3: Training loop
losses = []
for epoch in range(epochs):
    optimizer.zero_grad()

    embedded = embedding_matrix[X] + pos_embedding
    Q = embedded @ W_q
    K = embedded @ W_k
    V = embedded @ W_v

    def reshape(x):
        return x.view(batch_size, seq_len, num_heads, head_dim).transpose(1, 2)

    Qh, Kh, Vh = map(reshape, (Q, K, V))

    attn_mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0)
    scores = (Qh @ Kh.transpose(-2, -1)) / math.sqrt(head_dim)
    scores = scores.masked_fill(attn_mask == 0, float('-inf'))
    attn_weights = F.softmax(scores, dim=-1)

    attn_output = attn_weights @ Vh
    attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)

    ffn = torch.relu(attn_output @ W1 + b1)
    ffn = ffn @ W2 + b2

    final_token = ffn[:, -1, :]
    logits = final_token @ W_out + b_out
    loss = F.cross_entropy(logits, Y)
    losses.append(loss.item())

    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

Epoch 0, Loss: 214.2892
Epoch 10, Loss: 10.2746
Epoch 20, Loss: 0.7195
Epoch 30, Loss: 0.2090
Epoch 40, Loss: 0.1528
Epoch 50, Loss: 0.1469
Epoch 60, Loss: 0.1360
Epoch 70, Loss: 0.0127
Epoch 80, Loss: 0.0014
Epoch 90, Loss: 0.0010


In [19]:
# Step 4: Save model locally in a structured file
model_state = {
    'embedding_matrix': embedding_matrix.detach(),
    'pos_embedding': pos_embedding.detach(),
    'W_q': W_q.detach(), 'W_k': W_k.detach(), 'W_v': W_v.detach(),
    'W1': W1.detach(), 'b1': b1.detach(), 'W2': W2.detach(), 'b2': b2.detach(),
    'W_out': W_out.detach(), 'b_out': b_out.detach(),
    'word2idx': word2idx, 'idx2word': idx2word,
    'embed_dim': embed_dim, 'seq_len': seq_len, 'num_heads': num_heads
}
torch.save(model_state, "decoder_transformer_model.pt")

In [20]:
# Step 5: Load model and inference function
checkpoint = torch.load("decoder_transformer_model.pt")

In [21]:
embedding_matrix = checkpoint['embedding_matrix']
pos_embedding = checkpoint['pos_embedding']
W_q = checkpoint['W_q']
W_k = checkpoint['W_k']
W_v = checkpoint['W_v']
W1 = checkpoint['W1']
b1 = checkpoint['b1']
W2 = checkpoint['W2']
b2 = checkpoint['b2']
W_out = checkpoint['W_out']
b_out = checkpoint['b_out']
word2idx = checkpoint['word2idx']
idx2word = checkpoint['idx2word']
embed_dim = checkpoint['embed_dim']
seq_len = checkpoint['seq_len']
num_heads = checkpoint['num_heads']
head_dim = embed_dim // num_heads

In [22]:
pos_embedding

tensor([[ 0.3655, -1.0750, -0.1643,  0.6425,  1.5937, -1.7208, -0.0904, -0.0357,
         -0.0707, -0.7346,  2.1541, -1.5807, -2.5454, -0.5540, -0.1947,  0.5991],
        [-1.4843,  0.7852, -0.9336,  2.7172, -0.5522, -0.7768, -1.1033,  0.4119,
         -1.2031, -0.2711, -0.5309,  0.2537, -1.7941,  1.6523, -0.9307, -0.5694],
        [-0.8823, -0.4147,  2.1827, -1.0573, -0.2713, -0.4346, -1.1253,  0.9946,
         -1.0674,  1.9790,  1.0854, -0.9306,  0.0748, -0.2926, -1.6343, -0.3283],
        [-0.6138,  0.3872, -0.1049,  0.1742, -0.5154, -1.0625, -0.2511,  0.6225,
         -1.0498, -1.4465, -0.6444,  0.6280, -0.6765,  0.4764,  0.9294,  1.4273]])

In [23]:
W_q

tensor([[-0.6735, -1.3010, -0.5142,  0.1470, -0.4942, -0.6457,  1.6273,  1.1840,
         -0.3979,  0.0145,  0.6088, -0.1372, -1.7032, -0.7718,  0.3392,  0.6545],
        [-0.8289,  0.2249,  2.5723,  1.2621, -1.5539,  0.0320,  1.2931, -0.2951,
          0.3752,  0.6632,  0.7567, -0.7760, -1.3101, -1.3830, -0.3203, -0.1739],
        [ 2.5386, -0.4766,  0.5736,  0.5614,  0.7321,  1.1841, -2.4389, -0.0283,
         -0.0579,  0.8476, -0.0036,  0.5378, -0.7895, -0.4302,  0.9561, -0.9286],
        [ 0.8756,  0.9265, -0.6452,  0.4425, -0.6957, -0.8279, -0.0613, -0.2311,
         -1.7024, -0.9615,  0.9508, -0.4698, -2.0479, -0.4157,  0.3622, -0.0191],
        [-0.8027,  1.1205,  1.0070,  0.4953, -0.1670, -0.7380,  2.1886,  0.1097,
         -1.6248, -1.3093,  0.4047,  0.3114, -1.0942,  0.6767,  1.0555,  0.5374],
        [ 1.1496,  1.6786,  0.8893, -1.0376,  0.0752,  0.7339,  1.1957,  0.0477,
         -1.2821, -1.0999,  1.0371,  1.3299, -0.9239, -0.1270,  0.1070,  2.0305],
        [ 0.7131, -0.0

In [24]:
W_q.shape

torch.Size([16, 16])

In [25]:
# Step 6: Prediction utility
def predict_next_tokens(input_seq, max_len=5):
    model_input = input_seq[:]
    for _ in range(max_len):
        current_len = len(model_input)
        if current_len < seq_len:
            pad = ['<sos>'] * (seq_len - current_len)
            input_tokens = pad + model_input
        else:
            input_tokens = model_input[-seq_len:]

        input_idx = torch.tensor([[word2idx.get(tok, 0) for tok in input_tokens]])
        embedded = embedding_matrix[input_idx] + pos_embedding
        Q = embedded @ W_q
        K = embedded @ W_k
        V = embedded @ W_v

        Qh, Kh, Vh = map(lambda x: x.view(1, seq_len, num_heads, head_dim).transpose(1, 2), (Q, K, V))
        scores = (Qh @ Kh.transpose(-2, -1)) / math.sqrt(head_dim)
        scores = scores.masked_fill(torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0) == 0, float('-inf'))
        attn_weights = F.softmax(scores, dim=-1)
        attn_output = attn_weights @ Vh
        attn_output = attn_output.transpose(1, 2).contiguous().view(1, seq_len, embed_dim)
        ffn = torch.relu(attn_output @ W1 + b1)
        ffn = ffn @ W2 + b2
        final_token = ffn[:, -1, :]
        logits = final_token @ W_out + b_out
        next_idx = torch.argmax(logits, dim=1).item()
        next_word = idx2word[next_idx]
        model_input.append(next_word)
        if next_word == '<eos>':
            break
    return model_input

In [26]:
# Step 7: Predictions on loaded model
print("\nPredictions on loaded model:")
print("Generated sequence 1:", predict_next_tokens(['<sos>', 'the', 'moon']))
print("Generated sequence 2:", predict_next_tokens(['<sos>', 'stars']))
print("Generated sequence 3:", predict_next_tokens(['<sos>', 'the']))


Predictions on loaded model:
Generated sequence 1: ['<sos>', 'the', 'moon', 'sky', 'in', 'the', 'at', '<eos>']
Generated sequence 2: ['<sos>', 'stars', 'in', 'in', 'the', '<eos>']
Generated sequence 3: ['<sos>', 'the', 'night', '<eos>']
