<a href="https://colab.research.google.com/github/svetaU/Attention/blob/main/attnClassifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import math
try:
  import einops
except ModuleNotFoundError: 
  !pip install --quiet einops
from einops import rearrange
import torch
from torch import nn
import torch.nn.functional as F
#try:
#    import pytorch_lightning as pl
#except ModuleNotFoundError: 
#    !pip install --quiet pytorch-lightning>=1.5
#    import pytorch_lightning as pl

# Attention module

In [2]:
class NiptMultiHeadSelfAttention(nn.Module):
    
    def __init__(self, embed_dim, heads=2, dim_head=None):
        super().__init__()
        self.dim_head = (int(embed_dim / heads)) if dim_head is None else dim_head
        _dim = self.dim_head * heads
        self.heads = heads
        self.to_qvk = nn.Linear(embed_dim, _dim * 3, bias=False)
        self.last_linear = nn.Linear( _dim, dim, bias=False)
        self.scale_factor = self.dim_head ** -0.5
        
        self._init_weights()
        
    def _init_weights(self):
        nn.init.xavier_uniform_(self.to_qvk.weight)
        nn.init.xavier_uniform_(self.last_linear.weight)
        
    def forward(self, x, mask=None, return_attention=False):
        assert x.dim() == 3
        qkv = self.to_qvk(x)
        q, k, v = tuple(rearrange(qkv, 'b t (d k h) -> k b h t d ', k=3, h=self.heads))
        scaled_dot_prod = torch.einsum('b h i d , b h j d -> b h i j', q, k) * self.scale_factor
        if mask is not None:
            assert mask.shape == scaled_dot_prod.shape[2:]
            scaled_dot_prod = scaled_dot_prod.masked_fill(mask, -np.inf)

        attention = torch.softmax(scaled_dot_prod, dim=-1)
        values = torch.einsum('b h i j , b h j d -> b h i d', attention, v)
        values = rearrange(values, 'b h t d -> b t (h d)')
        output = self.last_linear(values)
        if return_attention:
            return output, attention
        else:
            return output

# Encoder Block

In [3]:
class NiptEncoderBlock(nn.Module):
    
    def __init__(self,embed_dim,heads = 2,dim_head=None,dim_linear_block=1024, dropout = 0.0):
        super().__init__()
        self.attn_layer = NiptMultiHeadSelfAttention(self, embed_dim=embed_dim, heads=heads, dim_head=dim_head)
        self.norm1 = nn.LayerNorm(embed_dim)        
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop = nn.Dropout(dropout)
        self.linear_net = nn.Sequential(
            nn.Linear(embed_dim, dim_linear_block),
            nn.Dropout(dropout),
            nn.ReLU(inplace=True),
            nn.Linear(dim_linear_block, embed_dim)
        )
        
    def forward(self, x, mask=None):
        y = self.norm_1(self.drop(self.attn_layer(x, mask)) + x)
        return self.norm_2(self.linear_net(y) + y)

# Classifier module

In [4]:
class NiptAttentionClassifier(nn.Module):
    
    def __init__(self, embed_dim, num_layers=6,  **block_args):
        super().__init__()
        self.layers = nn.ModuleList([NiptEncoderBlock(embed_dim,**block_args) for _ in range(num_layers)])
        self.fc = nn.Linear(embed_dim,2)
        
    def forward(self, x, mask=None):
        for l in self.layers:
            x = l(x, mask=mask)
        scores = self.fc(x)
        scores = torch.softmax(scores,dim=-1)
        return scores