In [3]:
import torch
import torch.nn as nn
import math

________________________________________________________________

In [6]:
class SingleHeadAttention(nn.Module):

    def __init__(self, hidden_size: int, bias: bool=True):
        super.__init__()

        head_size = hidden_size//4 # this is arbitarily determined
        ''' linear layer to generate query, key and value'''
        self.Wqkv = nn. Linear(hidden_size, head_size*3, bias=bias)
        # The output size is arbitarily fixed and then multiplied by 3 to merge Wq,Wk,Wv
        
        '''projection layer to project the final output (back to original token shape)'''
        self.proj = nn.Linear(head_size, hidden_size, bias=bias)

    def forward(self, X:torch.Tensor):

        '''B=batch size, S=sequence length, C=input dimension'''
        B, S, C = X.shape

        ''' split into q, k, v '''
        q, k, v = self.Wqkv(X).reshape(B,S,3,C//4).unbind(dim=2)

        '''compute dot product of q and k transpose'''
        attn = q@k.transpose(-2,-1)
        '''scale the dot product by dk'''
        attn=attn/math.sqrt(k.size(-1))
        '''softmax the output'''
        attn=attn.softmax(dim=-1)
        '''dot product with v'''
        attn = attn @ v

        return self.proj(attn)

_________________________________________