In [1]:
import torch
import torch.nn.functional as F

In [2]:
# torch.device("cuda:0"
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Simple Self Attention without trainable weights

In [3]:
# input embeddings
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
).to(device=device)

In [4]:
inputs

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]], device='cuda:0')

In [5]:
input_query = inputs[1]
input_query

tensor([0.5500, 0.8700, 0.6600], device='cuda:0')

In [6]:
input_1 = inputs[0]
input_1

tensor([0.4300, 0.1500, 0.8900], device='cuda:0')

In [7]:
torch.dot(input_query, input_1)

tensor(0.9544, device='cuda:0')

In [8]:
inputs.shape

torch.Size([6, 3])

In [9]:
input_query.shape

torch.Size([3])

In [10]:
input_query.unsqueeze(1).shape

torch.Size([3, 1])

### Attention scores
- dot product each token with the other tokens to get the similarity score (attention score)
- norm the scores using softmax

In [11]:
attention_scores = torch.matmul(inputs, input_query.unsqueeze(1))
print(attention_scores)
print(f'attention scores shape: {attention_scores.shape}')
attention_scores = attention_scores.squeeze(1)
print(f'attention scores unsqueezed shape: {attention_scores.shape}')
print(attention_scores)

tensor([[0.9544],
        [1.4950],
        [1.4754],
        [0.8434],
        [0.7070],
        [1.0865]], device='cuda:0')
attention scores shape: torch.Size([6, 1])
attention scores unsqueezed shape: torch.Size([6])
tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865], device='cuda:0')


In [12]:
norm_attention_scores = F.softmax(attention_scores, dim=0)
print(norm_attention_scores)
print(f'shape of norm attention scores {norm_attention_scores.shape}')
print(f'sum of norm attention scores: {torch.sum(norm_attention_scores)}')

tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581], device='cuda:0')
shape of norm attention scores torch.Size([6])
sum of norm attention scores: 0.9999999403953552


### Context Vector
- weighted sum of each input * attention score

In [13]:
print(f'\ninputs.shape {inputs.shape}')
print(inputs)
print(f'\nnorm_attention_scores.shape {norm_attention_scores.shape}')
print(norm_attention_scores)

print(f'\nnorm_attention_scores.shape after unsqueeze {norm_attention_scores.unsqueeze(-1).shape}')

element_wise_score_multiplied_inputs = (norm_attention_scores.unsqueeze(-1) * inputs)
print(f'\nelement wise broadcast of attention weights of size {norm_attention_scores.unsqueeze(-1).shape} on inputs.shape {inputs.shape} = shape {element_wise_score_multiplied_inputs.shape}')
print(element_wise_score_multiplied_inputs)

print(f'\ncontext vector - weighted sum of inputs over the normed attention scores. Same size as the input vector size')
context_vector = element_wise_score_multiplied_inputs.sum(dim=0)
print(context_vector)


inputs.shape torch.Size([6, 3])
tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]], device='cuda:0')

norm_attention_scores.shape torch.Size([6])
tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581], device='cuda:0')

norm_attention_scores.shape after unsqueeze torch.Size([6, 1])

element wise broadcast of attention weights of size torch.Size([6, 1]) on inputs.shape torch.Size([6, 3]) = shape torch.Size([6, 3])
tensor([[0.0596, 0.0208, 0.1233],
        [0.1308, 0.2070, 0.1570],
        [0.1330, 0.1983, 0.1493],
        [0.0273, 0.0719, 0.0409],
        [0.0833, 0.0270, 0.0108],
        [0.0079, 0.1265, 0.0870]], device='cuda:0')

context vector - weighted sum of inputs over the normed attention scores. Same size as the input vector size
tensor([0.4419, 0.6515, 0.5683], device='cuda:0')


In [15]:
print(f'inputs.shape {inputs.shape}')
# print(f'inputs.shape.transpose {torch.transpose(inputs, inputs.shape[1], inputs.shape[0]).shape}')

inputs.shape torch.Size([6, 3])


In [18]:
# torch.matmul(inputs, input_query.unsqueeze(1))
input_query.unsqueeze(1).shape

torch.Size([3, 1])

In [19]:
Q, K, V = inputs.unsqueeze(0), inputs.unsqueeze(0), inputs.unsqueeze(0)
print(f'\n(Q) {Q.shape}')
Kt = torch.transpose(K, -2, -1)
print(f'\n(Kt) {K.shape}')
QKt = torch.matmul(Q, Kt)
print(f'\nMatmul taken to get the dot product (QKt) {QKt.shape}')
print(QKt)
Softmax_QKt = F.softmax(QKt, dim=-2)
print(f'\nSoftmax_QKt in mid dimension {Softmax_QKt.shape}')
print(Softmax_QKt)

context_vectors = torch.matmul(Softmax_QKt, V)
print(f'\nContext vectors shape {context_vectors.shape}, one for each time step and also, of the same dim as input token embeddings')
print(context_vectors)



(Q) torch.Size([1, 6, 3])

(Kt) torch.Size([1, 6, 3])

Matmul taken to get the dot product (QKt) torch.Size([1, 6, 6])
tensor([[[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
         [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
         [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
         [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
         [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
         [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]]], device='cuda:0')

Softmax_QKt in mid dimension torch.Size([1, 6, 6])
tensor([[[0.2098, 0.1385, 0.1390, 0.1435, 0.1526, 0.1385],
         [0.2006, 0.2379, 0.2369, 0.2074, 0.1958, 0.2184],
         [0.1981, 0.2333, 0.2326, 0.2046, 0.1975, 0.2128],
         [0.1242, 0.1240, 0.1242, 0.1462, 0.1367, 0.1420],
         [0.1220, 0.1082, 0.1108, 0.1263, 0.1879, 0.0988],
         [0.1452, 0.1581, 0.1565, 0.1720, 0.1295, 0.1896]]], device='cuda:0')

Context vectors shape torch.Size([1, 6, 3]), one for each time step and also, of 

In [16]:
def self_attention(inputs):
    Q = inputs
    # first dim should be the batch dimension
    if len(inputs.shape) == 2:
        Q = inputs.unsqueeze(0) # (B, C, D)
    K, V = Q,Q # (B, C, D)
    QKt = torch.matmul(Q, torch.transpose(K, -2, -1)) # (B, C, C)
    soft_QKt = F.softmax(QKt, dim=-1) # (B, C, C)

    context_vectors = torch.matmul(soft_QKt, V) # (B, C, D)

    return context_vectors


In [17]:
self_attention(inputs)

tensor([[[0.4421, 0.5931, 0.5790],
         [0.4419, 0.6515, 0.5683],
         [0.4431, 0.6496, 0.5671],
         [0.4304, 0.6298, 0.5510],
         [0.4671, 0.5910, 0.5266],
         [0.4177, 0.6503, 0.5645]]], device='cuda:0')

# Simple Self Attention with trainable weights

In [51]:
x_2 = inputs[1] # second input element
d_in = inputs.shape[1] # the input embedding size, d=3
d_out = 2 # the output embedding size, d=2

In [52]:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out, device=device))
W_key = torch.nn.Parameter(torch.rand(d_in, d_out, device=device))
W_value = torch.nn.Parameter(torch.rand(d_in, d_out, device=device))

print(f'\nShape of W query, key, value => (d_embedding, d_out) {W_query.shape}')


Shape of W query, key, value => (d_embedding, d_out) torch.Size([3, 2])


In [64]:
# Add a batch dimension
inputs = inputs.unsqueeze(0)

In [70]:
K.shape
Q.shape

torch.Size([1, 6, 2])

In [74]:
Q = inputs @ W_query
K = inputs @ W_key
V = inputs @ W_value
print(f'\nmatmul inputs of shape {inputs.shape} with Q,K,V weight matrices of shape {W_query.shape} resulting in tensors of shapes {Q.shape}')
print(Q)

QKt = Q @ K.transpose(-2, -1)
soft_QKt = F.softmax(QKt / K.shape[-1]**0.5, dim=-1)
print(f'\nmatmul QKt with shape {QKt.shape} (Scalar Attention scores)')
print(QKt)
print(f'\nApplied softmax')
print(soft_QKt)

QKtV = soft_QKt @ V
print(f'\nmatmul normed attention scores {soft_QKt.shape} with V {V.shape} to get context_vectors, tensor of shape {QKtV.shape}')
print(QKtV)


matmul inputs of shape torch.Size([1, 6, 3]) with Q,K,V weight matrices of shape torch.Size([3, 2]) resulting in tensors of shapes torch.Size([1, 6, 2])
tensor([[[0.0689, 1.1501],
         [0.2340, 1.2777],
         [0.2308, 1.2754],
         [0.1453, 0.6264],
         [0.1075, 0.8745],
         [0.1842, 0.6994]]], device='cuda:0', grad_fn=<UnsafeViewBackward0>)

matmul QKt with shape torch.Size([1, 6, 6]) (Scalar Attention scores)
tensor([[[0.8082, 1.1077, 1.0858, 0.6298, 0.3861, 0.8811],
         [1.0512, 1.4657, 1.4388, 0.8304, 0.5491, 1.1434],
         [1.0465, 1.4589, 1.4320, 0.8266, 0.5459, 1.1384],
         [0.5451, 0.7642, 0.7505, 0.4325, 0.2925, 0.5925],
         [0.6682, 0.9246, 0.9070, 0.5246, 0.3356, 0.7276],
         [0.6301, 0.8861, 0.8705, 0.5012, 0.3434, 0.6845]]], device='cuda:0',
       grad_fn=<UnsafeViewBackward0>)

Applied softmax
tensor([[[0.1632, 0.2016, 0.1985, 0.1438, 0.1211, 0.1718],
         [0.1592, 0.2135, 0.2095, 0.1362, 0.1116, 0.1700],
         [0.1593,

In [77]:
class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = torch.nn.Parameter(torch.rand(d_in, d_out, device=device))
        self.W_key = torch.nn.Parameter(torch.rand(d_in, d_out, device=device))
        self.W_value = torch.nn.Parameter(torch.rand(d_in, d_out, device=device))

    def forward(self, x):

        if len(x.shape) == 2:
            x = x.squeeze(0)
        
        Q = x @ self.W_query
        K = x @ self.W_key
        V = x @ self.W_value
        
        QKt = Q @ K.transpose(-2,-1)
        soft_QKt = F.softmax(QKt / K.shape[-1]**0.5, dim=-1)
        
        QKtV = soft_QKt @ V
        return QKtV

In [79]:
torch.manual_seed(123)
self_attention_layer1 = SelfAttention(inputs.shape[-1], 2)
print(self_attention_layer1(inputs))

tensor([[[0.7488, 0.3065],
         [0.7599, 0.3101],
         [0.7597, 0.3101],
         [0.7365, 0.3013],
         [0.7424, 0.3038],
         [0.7405, 0.3027]]], device='cuda:0', grad_fn=<UnsafeViewBackward0>)
