In [None]:
import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()#weights didnt update
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads" #dimensionality check

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)

        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        B, T, D = x.shape

        Q = self.W_q(x)#what to look for or query
        K = self.W_k(x)#the keys or important parts
        V = self.W_v(x)#values for the keys

        Q = Q.view(B, T, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(B, T, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(B, T, self.num_heads, self.d_k).transpose(1, 2)

        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.d_k)#formula = (Q.K)/(d^0.5)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn = torch.softmax(scores, dim=-1)#softmaxxing

        out = attn @ V#final output

        out = out.transpose(1, 2).contiguous().view(B, T, D)

        return self.W_o(out)


In [None]:
#causal mask used by gpt-2
def causal_mask(seq_len):
    """
    Returns shape: (1, 1, seq_len, seq_len)
    """
    return torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0)# to let it access only prev vals


In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()

        pe = torch.zeros(max_len, d_model)#position matrix
        position = torch.arange(0, max_len).unsqueeze(1)#to get positions

        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)#for even terms
        pe[:, 1::2] = torch.cos(position * div_term)#for odd terms

        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]


In [None]:

#random input i got from chatgpt to test the code
if __name__ == "__main__":

    # Parameters
    batch_size = 2
    seq_len = 5
    d_model = 8
    num_heads = 2

    # Dummy input (like embeddings)
    x = torch.randn(batch_size, seq_len, d_model)

    print("Input shape:", x.shape)

    # Positional Encoding
    pos_enc = PositionalEncoding(d_model)
    x_pos = pos_enc(x)

    print("After Positional Encoding:", x_pos.shape)

    # Multi-Head Attention
    mha = MultiHeadAttention(d_model, num_heads)

    mask = causal_mask(seq_len)  # try with and without mask
    out = mha(x_pos, mask)

    print("After Multi-Head Attention:", out.shape)

    print("\nSample output tensor:\n", out)

Input shape: torch.Size([2, 5, 8])
After Positional Encoding: torch.Size([2, 5, 8])
After Multi-Head Attention: torch.Size([2, 5, 8])

Sample output tensor:
 tensor([[[-0.0132, -0.9266,  0.0755, -0.2064, -0.4950,  0.3838, -0.4032,
           0.0495],
         [-0.3515, -0.5994,  0.3107, -0.1495, -0.2342, -0.0237, -0.0849,
          -0.2068],
         [-0.6619, -0.5558,  0.5254, -0.3041,  0.1050, -0.3706,  0.0107,
          -0.4898],
         [-0.8187, -0.4154,  0.5876, -0.4875,  0.2922, -0.5425, -0.0012,
          -0.5605],
         [-0.5957, -0.2417,  0.3288, -0.3884,  0.1129, -0.3773, -0.0256,
          -0.3239]],

        [[-0.6042, -0.1195,  0.3992, -0.4868,  0.0033, -0.5287,  0.2096,
          -0.3515],
         [-0.7286, -0.1981,  0.5124, -0.6597,  0.1962, -0.5426,  0.0217,
          -0.4896],
         [-0.6133, -0.3381,  0.4341, -0.6199,  0.0887, -0.4186, -0.0337,
          -0.3343],
         [-0.7838, -0.3363,  0.5110, -0.7009,  0.1155, -0.6016,  0.0721,
          -0.4981],
   