## File for different Model Layers and the like

In [1]:
import torch
from torch import nn

In [3]:
class selfAttentionDot(nn.Module):
    """calculating dot product self attention
    """
    def __init__(self):
        super(selfAttentionDot,self).__init__()
    
    def forward(self,q,k,v):
        #q,k,v = (batch,seq,features)
        # add activations
        depth = q.shape[2]
        scores = torch.matmul(k,q.permute([0,2,1]))/torch.sqrt(depth) #(batch,seq,seq)
        softmax_scores = torch.softmax(scores,dim=(1,2))
        return torch.matmul(softmax_scores,v) #(batch,seq,features)
        
        
        
        

In [90]:
class LayerNorm(nn.Module):
    """LayerNorm from https://arxiv.org/pdf/1607.06450.pdf
    """
    def __init__(self, feat_dim):
        super(LayerNorm,self).__init__()
        
        self.bias = nn.parameter.Parameter(data=torch.zeros(feat_dim),requires_grad=True)
        self.gain = nn.parameter.Parameter(data=torch.ones(feat_dim),requires_grad=True)
    
        self.input_shape = feat_dim
    def forward(self,inputs):
        """
        Args:
            inputs (torch.Tensor): tensor of shape (batch,seq,feat_dim)

        Returns:
            torch.tensor : layer normalized output
        """

        mean = torch.mean(inputs,dim=(1,2))
        var = torch.mean(torch.square(inputs-mean), dim =(1,2),keepdim=True)
        std = torch.sqrt(var)
        
        norm = (inputs - mean)/std
        af_norm = self.gain*norm + self.bias
        return af_norm
        
        
        

In [None]:
class TransformerEncoder(nn.Module):
    #(batch, seq, features)
    def __init__(self,size):
        super(TransformerEncoder,self).__init__()
        
        self.keys = nn.Linear(size, size)
        self.query = nn.Linear(size, size)
        self.values = nn.Linear(size, size)
        
        self.feedForward = nn.Linear(size,size)
        
        self.selfAttention = selfAttentionDot(size)
        
        self.layerNorm1 = LayerNorm(size)
        self.layerNorm2 = LayerNorm(size)
        
    def forward(self,inputs):
        #(batch, seq, features)
        shape = inputs.shape
        flatten = torch.flatten(inputs,0,1)
        k = self.keys(flatten)
        k = torch.reshape(k,shape)
        
        q = self.query(flatten)
        q = torch.reshape(q,shape)
        
        v = self.values(flatten)
        v = torch.reshape(v,shape)
        
        attention = self.selfAttention(k,q,v)
        res = inputs + attention
        res = self.layerNorm1(res)
        
        res_flatten = torch.flatten(res,0,1)
        res_flatten = self.feedForward(res_flatten)
        res_flatten = torch.reshape(res_flatten, shape)
        
        out = res + res_flatten
        out = self.layerNorm2(out)
        return out
                