<a href="https://colab.research.google.com/github/teelch0/Data-Mining/blob/main/trainable_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Attention with trainable weights

In [2]:
import torch

In [3]:
inputs = torch.nn.Embedding( 4, 8 )

In [4]:
inputs = inputs.weight
inputs

Parameter containing:
tensor([[-0.5144, -0.2876, -0.5925,  0.6853, -1.4830,  1.1301,  0.5501,  0.8885],
        [-1.0101, -0.8224,  2.7662,  1.2267, -0.6145,  0.0572, -0.7568,  0.6611],
        [ 1.6995, -0.3372, -1.7805,  0.3423, -0.3928,  0.2949,  0.7520,  0.8136],
        [-0.3884, -1.6392, -0.8317, -1.4207, -0.1643, -0.4159, -1.9173,  0.2474]],
       requires_grad=True)

In [5]:
inputs = inputs.data
inputs

tensor([[-0.5144, -0.2876, -0.5925,  0.6853, -1.4830,  1.1301,  0.5501,  0.8885],
        [-1.0101, -0.8224,  2.7662,  1.2267, -0.6145,  0.0572, -0.7568,  0.6611],
        [ 1.6995, -0.3372, -1.7805,  0.3423, -0.3928,  0.2949,  0.7520,  0.8136],
        [-0.3884, -1.6392, -0.8317, -1.4207, -0.1643, -0.4159, -1.9173,  0.2474]])

In [6]:
# set dimensions
d_in = 8
d_out = 6

# create weight matrices
W_q = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
W_k = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
W_v = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )

In [7]:
# choose an input vector and transform it into our query vector using W_q
query = inputs[2] @ W_q
query

tensor([ 0.2240,  1.3947, -0.6227,  0.1006, -0.4559, -0.2327])

In [8]:
# calculate attention scores using the keys generated by W_k:
keys = inputs @ W_k
values = inputs @ W_v
print("Keys:", keys)
print("Values:", values )

Keys: tensor([[ 0.4340, -0.4435,  0.4869, -0.9997, -0.1238, -0.4004],
        [ 1.3744, -0.2518,  1.2144, -0.3979, -0.8276, -0.6358],
        [-0.4079,  1.1825,  0.7480, -0.0312,  1.5904,  0.8623],
        [-4.3143, -3.2547, -4.3417, -2.8424, -2.8922, -2.9691]])
Values: tensor([[ 0.2843,  0.5977, -0.1439, -1.5276,  0.3573, -0.4165],
        [-0.2346,  0.0294,  0.8861,  0.2544,  2.3080,  1.7589],
        [ 1.5109,  2.1382, -0.1074, -0.2093, -0.0315, -1.0489],
        [-4.7379, -4.7088, -2.3949, -1.8229, -2.1963, -2.1774]])


In [9]:
attention_scores = query @ keys.T
attention_scores

tensor([-0.7755, -0.3143,  0.1631, -1.0785])

In [10]:
attention_weights = torch.softmax( attention_scores / keys.shape[-1]**0.5, dim = -1 )
attention_weights

tensor([0.2194, 0.2649, 0.3219, 0.1939])

In [11]:
attention_weights.sum()

tensor(1.0000)

In [12]:
context_vector = attention_weights @ values
context_vector

tensor([-0.4321, -0.0858, -0.2958, -0.6886,  0.2537, -0.3853])

In [13]:
import torch.nn as nn


In [14]:
# here's a first version of a SimpleAttention class:

class SimpleAttention( nn.Module ):
  def __init__(self, d_in, d_out):
    super().__init__()
    # create weight matrices:
    self.W_q = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
    self.W_k = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
    self.W_v = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    queries = x @ self.W_q
    keys = x @ self.W_k
    values = x @ self.W_v
    scores = queries @ keys.T
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    context = weights @ values
    return context

In [15]:
# here's how to use this class:
# instantiate an instance of it:
simple = SimpleAttention( d_in = 8, d_out = 6 )

In [16]:
simple.W_v

Parameter containing:
tensor([[0.4713, 0.5174, 0.5561, 0.6869, 0.9157, 0.6996],
        [0.7924, 0.2744, 0.8399, 0.3575, 0.8785, 0.7665],
        [0.8209, 0.9577, 0.6348, 0.1826, 0.4509, 0.0427],
        [0.9900, 0.1020, 0.5539, 0.4737, 0.0176, 0.2476],
        [0.0241, 0.5900, 0.3420, 0.5407, 0.4753, 0.4030],
        [0.8273, 0.3048, 0.5215, 0.7960, 0.2324, 0.8107],
        [0.7775, 0.3525, 0.7343, 0.3644, 0.5825, 0.1267],
        [0.2497, 0.5370, 0.7109, 0.1379, 0.9974, 0.4581]])

In [17]:
context_vectors = simple( inputs )
context_vectors

tensor([[ 0.0183, -0.3816, -0.1456, -0.1467, -0.2976, -0.0868],
        [ 1.6009,  1.0210,  0.8057, -0.0985, -0.1908, -0.4073],
        [ 0.8315, -0.2106,  0.7974,  0.8025,  0.8686,  0.9164],
        [-5.3481, -2.3589, -4.4127, -2.7622, -3.2405, -2.4485]])

In [18]:
# here's a second version of a SimpleAttention class ;
# it uses nn.Linear to do things more efficiently

class SimpleAttention( nn.Module ):
  def __init__(self, d_in, d_out):
    super().__init__()
    # create weight matrices:
    self.W_q = nn.Linear( d_in, d_out, bias=False )
    self.W_k = nn.Linear( d_in, d_out, bias=False )
    self.W_v = nn.Linear( d_in, d_out, bias=False )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    queries = self.W_q( x )
    keys = self.W_k( x )
    values = self.W_v( x )
    scores = queries @ keys.T
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    context = weights @ values
    return weights

In [19]:
# here's how to use this class:
# instantiate an instance of it:
simple = SimpleAttention( d_in = 8, d_out = 6 )

In [20]:
context_vectors = simple( inputs )
context_vectors

tensor([[0.2056, 0.2703, 0.2361, 0.2880],
        [0.1612, 0.4250, 0.1154, 0.2984],
        [0.2681, 0.1686, 0.2787, 0.2845],
        [0.1823, 0.2210, 0.1754, 0.4213]], grad_fn=<SoftmaxBackward0>)

In [21]:
# the problem with this is that each context vector uses information from ALL of the embedding vectors
# in practice, we should only use information about the preceding embedding vectors
# to accomplish this, we'll implement causal attention AKA masked attention

In [22]:
# this is a hack to get some example weights to work with!
weights = simple( inputs )
weights

tensor([[0.2056, 0.2703, 0.2361, 0.2880],
        [0.1612, 0.4250, 0.1154, 0.2984],
        [0.2681, 0.1686, 0.2787, 0.2845],
        [0.1823, 0.2210, 0.1754, 0.4213]], grad_fn=<SoftmaxBackward0>)

In [23]:
# note that these have already been normalized:
weights.sum( dim=-1 )

tensor([1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<SumBackward1>)

In [24]:
# masking method #1
simple_mask = torch.tril( torch.ones( weights.shape[0], weights.shape[0] ) )
simple_mask

tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])

In [25]:
masked_weights = weights*simple_mask
masked_weights

tensor([[0.2056, 0.0000, 0.0000, 0.0000],
        [0.1612, 0.4250, 0.0000, 0.0000],
        [0.2681, 0.1686, 0.2787, 0.0000],
        [0.1823, 0.2210, 0.1754, 0.4213]], grad_fn=<MulBackward0>)

In [26]:
masked_weights.sum( dim=-1 )

tensor([0.2056, 0.5862, 0.7155, 1.0000], grad_fn=<SumBackward1>)

In [27]:
# now, we need to normalize the masked_weights so that each row has sum 1
row_sums = masked_weights.sum( dim=-1, keepdim=True)
row_sums

tensor([[0.2056],
        [0.5862],
        [0.7155],
        [1.0000]], grad_fn=<SumBackward1>)

In [28]:
masked_weights = masked_weights / row_sums
masked_weights.sum( dim=-1)

tensor([1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<SumBackward1>)

In [29]:
masked_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.2750, 0.7250, 0.0000, 0.0000],
        [0.3748, 0.2356, 0.3896, 0.0000],
        [0.1823, 0.2210, 0.1754, 0.4213]], grad_fn=<DivBackward0>)

In [30]:
# masking method #2
mask = torch.triu( torch.ones(weights.shape[0], weights.shape[0]), diagonal = 1 )
mask

tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.]])

In [31]:
mask.bool()

tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])

In [32]:
weights

tensor([[0.2056, 0.2703, 0.2361, 0.2880],
        [0.1612, 0.4250, 0.1154, 0.2984],
        [0.2681, 0.1686, 0.2787, 0.2845],
        [0.1823, 0.2210, 0.1754, 0.4213]], grad_fn=<SoftmaxBackward0>)

In [33]:
weights = weights.masked_fill( mask.bool(), -torch.inf )
weights

tensor([[0.2056,   -inf,   -inf,   -inf],
        [0.1612, 0.4250,   -inf,   -inf],
        [0.2681, 0.1686, 0.2787,   -inf],
        [0.1823, 0.2210, 0.1754, 0.4213]], grad_fn=<MaskedFillBackward0>)

In [34]:
masked_weights = torch.softmax( weights, dim=-1 )
masked_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.4344, 0.5656, 0.0000, 0.0000],
        [0.3429, 0.3104, 0.3466, 0.0000],
        [0.2324, 0.2416, 0.2308, 0.2952]], grad_fn=<SoftmaxBackward0>)

In [35]:
## Dropout
# idea: randomly select some data to leave out to avoid overfitting
dropout = nn.Dropout( 0.5 )

In [36]:
dropout( masked_weights )

tensor([[2.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000],
        [0.6859, 0.6209, 0.6932, 0.0000],
        [0.4649, 0.4832, 0.4616, 0.0000]], grad_fn=<MulBackward0>)

In [37]:
# we need to be able to give our LLM batches of input
# for example:
batches = torch.stack( (inputs, inputs), dim=0)

In [38]:
batches

tensor([[[-0.5144, -0.2876, -0.5925,  0.6853, -1.4830,  1.1301,  0.5501,
           0.8885],
         [-1.0101, -0.8224,  2.7662,  1.2267, -0.6145,  0.0572, -0.7568,
           0.6611],
         [ 1.6995, -0.3372, -1.7805,  0.3423, -0.3928,  0.2949,  0.7520,
           0.8136],
         [-0.3884, -1.6392, -0.8317, -1.4207, -0.1643, -0.4159, -1.9173,
           0.2474]],

        [[-0.5144, -0.2876, -0.5925,  0.6853, -1.4830,  1.1301,  0.5501,
           0.8885],
         [-1.0101, -0.8224,  2.7662,  1.2267, -0.6145,  0.0572, -0.7568,
           0.6611],
         [ 1.6995, -0.3372, -1.7805,  0.3423, -0.3928,  0.2949,  0.7520,
           0.8136],
         [-0.3884, -1.6392, -0.8317, -1.4207, -0.1643, -0.4159, -1.9173,
           0.2474]]])

In [39]:
# this class needs to handle batches of input!

class CausalAttention( nn.Module ):
  def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
    super().__init__()
    # create weight matrices:
    self.W_q = nn.Linear( d_in, d_out, bias=False )
    self.W_k = nn.Linear( d_in, d_out, bias=False )
    self.W_v = nn.Linear( d_in, d_out, bias=False )
    self.dropout = nn.Dropout( dropout ) #include dropout
    self.register_buffer(#use to manage memory efficiently
        'mask',
        torch.triu( torch.ones(weights.shape[0], weights.shape[0]), diagonal = 1 )
    )


  # x = embedding vectors (inputs)
  def forward( self, x ):
    b, num_tokens, d_in = x.shape
    queries = self.W_q( x )
    keys = self.W_k( x )
    values = self.W_v( x )
    scores = queries @ keys.transpose(1,2)
    scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    context = weights @ values
    return context

In [40]:
causal = CausalAttention( d_in=8, d_out=6, context_length=4, dropout=0)
#instantiate a causal attention mechanism

In [41]:
causal(batches)

tensor([[[-0.3222,  0.0286,  1.0227,  0.0562,  0.5668,  0.6717],
         [-0.4803,  0.3176,  0.6440,  0.3903,  0.3554, -0.1116],
         [-0.4202,  0.1089,  0.5496,  0.1852,  0.3861,  0.0888],
         [-0.2863, -0.0042,  0.3866, -0.0514,  0.4464,  0.3881]],

        [[-0.3222,  0.0286,  1.0227,  0.0562,  0.5668,  0.6717],
         [-0.4803,  0.3176,  0.6440,  0.3903,  0.3554, -0.1116],
         [-0.4202,  0.1089,  0.5496,  0.1852,  0.3861,  0.0888],
         [-0.2863, -0.0042,  0.3866, -0.0514,  0.4464,  0.3881]]],
       grad_fn=<UnsafeViewBackward0>)

In [42]:
W_q = nn.Linear( d_in, d_out, bias=False )
W_k = nn.Linear( d_in, d_out, bias=False )
W_v = nn.Linear( d_in, d_out, bias=False )

In [43]:
queries= W_q( batches )
queries

tensor([[[ 0.1981,  0.6930, -0.5793,  0.3826,  0.0747, -0.0669],
         [ 0.3678, -0.1562, -0.3005, -0.8592, -0.4347, -0.0084],
         [ 0.3455,  0.3496, -0.5403,  0.4469,  0.8124, -0.4717],
         [-1.0737, -0.9779,  0.6552, -0.3107,  0.4692, -0.4736]],

        [[ 0.1981,  0.6930, -0.5793,  0.3826,  0.0747, -0.0669],
         [ 0.3678, -0.1562, -0.3005, -0.8592, -0.4347, -0.0084],
         [ 0.3455,  0.3496, -0.5403,  0.4469,  0.8124, -0.4717],
         [-1.0737, -0.9779,  0.6552, -0.3107,  0.4692, -0.4736]]],
       grad_fn=<UnsafeViewBackward0>)

In [44]:
keys= W_k( batches )
keys

tensor([[[ 0.7103,  0.7806, -0.4968, -0.9811,  0.5417, -0.0483],
         [-0.8315, -0.1817,  0.5227,  0.5549, -0.6145,  0.5099],
         [ 0.3417,  0.9014, -0.3325, -0.9042,  0.4625, -0.6542],
         [-0.4073, -0.1351, -0.6780,  0.5060,  0.2485,  0.1818]],

        [[ 0.7103,  0.7806, -0.4968, -0.9811,  0.5417, -0.0483],
         [-0.8315, -0.1817,  0.5227,  0.5549, -0.6145,  0.5099],
         [ 0.3417,  0.9014, -0.3325, -0.9042,  0.4625, -0.6542],
         [-0.4073, -0.1351, -0.6780,  0.5060,  0.2485,  0.1818]]],
       grad_fn=<UnsafeViewBackward0>)

In [45]:
keys.transpose(1,2)

tensor([[[ 0.7103, -0.8315,  0.3417, -0.4073],
         [ 0.7806, -0.1817,  0.9014, -0.1351],
         [-0.4968,  0.5227, -0.3325, -0.6780],
         [-0.9811,  0.5549, -0.9042,  0.5060],
         [ 0.5417, -0.6145,  0.4625,  0.2485],
         [-0.0483,  0.5099, -0.6542,  0.1818]],

        [[ 0.7103, -0.8315,  0.3417, -0.4073],
         [ 0.7806, -0.1817,  0.9014, -0.1351],
         [-0.4968,  0.5227, -0.3325, -0.6780],
         [-0.9811,  0.5549, -0.9042,  0.5060],
         [ 0.5417, -0.6145,  0.4625,  0.2485],
         [-0.0483,  0.5099, -0.6542,  0.1818]]], grad_fn=<TransposeBackward0>)

In [46]:
#Multi-head attention
#multiple causal attentions stuck together
class MultiHeadAttention(nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
    super().__init__()
    self.heads = nn.ModuleList([CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)])

  def forward(self, x):
    return torch.cat( [ head(x) for head in self.heads], dim= -1)


In [54]:
mha = MultiHeadAttention( d_in = 8, d_out = 6, context_length = 4, dropout = 0, num_heads = 3)

In [55]:
mha_out = mha(batches)

In [56]:
mha_out

tensor([[[ 1.7469e-01,  1.2094e+00, -4.5389e-02,  3.8637e-01, -4.3214e-01,
           6.6472e-01, -2.3259e-01, -3.2051e-01, -4.5996e-01,  1.8837e-01,
           1.7489e-01,  4.7684e-01, -3.2374e-01,  1.1922e-01,  9.7719e-02,
           6.8529e-01,  5.3678e-02,  5.9794e-01],
         [ 1.1060e-02,  9.0605e-01, -2.8432e-01, -1.1351e-01,  9.9134e-02,
           4.7537e-01, -2.6532e-02, -2.8162e-01, -4.0482e-02, -2.5999e-01,
           3.0705e-01, -2.9023e-01, -2.2019e-01,  3.7587e-01,  2.5520e-01,
           4.8670e-01,  1.2887e-01,  3.4912e-01],
         [ 1.5671e-01,  7.2261e-01, -2.6303e-01,  9.4548e-02, -1.4341e-02,
           4.4316e-01,  2.3497e-02, -2.1352e-01,  3.1542e-02, -2.9462e-01,
           2.9261e-01, -3.6400e-01,  1.6516e-02, -6.3373e-02, -2.0408e-01,
           1.2524e-01,  2.2403e-02,  5.2884e-02],
         [ 1.2673e-01,  4.3821e-01, -1.9526e-01,  7.7893e-02, -5.5941e-02,
           3.6452e-01, -1.7185e-01, -2.2007e-01, -2.5148e-01,  1.3783e-01,
           2.6300e-02,  2

In [57]:
mha_out.shape

torch.Size([2, 4, 18])

In [58]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        # As in `CausalAttention`, for inputs where `num_tokens` exceeds `context_length`,
        # this will result in errors in the mask creation further below.
        # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs
        # do not exceed `context_length` before reaching this forward method.

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec

In [59]:
batches.shape

torch.Size([2, 4, 8])

In [60]:
batches

tensor([[[-0.5144, -0.2876, -0.5925,  0.6853, -1.4830,  1.1301,  0.5501,
           0.8885],
         [-1.0101, -0.8224,  2.7662,  1.2267, -0.6145,  0.0572, -0.7568,
           0.6611],
         [ 1.6995, -0.3372, -1.7805,  0.3423, -0.3928,  0.2949,  0.7520,
           0.8136],
         [-0.3884, -1.6392, -0.8317, -1.4207, -0.1643, -0.4159, -1.9173,
           0.2474]],

        [[-0.5144, -0.2876, -0.5925,  0.6853, -1.4830,  1.1301,  0.5501,
           0.8885],
         [-1.0101, -0.8224,  2.7662,  1.2267, -0.6145,  0.0572, -0.7568,
           0.6611],
         [ 1.6995, -0.3372, -1.7805,  0.3423, -0.3928,  0.2949,  0.7520,
           0.8136],
         [-0.3884, -1.6392, -0.8317, -1.4207, -0.1643, -0.4159, -1.9173,
           0.2474]]])

In [65]:
batches.view(2, 4, 2, 4)

tensor([[[[-0.5144, -0.2876, -0.5925,  0.6853],
          [-1.4830,  1.1301,  0.5501,  0.8885]],

         [[-1.0101, -0.8224,  2.7662,  1.2267],
          [-0.6145,  0.0572, -0.7568,  0.6611]],

         [[ 1.6995, -0.3372, -1.7805,  0.3423],
          [-0.3928,  0.2949,  0.7520,  0.8136]],

         [[-0.3884, -1.6392, -0.8317, -1.4207],
          [-0.1643, -0.4159, -1.9173,  0.2474]]],


        [[[-0.5144, -0.2876, -0.5925,  0.6853],
          [-1.4830,  1.1301,  0.5501,  0.8885]],

         [[-1.0101, -0.8224,  2.7662,  1.2267],
          [-0.6145,  0.0572, -0.7568,  0.6611]],

         [[ 1.6995, -0.3372, -1.7805,  0.3423],
          [-0.3928,  0.2949,  0.7520,  0.8136]],

         [[-0.3884, -1.6392, -0.8317, -1.4207],
          [-0.1643, -0.4159, -1.9173,  0.2474]]]])

In [66]:
mha = MultiHeadAttention( d_in = 8, d_out = 6, context_length=4, dropout=0, num_heads=3 )

In [67]:
mha_out= mha(batches)

In [68]:
mha_out.shape

torch.Size([2, 4, 6])