In [37]:
"""
Imports Here
"""
import numpy as np
import torch
import torch.nn as nn

In [38]:
class Attention(nn.Module):
    def __init__(self, heads, EMBED_DIMENSION):
        super().__init__()
        self.heads = heads
        self.attn = nn.MultiheadAttention(EMBED_DIMENSION, heads, batch_first=True)
        self.Q = nn.Linear(EMBED_DIMENSION, EMBED_DIMENSION, bias=False)
        self.K = nn.Linear(EMBED_DIMENSION, EMBED_DIMENSION, bias=False)
        self.V = nn.Linear(EMBED_DIMENSION, EMBED_DIMENSION, bias=False)
        
    def forward(self, x):
        Q = self.Q(x)
        K = self.K(x)
        V = self.V(x)
        
        attnout, attnweights = self.attn(Q, K, V)
        return attnout

In [43]:
class TransBlock(nn.Module):
    def __init__(self, heads, EMBED_DIMENSION, fflsize):
        super().__init__()
        self.fnorm = nn.LayerNorm(EMBED_DIMENSION)
        self.snorm = nn.LayerNorm(EMBED_DIMENSION)
        self.attn = Attention(heads, EMBED_DIMENSION)
        self.ffl = nn.Sequential(
            nn.Linear(EMBED_DIMENSION, fflsize),
            nn.GELU(),
            nn.Linear(fflsize, EMBED_DIMENSION)
        )
    
    def forward(self, x):
        """
        Switching to pre-MHA LayerNorm is supposed to give better performance,
        this is used in other models such as LLMs like GPT. Gradients are meant
        to be stabilised. This is different to the original ViT paper.
        """
        x = x + self.attn(self.fnorm(x))[0]
        x = x + self.ffl(self.snorm(x))
        return x

In [None]:
"""
Inception module for efficient 7x7 convolution
"""
class Inception(nn.Module):
    def __init__(self, dimin, dimout):
        super().__init__()
        self.branch1 = nn.Sequential(
            nn.Conv2d(dimin, dimout[0], 1, stride=(1,1)),
            nn.Conv2d(dimout[0], dimout[0], 3, stride=(1,1), padding=1),
            nn.Conv2d(dimout[0], dimout[0], 3, stride=(1,1), padding=1)
        )
        self.branch2 = nn.Sequential(
            nn.Conv2d(dimin, dimout[1]), 1, stride=(1,1),
            nn.Conv2d(dimout[1], dimout[1], 3, stride=(1,1), padding=1)
        )
        self.branch3 = nn.Sequential(
            nn.AvgPool2d(3, stride=(1,1), padding=1),
            nn.Conv2d(dimin, dimout[2], 1, stride=(1,1))
        )
        self.branch4 = nn.Sequential(
            nn.Conv2d(dimin, dimout[3], 1, stride=(1,1))
        )
    def forward(self, imgs)
        x1 = self.branch1(imgs)
        x2 = self.branch2(imgs)
        x3 = self.branch3(imgs)
        x4 = self.branch4(imgs)
        return torch.cat([x1, x2, x3, x4], dim=1)

In [44]:
"""
Vision Transformer Class to create a vision transformer model
"""
class VisionTransformer(nn.Module):
    def __init__(self, classes=2, inputsize=(1,1,1), heads=2, fflscale=2, nblocks=1):
        super().__init__()
        (self.N, self.Np, self.P) = inputsize
        """components"""
        self.proj = nn.Linear(self.P, EMBED_DIMENSION)
        self.clstoken = nn.Parameter(torch.zeros(1, 1, EMBED_DIMENSION))
        self.posembed = self.embedding(self.Np+1, EMBED_DIMENSION, freq=10000) #10000 is described in ViT paper
        self.posembed = self.posembed.repeat(self.N, 1, 1)
        self.transformer = nn.Sequential(
            *((TransBlock(heads, EMBED_DIMENSION, int(fflscale*EMBED_DIMENSION)),)*nblocks)
        )
        self.classifier = nn.Sequential(
            nn.LayerNorm(EMBED_DIMENSION),
            nn.Linear(EMBED_DIMENSION, classes)
        )
    
    def embedding(npatches, EMBED_DIMENSION, freq):
        posembed = torch.zeros(npatches, EMBED_DIMENSION)
        for i in range(npatches):
            for j in range(EMBED_DIMENSION):
                if j % 2 == 0:
                    posembed[i][j] = np.sin(i/(freq**(j/EMBED_DIMENSION)))
                else:
                    posembed[i][j] = np.cos(i/(freq**((j-1)/EMBED_DIMENSION)))
        return posembed
    
    def forward(self, imgs): #assume size checking done by createPatches
        """Linear Projection and Positional Embedding"""
        tokens = self.proj(imgs) #perform linear projection
        clstoken = self.clstoken.repeat(self.N, 1, 1)
        tokens = torch.cat([clstoken, tokens], dim=1) #concat the class token
        x = tokens + self.posembed #add positional encoding
        """Transformer"""
        x = self.transformer(x)
        """Classification"""
        y = x[0]
        return self.classifier(y)