In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from pathlib import Path

In [2]:
inputs = torch.tensor( [[0.43, 0.15, 0.89], [0.55, 0.87, 0.66], 
                        [0.57, 0.85, 0.64], [0.22, 0.58, 0.33], 
                        [0.77, 0.25, 0.10], [0.05, 0.80, 0.55]])

In [3]:
# create a sample batch manually
batch = torch.stack((inputs, inputs),dim=0)
print(batch)
print(batch.shape)

tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])
torch.Size([2, 6, 3])


In [4]:
# class CausalAttention(nn.Module):
#     def __init__(self, d_in, d_out, qkv_bias, context_length, dropout) -> None:
#         super().__init__()
#         self.d_in = d_in
#         self.d_out = d_out
#         self.qkv_bias = qkv_bias
#         self.context_length = context_length
#         self.dropout = nn.Dropout(dropout)
#         self.wq = nn.Linear(self.d_in,self.d_out,bias=self.qkv_bias)
#         self.wk = nn.Linear(self.d_in,self.d_out,bias=self.qkv_bias)
#         self.wv = nn.Linear(self.d_in,self.d_out,bias=self.qkv_bias)

#     def forward(self, inputs):
#         query   = self.wq(inputs)
#         keys = self.wk(inputs)
#         values= self.wv(inputs)

#         attention_scores = query @ keys.T
#         d_k = keys.shape[-1]
#         attention_weights = torch.softmax(attention_scores/d_k **0.5,dim = -1)
#         mask = torch.triu(torch.ones(self.context_length, self.context_length),diagonal=1)
#         masked_attention_weights = attention_weights * mask
#         dropout_masked_weights = self.dropout(masked_attention_weights)
#         context_vector = dropout_masked_weights @ values
#         return context_vector

        
        


In [5]:
print(inputs)
print(inputs.shape)
print(inputs.T.shape)
print(inputs.transpose(0,1))
print(inputs.transpose(0,1).shape)

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])
torch.Size([6, 3])
torch.Size([3, 6])
tensor([[0.4300, 0.5500, 0.5700, 0.2200, 0.7700, 0.0500],
        [0.1500, 0.8700, 0.8500, 0.5800, 0.2500, 0.8000],
        [0.8900, 0.6600, 0.6400, 0.3300, 0.1000, 0.5500]])
torch.Size([3, 6])


In [6]:
class CausalAttention(nn.Module):
    def __init__(self,d_in,d_out,context_length,dropout,qkv_bias=False,):
        super().__init__()
        self.d_out = d_out
        self.wq = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.wk = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.wv = nn.Linear(d_in,d_out,bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask",torch.triu(torch.ones(context_length,context_length),diagonal=1))
    
    def forward(self, x):
        batch, num_tokens, d_in = x.shape
        query = self.wq(x)
        keys = self.wk(x)
        values = self.wv(x)
        attention_scores = query @ keys.transpose(1,2) # this simply swaps the dimensions.
        masked_attention_scores = attention_scores.masked_fill(self.mask.bool()[:num_tokens,:num_tokens],-torch.inf)
        attention_weights = torch.softmax(masked_attention_scores/keys.shape[-1]**0.5,dim=-1)
        attention_weights = self.dropout(attention_weights)
        context_vector = attention_weights @ values
        return context_vector

In [7]:
torch.manual_seed(123)
context_length = batch.shape[1]
d_in = 3
d_out = 2
causal_attention = CausalAttention(d_in,d_out,context_length,dropout=0)

In [8]:
context_vectors = causal_attention(batch)
print(context_vectors.shape)
print(context_vectors)

torch.Size([2, 6, 2])
tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)
