In [41]:
import torch
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your (x^1)
    [0.55, 0.87, 0.66], # journey (x^2)
    [0.57, 0.85, 0.64], # starts (x^3)
    [0.22, 0.58, 0.33], # with (x^4)
    [0.77, 0.25, 0.10], # one (x^5)
    [0.05, 0.80, 0.55]] # step (x^6)
)

In [42]:
query = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])
for i in range(len(inputs)):
    attn_scores_2[i]=torch.dot(query, inputs[i])
print(attn_scores_2)
# same as vec mat mul
t = torch.transpose(inputs,0,1)
inputs[1]@t

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

In [43]:
# Simple normalization 
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
print(attn_weights_2_tmp.sum())

# More common to use softmax to normalize
def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)
attn_weights_2_naive = softmax_naive(attn_scores_2)
print(attn_weights_2_naive)
print(attn_weights_2_naive.sum())

# better to use optimized torch softmax
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)

tensor(1.0000)
tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
tensor(1.)


In [44]:
# create context vector by multiplying each input by its attention weight
query = inputs[1] # 2nd input token is the query
context_vec_2 = torch.zeros(query.shape)
for i,x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i]*x_i
print(context_vec_2)

tensor([0.4419, 0.6515, 0.5683])


In [45]:
# Now follow the same steps but for each input embedding rather than just one
input_length = inputs.shape[0]
embedding_length = inputs.shape[1]
attn_scores = torch.empty(input_length, input_length)
for i in range(input_length):
    for j in range(input_length):
        attn_scores[i,j] = torch.dot(inputs[i], inputs[j])

# or just use matrix multiplication 
attn_scores = inputs @ inputs.T

In [46]:
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)
print(inputs)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])
tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])


In [47]:
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


In [48]:
# TRAINABLE ATTENTION ON SINGLE INPUT 
# computing only z2 (context vector for 2nd input)

x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)



In [49]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2)
W_key


tensor([0.4306, 1.4551])


Parameter containing:
tensor([[0.1366, 0.1025],
        [0.1841, 0.7264],
        [0.3153, 0.6871]])

In [50]:
# calculation for just second input requires key value matrices for rest of the inputs
keys = inputs @ W_key
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


In [51]:
keys_2 = keys[1]
print(keys_2)
attn_score_22 = query_2.dot(keys_2)
attn_score_22

tensor([0.4433, 1.1419])


tensor(1.8524)

In [52]:
# for all other inputs, attn score relative to x[1] is done by dot of the Wk of the target input and the Wq of x[1]
attn_scores_2 = query_2 @ keys.T
attn_scores_2

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])

In [53]:
# go attn scores to attn weights by softmax
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


In [54]:
context_vec_2  = attn_weights_2 @ values
context_vec_2 

tensor([0.3061, 0.8210])

In [55]:
queries = inputs @ W_query
queries

tensor([[0.2309, 1.0966],
        [0.4306, 1.4551],
        [0.4300, 1.4343],
        [0.2355, 0.7990],
        [0.2983, 0.6565],
        [0.2568, 1.0533]])

In [56]:
import torch.nn as nn   
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))
            
    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T # omega
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec

In [57]:
torch.manual_seed(789)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.7584, 0.9454],
        [0.7860, 0.9871],
        [0.7858, 0.9868],
        [0.7589, 0.9410],
        [0.7648, 0.9542],
        [0.7631, 0.9471]], grad_fn=<MmBackward0>)


In [58]:
import torch.nn as nn   
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
            
    def forward(self, x):
        keys =  self.W_key(x)
        queries =  self.W_query(x)
        values =  self.W_value(x)
        attn_scores = queries @ keys.T # omega
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec

In [59]:
torch.manual_seed(789)
sa_v2= SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))   
print(list(sa_v2.parameters()))
print(list(sa_v1.parameters()))


tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)
[Parameter containing:
tensor([[ 0.3161,  0.4568,  0.5118],
        [-0.1683, -0.3379, -0.0918]], requires_grad=True), Parameter containing:
tensor([[ 0.4058, -0.4704,  0.2368],
        [ 0.2134, -0.2601, -0.5105]], requires_grad=True), Parameter containing:
tensor([[ 0.2526, -0.1415, -0.1962],
        [ 0.5191, -0.0852, -0.2043]], requires_grad=True)]
[Parameter containing:
tensor([[0.7737, 0.8956],
        [0.9433, 0.3543],
        [0.2074, 0.4205]], requires_grad=True), Parameter containing:
tensor([[0.8514, 0.0926],
        [0.7051, 0.6848],
        [0.2748, 0.0579]], requires_grad=True), Parameter containing:
tensor([[0.7187, 0.3775],
        [0.3301, 0.9496],
        [0.4262, 0.3230]], requires_grad=True)]


In [60]:
# Linear intializes weights using a more efficient scheme, but copying shows the two models have same behavior
sa1 = SelfAttention_v1(d_in, d_out)
sa2 = SelfAttention_v2(d_in, d_out)
sa1.W_query.data=torch.t(sa2.W_query.weight)
sa1.W_value.data=torch.t(sa2.W_value.weight)
sa1.W_key.data=torch.t(sa2.W_key.weight)
print(sa1(inputs))
print(sa2(inputs))


tensor([[0.1839, 0.0178],
        [0.1815, 0.0205],
        [0.1818, 0.0202],
        [0.1826, 0.0191],
        [0.1875, 0.0144],
        [0.1799, 0.0218]], grad_fn=<MmBackward0>)
tensor([[0.1839, 0.0178],
        [0.1815, 0.0205],
        [0.1818, 0.0202],
        [0.1826, 0.0191],
        [0.1875, 0.0144],
        [0.1799, 0.0218]], grad_fn=<MmBackward0>)


In [61]:
# Masked Attention
# Model can only consider current and previous tokens

# this means turning the attention matrix into a lower triangular matrix
queries = sa_v2.W_query(inputs) #A
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)
# above just calculates weights normally, then we can use torch.tril to mask out the upper diagonal
mask = torch.tril(torch.ones(attn_weights.shape))
masked_simple = torch.tril(attn_weights)
# same as using the mask in the calculation
attn_weights * mask
# now renormalize 
row_sums = masked_simple.sum(dim=1, keepdim=True)
masked_simple_norm =  masked_simple / row_sums
masked_simple_norm

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)

In [62]:
# Can do the masked attention in fewer steps by noticing that with softmax, -inf values are 0
# triu is uppper trianguler. maskeed_fill fills whatever is masked with the given value
mask = torch.triu(torch.ones(attn_weights.shape), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
#now softmax will turn the -inf into 0
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)
attn_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)

In [63]:
# Use dropout to prevent overfit by dropping out some of the attention weights
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
dropout(attn_weights)

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.8966, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4921, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4350, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3327, 0.0000, 0.0000, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)

In [64]:
# Add causal attention to previous implementation
# Also allow it to take batches of input instead of just one, so now input is 3D
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.context_length = context_length
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores = attn_scores.masked_fill(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vec = attn_weights @ values
        return context_vec

In [65]:
torch.manual_seed(123)
batch = torch.stack((inputs,inputs), dim = 0)
context_length = batch.shape[1]
print(d_in ,d_out, context_length)
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)

3 2 6
context_vecs.shape: torch.Size([2, 6, 2])


In [26]:
# Multighead attention
# Most basic form, just run CausalAttention multiple times and concatenate the results
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length,
        dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(
                d_in, d_out, context_length, dropout, qkv_bias
            )
            for _ in range(num_heads)]
        )
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [66]:
# Can be made more efficient by being more parallel in head computation
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out,
        context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
        "d_out must be divisible by num_heads"
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads #A
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out) #B
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),diagonal=1)
        )
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        # print("v shape", values.shape)

        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        # print("keys shape", keys.shape)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)

        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_scores = queries @ keys.transpose(2, 3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores = attn_scores.masked_fill(mask_bool, -torch.inf)
        # print("attn_scores shape", attn_scores.shape)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2)

        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)
        return context_vec


In [39]:
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
print("batch shape", batch.shape)
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)


batch shape torch.Size([2, 6, 3])
v shape torch.Size([2, 6, 2])
keys shape torch.Size([2, 6, 2, 1])
attn_scores shape torch.Size([2, 2, 6, 6])
tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])
