## Basic Self Attention Mechanism
In this notebook implements to most basic self attention mechanism. 

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [13]:
class BasicSelfAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model
        self.query_proj = nn.Linear(d_model, d_model)
        self.key_proj = nn.Linear(d_model, d_model)
        self.value_proj = nn.Linear(d_model, d_model)

    def forward(self,x):
        ## x is of dimension (batch_size, seq_len, d_model)
        query_projected = self.query_proj(x) ## dim (batch_size, seq_len, d_model)
        key_projected = self.key_proj(x) ## dim (batch_size, seq_len, d_model)
        value_projected = self.value_proj(x) ## dim (batch_size, seq_len, d_model)

        ## do the logits and softax 
        ## We would also need the dimension of key matrix, in self attention (not multihead attention it is same as d_model)
        dim_k = float(key_projected.size()[-1])
        attention_logits = query_projected @ key_projected.transpose(-2,-1) ## dim (batch_size, seq_len, seq_len)
        attention_probs = torch.softmax(attention_logits/math.sqrt(dim_k), dim=-1) ## (batch_size, seq_len, seq_len)
        attention_score = attention_probs @ value_projected ## (batch_size, seq_len, d_model)
        return(attention_score)


basicSelfAttention = BasicSelfAttention(d_model=10)



In [14]:
input_sample = torch.randn((3,5,10))
out = basicSelfAttention(input_sample)