## Dependencies

In [1]:
import torch

In [2]:
torch.__version__

'2.8.0+cpu'

## Simplified Self Attention

In [3]:
# Our goal is to calculate the context vector using the current query vector and the given sequence
# We are considering a input sentence "Your journey starts with one step"

In [4]:
inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your (x^1)
[0.55, 0.87, 0.66], # journey (x^2)
[0.57, 0.85, 0.64], # starts (x^3)
[0.22, 0.58, 0.33], # with (x^4)
[0.77, 0.25, 0.10], # one (x^5)
[0.05, 0.80, 0.55]] # step (x^6)
)

#### Calculate Attention Scores

In [5]:
query = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])

# Calculate the attention score for each token w.r.t query token
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(query, x_i)
print(attn_scores_2)



# # torch.dot function can also be implemented using following way
# # Dot product is way of multiplying two vectors element-wise and then summing the product 
# attn_scores_2_brute = torch.empty(inputs.shape[0])
# for i, x_i in enumerate(inputs):                # at i = 0 - x_i = [0.43, 0.15, 0.89], query = [0.55, 0.87, 0.66]
#     temp_attn = 0
#     for j, x_i_i in enumerate(x_i):             # at i = 0 and j = 0 , x_i_i = 0.43, query[j] = query[0] = 0.55
#         temp_attn += query[j] * x_i_i
#     attn_scores_2_brute[i] = temp_attn
# print(attn_scores_2_brute)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


#### Normalize the attention scores

In [6]:
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
print("Attention weights:", attn_weights_2_tmp)
print("Sum:", attn_weights_2_tmp.sum())

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)


In [7]:
def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)
attn_weights_2_naive = softmax_naive(attn_scores_2)
print("Attention weights:", attn_weights_2_naive)
print("Sum:", attn_weights_2_naive.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


In [8]:
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


#### Calculate Context Vector

In [9]:
query = inputs[1]
context_vec_2 = torch.zeros(query.shape)
for i,x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i]*x_i
print(context_vec_2)

tensor([0.4419, 0.6515, 0.5683])


#### Putting it together

In [10]:
# attn_scores_tm = torch.empty(6,6)
# for i, x_i in enumerate(inputs):
#     for j, x_j in enumerate(inputs):
#         attn_scores_tm[i,j] = torch.dot(x_i, x_j)

# print(attn_scores_tm)



# The better way of doing it is 

attn_scores = inputs @ inputs.T
attn_scores

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

In [11]:
attn_weights = torch.softmax(attn_scores, dim=-1)     # Here we are keeping dim=-1 so that it should normalize values on column level.
                                                                # If normalize values based on column then on row level the sum will be 1
attn_weights

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

In [12]:
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


## Self-Attention with trainable weights

In [13]:
inputs = torch.tensor(
[[0.43, 0.15, 0.89], # Your (x^1)
[0.55, 0.87, 0.66], # journey (x^2)
[0.57, 0.85, 0.64], # starts (x^3)
[0.22, 0.58, 0.33], # with (x^4)
[0.77, 0.25, 0.10], # one (x^5)
[0.05, 0.80, 0.55]] # step (x^6)
)

In [14]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2
print(x_2)
print(d_in)
print(d_out)

tensor([0.5500, 0.8700, 0.6600])
3
2


#### Computing the attention weights step by step

In [15]:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in,d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in,d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in,d_out), requires_grad=False)


In [16]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)

tensor([0.4306, 1.4551])


In [17]:
keys = inputs @ W_key
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


In [18]:
# Calculate the attention score for w22 
keys_2 = keys[1]
attn_score_22 = query_2.dot(keys_2)
print(attn_score_22)

tensor(1.8524)


In [19]:
attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


In [20]:
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


In [21]:
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

tensor([0.3061, 0.8210])


#### Self-attention python class

In [22]:
import torch.nn as nn


class SelfAttention_v1(nn.Module):

    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_k = nn.Parameter(torch.rand(d_in, d_out))
        self.W_q = nn.Parameter(torch.rand(d_in, d_out))
        self.W_v = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        # x - (6, 3)
        keys = x @ self.W_k # 6,3 @ 3, 2 -> 6,2 
        values = x @ self.W_v # 6,3 @ 3, 2 -> 6,2 
        queries = x @ self.W_q # 6,3 @ 3, 2 -> 6,2 

        attn_scores = queries @ keys.T # 6, 2 @ 2, 6

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

        context_vec = attn_weights @ values
        return context_vec

In [23]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2947, 0.7956],
        [0.3015, 0.8132],
        [0.3010, 0.8120],
        [0.2925, 0.7902],
        [0.2863, 0.7737],
        [0.2979, 0.8043]], grad_fn=<MmBackward0>)


In [24]:
class SelfAttention_v2(nn.Module):

    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_k = nn.Linear(d_in, d_out, bias='qkv_bias')
        self.W_q = nn.Linear(d_in, d_out, bias='qkv_bias')
        self.W_v = nn.Linear(d_in, d_out, bias='qkv_bias')

    def forward(self, x):
        keys = self.W_k(x)
        queries = self.W_q(x)
        values = self.W_v(x)

        attn_scores = queries @ keys.T

        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )

        context_vectors = attn_weights @ values
        return context_vectors

torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.3377, -0.2826],
        [-0.3369, -0.2830],
        [-0.3369, -0.2829],
        [-0.3367, -0.2833],
        [-0.3372, -0.2825],
        [-0.3365, -0.2835]], grad_fn=<MmBackward0>)


## Self Attention with Causal Masking and Dropout

#### Using simpler masking (torch.tril)

In [25]:
queries = sa_v2.W_q(inputs)
keys = sa_v2.W_k(inputs)
attn_scores = queries @ keys.T
print(attn_scores)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

tensor([[-0.4559, -0.5804, -0.5755, -0.4223, -0.3861, -0.4831],
        [-0.3016, -0.4029, -0.4003, -0.3148, -0.2953, -0.3471],
        [-0.2871, -0.3840, -0.3816, -0.3005, -0.2820, -0.3311],
        [-0.4552, -0.6044, -0.6003, -0.4682, -0.4379, -0.5185],
        [-0.1038, -0.1428, -0.1421, -0.1161, -0.1104, -0.1255],
        [-0.5841, -0.7729, -0.7676, -0.5957, -0.5562, -0.6614]],
       grad_fn=<MmBackward0>)
tensor([[0.1698, 0.1555, 0.1560, 0.1739, 0.1784, 0.1665],
        [0.1716, 0.1598, 0.1600, 0.1700, 0.1724, 0.1662],
        [0.1714, 0.1601, 0.1604, 0.1698, 0.1721, 0.1662],
        [0.1736, 0.1562, 0.1566, 0.1720, 0.1757, 0.1660],
        [0.1690, 0.1644, 0.1645, 0.1675, 0.1682, 0.1664],
        [0.1751, 0.1532, 0.1538, 0.1736, 0.1786, 0.1658]],
       grad_fn=<SoftmaxBackward0>)


In [26]:
# Create a mask which mark diagonal elements as 0
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

# Multiple the atten_weights to remove the future elements
masked_simple = attn_weights * mask_simple
display(masked_simple)

# Normalize the values
row_sums = masked_simple.sum(dim=-1, keepdim=True)
print(f"row sums shape: {row_sums.shape}")
print(row_sums)
masked_simple_norm = masked_simple / row_sums
masked_simple_norm

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


tensor([[0.1698, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1716, 0.1598, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1714, 0.1601, 0.1604, 0.0000, 0.0000, 0.0000],
        [0.1736, 0.1562, 0.1566, 0.1720, 0.0000, 0.0000],
        [0.1690, 0.1644, 0.1645, 0.1675, 0.1682, 0.0000],
        [0.1751, 0.1532, 0.1538, 0.1736, 0.1786, 0.1658]],
       grad_fn=<MulBackward0>)

row sums shape: torch.Size([6, 1])
tensor([[0.1698],
        [0.3314],
        [0.4919],
        [0.6583],
        [0.8336],
        [1.0000]], grad_fn=<SumBackward1>)


tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5179, 0.4821, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3485, 0.3255, 0.3260, 0.0000, 0.0000, 0.0000],
        [0.2636, 0.2372, 0.2379, 0.2612, 0.0000, 0.0000],
        [0.2027, 0.1972, 0.1973, 0.2010, 0.2018, 0.0000],
        [0.1751, 0.1532, 0.1538, 0.1736, 0.1786, 0.1658]],
       grad_fn=<DivBackward0>)

#### Using -inf (torch.triu)

In [83]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
print(mask)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])
tensor([[[-0.0696,    -inf,    -inf,    -inf,    -inf,    -inf],
         [-0.0655, -0.0646,    -inf,    -inf,    -inf,    -inf],
         [-0.0593, -0.0528, -0.0505,    -inf,    -inf,    -inf],
         [-0.0476, -0.0653, -0.0640, -0.0372,    -inf,    -inf],
         [ 0.0680,  0.1742,  0.1745,  0.0978,  0.1304,    -inf],
         [-0.1090, -0.1804, -0.1783, -0.1022, -0.0891, -0.1296]],

        [[-0.0696,    -inf,    -inf,    -inf,    -inf,    -inf],
         [-0.0655, -0.0646,    -inf,    -inf,    -inf,    -inf],
         [-0.0593, -0.0528, -0.0505,    -inf,    -inf,    -inf],
         [-0.0476, -0.0653, -0.0640, -0.0372,    -inf,    -inf],
         [ 0.0680,  0.1742,  0.1745,  0.0978,  0.1304,    -inf],
         [-0.1090, -0.1804, -0.1783, -0.1022, -0.0891, -0.1296]]],
       grad

In [28]:
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5179, 0.4821, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3485, 0.3255, 0.3260, 0.0000, 0.0000, 0.0000],
        [0.2636, 0.2372, 0.2379, 0.2612, 0.0000, 0.0000],
        [0.2027, 0.1972, 0.1973, 0.2010, 0.2018, 0.0000],
        [0.1751, 0.1532, 0.1538, 0.1736, 0.1786, 0.1658]],
       grad_fn=<SoftmaxBackward0>)


#### Dropout

In [29]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
example = torch.ones(6, 6)
print(dropout(example))
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])
tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6971, 0.6509, 0.6520, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4745, 0.4758, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3944, 0.0000, 0.4019, 0.0000, 0.0000],
        [0.0000, 0.3064, 0.3075, 0.3473, 0.3571, 0.0000]],
       grad_fn=<MulBackward0>)


#### Batch Implementation

In [None]:
print(f"Inputs shape: {inputs.shape}")
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

Inputs shape: torch.Size([6, 3])
torch.Size([2, 6, 3])


#### Putting all together

In [75]:
# Stacking the same input twice to create a batch
batched_inputs = torch.stack((inputs, inputs), dim=0)
d_in = batched_inputs.shape[2]
d_out = 2
context_length = batched_inputs.shape[1]

print(f"d_in: {d_in} \n d_out: {d_out} \n context_length: {context_length}")

w_query = nn.Linear(d_in, d_out, bias=False)
w_key = nn.Linear(d_in, d_out, bias=False)
w_value = nn.Linear(d_in, d_out, bias=False)
print(f"Dimensions of w metrices are: {w_query}")

# Number of elements to erase from the given matrix
dropout = nn.Dropout(0.5)

# Creating this mask for causal attntion so that only previous and current token so that relvancy with only previous and current token is available
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
print(f"Shape of mask is : {mask.shape}")

d_in: 3 
 d_out: 2 
 context_length: 6
Dimensions of w metrices are: Linear(in_features=3, out_features=2, bias=False)
Shape of mask is : torch.Size([6, 6])


In [76]:
print("Weights of w_query:\n", w_query.weight)
print("Bias of w_query:\n", w_query.bias)

Weights of w_query:
 Parameter containing:
tensor([[-0.3976,  0.1673,  0.2912],
        [-0.3153,  0.0842, -0.2289]], requires_grad=True)
Bias of w_query:
 None


In [77]:
b, num_tokens, d_in = batched_inputs.shape

# Getting the q,k,v matrices from q,k,v matrices
# This is done by multiplying batched_inputs with w_key weights
keys = w_key(batched_inputs)
values = w_value(batched_inputs)
queries = w_query(batched_inputs)
# display((batched_inputs @ w_key.weight.T)==(w_key(batched_inputs))) # True

print(keys.shape)
print(values.shape)
print(queries.shape)

torch.Size([2, 6, 2])
torch.Size([2, 6, 2])
torch.Size([2, 6, 2])


In [78]:
print(keys.shape)
print(keys.transpose(1,2).shape)

torch.Size([2, 6, 2])
torch.Size([2, 2, 6])


In [79]:
# Check the printed mask.bool(). Basically mask.bool() tells at which places to fill -torch.inf
mask.bool()[:num_tokens, :num_tokens]

tensor([[False,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True],
        [False, False, False, False,  True,  True],
        [False, False, False, False, False,  True],
        [False, False, False, False, False, False]])

In [80]:
# Focus that we are only transposing 2nd and 3rd dimensions of the key vectors
attn_scores = queries @ keys.transpose(1,2) # [2,6,2] @ [2, 2, 6] --> [2, 6, 6]
print(f"attn_scores shape: {attn_scores.shape}")

attn_scores.masked_fill_(mask.bool(), -torch.inf)
attn_scores

attn_scores shape: torch.Size([2, 6, 6])


tensor([[[-0.0696,    -inf,    -inf,    -inf,    -inf,    -inf],
         [-0.0655, -0.0646,    -inf,    -inf,    -inf,    -inf],
         [-0.0593, -0.0528, -0.0505,    -inf,    -inf,    -inf],
         [-0.0476, -0.0653, -0.0640, -0.0372,    -inf,    -inf],
         [ 0.0680,  0.1742,  0.1745,  0.0978,  0.1304,    -inf],
         [-0.1090, -0.1804, -0.1783, -0.1022, -0.0891, -0.1296]],

        [[-0.0696,    -inf,    -inf,    -inf,    -inf,    -inf],
         [-0.0655, -0.0646,    -inf,    -inf,    -inf,    -inf],
         [-0.0593, -0.0528, -0.0505,    -inf,    -inf,    -inf],
         [-0.0476, -0.0653, -0.0640, -0.0372,    -inf,    -inf],
         [ 0.0680,  0.1742,  0.1745,  0.0978,  0.1304,    -inf],
         [-0.1090, -0.1804, -0.1783, -0.1022, -0.0891, -0.1296]]],
       grad_fn=<MaskedFillBackward0>)

In [81]:
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)  # Here we are applying softmax function of dim=-1 i.e. column level so that total sum of all values in rows will be 1

display(attn_weights[0])
# Applying dropout
attn_weights = dropout(attn_weights)
display(attn_weights[0])

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4998, 0.5002, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3321, 0.3337, 0.3342, 0.0000, 0.0000, 0.0000],
        [0.2510, 0.2479, 0.2481, 0.2529, 0.0000, 0.0000],
        [0.1915, 0.2064, 0.2064, 0.1956, 0.2001, 0.0000],
        [0.1693, 0.1609, 0.1612, 0.1701, 0.1717, 0.1668]],
       grad_fn=<SelectBackward0>)

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0003, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6643, 0.6673, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5021, 0.4958, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4128, 0.0000, 0.3911, 0.0000, 0.0000],
        [0.0000, 0.3219, 0.0000, 0.0000, 0.0000, 0.3337]],
       grad_fn=<SelectBackward0>)

#### Causal Attention with dropout

In [84]:
class CausalAttention(nn.Module):

    def __init__(self, d_in: int, d_out: int, context_length: int, dropout: float, qkv_bias: bool = False):
        super().__init__()
        # Intialise the weight matrices which will trainable
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        # Initialize the droput layer to drop certail tensors. Used only while training
        self.dropout = nn.Dropout(dropout)

        # Mask for Causal attention
        self.mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)


    def forward(self, input):
        queries = self.W_query(input) # 2 ,6 ,2 
        keys = self.W_key(input)# 2 ,6 ,2 
        values = self.W_value(input)# 2 ,6 ,2 

        # Calculate the attention scores by matmul of queries and keys
        # Important thing to note here is that we are only tranposing last two dimensions of key matrix
        attn_scores = queries @ keys.transpose(-2,-1)

        # Apply causal attention mask to the attn_scores
        # In pytorch, _ at the end of function denotes inplace operation
        # we are converting the mask into bool. Basically identifying the places where we have to fill -infinity
        # For more details check putting all together implementation
        attn_scores.masked_fill_(self.mask.bool(), -torch.inf)

        # Calculate attention weights
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) #  Here we are normalizing values at -1 i.e. column level. So that sum of each row should add up to 1.

        # Dropout random tensors
        attn_weights = self.dropout(attn_weights)

        # Calculate context vectors
        context_vector = attn_weights @ values 
        return context_vector

In [85]:
torch.manual_seed(123)
print(f"Inputs shape: {inputs.shape}")
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)

Inputs shape: torch.Size([6, 3])
torch.Size([2, 6, 3])
context_vecs.shape: torch.Size([2, 6, 2])


## Multi-Head Attention

### MHA Wrapper using self attention class

In [37]:
class SelfAttention_rev(nn.Module):

    def __init__(self, d_in, d_out, context_length):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias='qkv_bias')
        self.W_key = nn.Linear(d_in, d_out, bias='qkv_bias')
        self.W_value = nn.Linear(d_in, d_out, bias='qkv_bias')

        self.mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)

        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x):
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.T

        masked_attn_scores = attn_scores.masked_fill(self.mask.bool(), -torch.inf)

        attn_weights = torch.softmax(masked_attn_scores / keys.shape[1]**0.5, dim=-1)

        context_vectors = attn_weights @ values

        return context_vectors

In [None]:
torch.manual_seed(789)
sa_v2 = SelfAttention_rev(d_in, d_out, inputs.shape[0])
print(sa_v2(inputs))

In [None]:
class SelfAttention_rev(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) # 
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        self.dropout = nn.Dropout(dropout)

        self.mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)

    def forward(self, x):
        batch, num_tokens, d_in = x.shape

        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1,2)
        attn_scores.masked_fill_(mask.bool()[:num_tokens,:num_tokens], -torch.inf)


        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

        context_vectors = attn_weights @ values

        return context_vectors




### MHA optimized

tensor([[-0.1378, -0.4154],
        [-0.2589, -0.3110],
        [-0.2976, -0.2762],
        [-0.3150, -0.2884],
        [-0.3409, -0.2551],
        [-0.3426, -0.2714]], grad_fn=<MmBackward0>)
