In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [None]:
mask=torch.full([10,10],float("-inf"))

In [11]:
def self_attention(q,k,v,mask=None):
    d_k=q.size()[-1]
    scaled=torch.matmul(q,k.transpose(-1,-2))/math.sqrt(d_k)
    if mask is not None:
        scaled+=mask
    attention=F.softmax(scaled)
    out=torch.matmul(attention,v)
    return attention,out

class MultiheadAttention(nn.Module):
    def __init__(self,input_dim,d_model,num_heads):
        super().__init__()
        self.d_model=d_model
        self.num_heads=num_heads
        self.input_dim=input_dim
        self.head_dim=d_model//num_heads
        self.qkv_layer=nn.Linear(input_dim,3*d_model)
        self.linear=nn.Linear(d_model,d_model)
    def forward(self,x,mask=None):
        batch_size,sequence_length,input_dim=x.size()
        qkv=self.qkv_layer(x)
        qkv=qkv.reshape(batch_size,sequence_length,self.num_heads,3*self.head_dim)
        qkv=qkv.permute(0,2,1,3)
        q,k,v=qkv.chunk(3,dim=-1)
        attention,values=self_attention(q,k,v,mask)
        values=values.reshape(batch_size,sequence_length,self.num_heads*self.head_dim)
        return self.linear(values)

In [14]:
class PositionalEncoding(nn.Module):
    def __init__(self,max_sequence_length,d_model):
        super().__init__()
        self.max_sequence_length=max_sequence_length
        self.d_model=d_model
    def forward(self):
        even_i=torch.arange(0,self.max_sequence_length,2).float()
        denominator=torch.pow(10000,even_i/self.d_model).float()
        pos=torch.arange(self.max_sequence_length,dtype=torch.float).reshape(self.max_sequence_length,1)
        even_PE=torch.sin(pos/denominator)
        odd_PE=torch.cos(pos/denominator)
        stacked=torch.stack([even_PE,odd_PE],dim=2)
        return torch.flatten(stacked,start_dim=1,end_dim=1)


In [None]:
class LayerNormalization(nn.Module):
    def __init__(self,parameter_shape,eps=1e-5):
        super().__init__()
        self.parameter_shape=parameter_shape
        self.eps=eps
        self.gamma=nn.Parameter(torch.ones(parameter_shape))
        self.beta=nn.Paramtere(torch.zeros(parameter_shape))
    def forward(self,inputs):
        dims=[-(i+1) for i in range(len(self.parameter_shape))]
        mean=inputs.mean(dims=dims,keepdim=True)
        var=((inputs-mean)**2).mean(dim=dims,keepdim=True)
        std=(var+self.eps).sqrt()
        y=(inputs-mean)/std
        return self.gamma*y+self.beta

In [12]:
input_dim=512
d_model=512
num_heads=8

batch_size=30
sequence_length=5
x=torch.randn((batch_size,sequence_length,input_dim))
model=MultiheadAttention(input_dim,d_model,num_heads)
out=model.forward(x)

  attention=F.softmax(scaled)
