In [26]:
import numpy as np


import torch
import torch.nn as nn
import torch.nn.functional as F


model parameters

In [27]:
# data hyperpara
seq_len = 8

#model hyperpara
embed_dim = 128

# training hyperpara
batch_size = 5

One attention head

In [29]:
# create one attention head
class OneAttentionHead(nn.Module):
    def __init__(self,embed_dim):
        super().__init__()

        # create q,k,v matrices
        self.key = nn.Linear(embed_dim,embed_dim, bias=False)
        self.query = nn.Linear(embed_dim,embed_dim, bias=False)
        self.value = nn.Linear(embed_dim,embed_dim, bias=False)
        self.W0 = nn.Linear(embed_dim,embed_dim, bias=False)

    def forward(self,x):
        #run the token embedd vectors through attention
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        y = F.scaled_dot_product_attention(q,k,v,is_causal=True)
        #is_causal make sures the time causal mask is included in calculations
        y = self.W0(y) #linear transfor

        return y
    

In [30]:
# explore the attention head
onehead = OneAttentionHead(embed_dim)

print(onehead)

#run some fake data through
tokenEmebeds = torch.randn(batch_size, seq_len, embed_dim)
out = onehead(tokenEmebeds)
print(f'\nOutput ({out.shape}): \n{out}')

OneAttentionHead(
  (key): Linear(in_features=128, out_features=128, bias=False)
  (query): Linear(in_features=128, out_features=128, bias=False)
  (value): Linear(in_features=128, out_features=128, bias=False)
  (W0): Linear(in_features=128, out_features=128, bias=False)
)

Output (torch.Size([5, 8, 128])): 
tensor([[[-9.4601e-02,  4.4920e-01, -1.6067e-01,  ..., -2.6583e-01,
          -2.5213e-01, -8.2024e-02],
         [-1.8009e-01,  2.7770e-01, -3.0519e-01,  ..., -3.2870e-01,
          -1.7446e-01, -2.4908e-01],
         [-7.4834e-02,  2.3881e-01, -1.7252e-01,  ...,  3.9444e-02,
          -2.8884e-02, -1.0802e-01],
         ...,
         [-8.4594e-02,  3.7004e-02, -1.5335e-01,  ..., -1.3751e-01,
           5.5471e-02, -1.3126e-01],
         [-9.1445e-02,  5.3801e-02, -2.9243e-01,  ..., -1.8555e-01,
          -8.3495e-03, -2.5617e-01],
         [-4.2297e-02,  8.8682e-02, -2.2556e-01,  ..., -1.3645e-01,
          -5.6437e-02, -2.4334e-01]],

        [[-1.6051e-01, -7.5624e-02, -6.3366

In [33]:
tokenEmebeds.shape

torch.Size([5, 8, 128])

In [34]:
# In single head attention, we process all of the embeding dims all at once
# in multi head, the embed_dimensions broke up into slices


Transformer block

In [45]:
class TransformerBlock(nn.Module):
    def __init__(self,embed_dim):
        super().__init__()

        #attention sublayer
        self.layerNormAttn = nn.LayerNorm(embed_dim)
        self.attn = OneAttentionHead(embed_dim)

        #feedfwd (MLP) sublayer
        self.layerNormMLP = nn.LayerNorm(embed_dim)
        self.W1 = nn.Linear(embed_dim,4*embed_dim) # 4x expansion
        self.gelu = nn.GELU()
        self.W2 = nn.Linear(4*embed_dim, embed_dim) #4x contraction

    def forward(self,x):

        ## ----attention sublayer ------##
        # save a copy for pre-attention data
        residual = x

        # layernorm -> atteniton
        h = self.layerNormAttn(x)
        attn_out = self.attn(h)
        # [batch_size, seq_len, embed_dim]

        #combine pre attention copy + attention adjustments
        x = residual + attn_out

        #could do this in one line:
        #  x = x + self.attn(self.layerNormAttn(x))
        # ------------------------------#

        # --------MLP sublayer -------#
        # copy of pre_MLP data (output of attn sublayer)

        residual2 = x

        #layernorm before MLO
        h2 = self.layerNormMLP(x)

        #expansion-nonlinearity-contraction
        mlp_out = self.W2(self.gelu(self.W1(h2)))

        #combine pre-MLP copy + MLP adjustment
        y = residual2 + mlp_out
        #-----------------------------â€“#
        # y is [batch_size, seq_len, embed_dim]

        # y either goes to next transformer block
        # or to final unembedding matrix and then to token and text
        
        return y
        
        

In [46]:
# create and explore an instance
transblock = TransformerBlock(embed_dim)

print(transblock)

TransformerBlock(
  (layerNormAttn): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (attn): OneAttentionHead(
    (key): Linear(in_features=128, out_features=128, bias=False)
    (query): Linear(in_features=128, out_features=128, bias=False)
    (value): Linear(in_features=128, out_features=128, bias=False)
    (W0): Linear(in_features=128, out_features=128, bias=False)
  )
  (layerNormMLP): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (W1): Linear(in_features=128, out_features=512, bias=True)
  (gelu): GELU(approximate='none')
  (W2): Linear(in_features=512, out_features=128, bias=True)
)


In [47]:
#push data through
out = transblock(tokenEmebeds)
print(f'\nOutput ({out.shape}): \n{out}')



Output (torch.Size([5, 8, 128])): 
tensor([[[-1.0108, -2.0522,  1.7651,  ..., -1.0905,  0.5553,  0.2084],
         [-1.4508,  0.0729,  0.1084,  ...,  0.4905,  0.6473, -2.0556],
         [ 1.6218,  0.6352,  0.1972,  ..., -1.1110, -1.0756, -1.0169],
         ...,
         [-0.3379, -0.2691,  0.7948,  ...,  0.1477, -0.0889, -1.7009],
         [-0.9677, -0.0556, -1.2894,  ..., -1.0842, -0.8767, -1.3239],
         [ 0.6473,  0.1347, -1.1578,  ..., -0.6961, -1.1902, -0.3936]],

        [[-0.3241, -0.0264,  0.6526,  ..., -1.0890,  0.1171, -1.1695],
         [ 0.7390,  1.5468,  0.8615,  ..., -0.2498, -0.6122,  0.5031],
         [-0.3949,  1.5369, -1.9716,  ..., -0.5512,  1.7584,  0.1625],
         ...,
         [ 0.9793,  0.5105, -0.3675,  ...,  0.7062, -0.0030,  1.0130],
         [ 0.3585, -0.0196, -1.2885,  ...,  0.5898,  0.6689,  0.5115],
         [-0.7787, -0.1190, -0.2454,  ...,  0.0558,  0.0030, -0.8798]],

        [[-0.0938, -0.8779, -0.5781,  ...,  1.5676, -1.3632,  1.5236],
         