In [36]:
class Attention(nn.Module):
    def __init__(self, heads, EMBED_DIMENSION):
        super().__init__()
        self.heads = heads
        self.attn = nn.MultiheadAttention(EMBED_DIMENSION, heads, batch_first=True)
        self.Q = nn.Linear(EMBED_DIMENSION, EMBED_DIMENSION, bias=False)
        self.K = nn.Linear(EMBED_DIMENSION, EMBED_DIMENSION, bias=False)
        self.V = nn.Linear(EMBED_DIMENSION, EMBED_DIMENSION, bias=False)
        
    def forward(self, x):
        Q = self.Q(x)
        K = self.K(x)
        V = self.V(x)
        
        attnout, attnweights = self.attn(Q, K, V)

In [None]:
class TransBlock(nn.Module):
    def __init__(self, heads, EMBED_DIMENSION, fflsize)
        super().__init__()
        self.fnorm = nn.LayerNorm(EMBED_DIMENSION)
        self.snorm = nn.LayerNorm(EMBED_DIMENSION)
        self.attn = Attention(heads, EMBED_DIMENSION)
        self.ffl = nn.Sequential(
            nn.Linear(EMBED_DIMENSION, fflsize),
            nn.GELU(),
            nn.Linear(fflsize, EMBED_DIMENSION)
        )
    
    def forward(self, x):
        """
        Switching to pre-MHA LayerNorm is supposed to give better performance,
        this is used in other models such as LLMs like GPT. Gradients are meant
        to be stabilised. This is different to the original ViT paper.
        """
        x = x + self.attn(self.fnorm(x))
        x = x + self.ffl(self.snorm(x))
        return x

In [34]:
"""
Vision Transformer Class to create a vision transformer model
"""
class VisionTransformer(nn.Module):
    def __init__(self, imgsize, patchsize, fflscale, nblocks):
        super().__init__()
        (self.N, self.C, self.W, self.H) = imgsize
        (self.wsize, self.hsize) = patchsize
        """check for errors with sizing"""
        if (W % wsize != 0) or (H % hsize != 0):
            raise Exception("patchsize is not appropriate")
        if (self.C != C) or (self.H != H):
            raise Exception("given sizes do not match")
        """components"""
        self.proj = nn.Linear(self.C*self.wsize*self.hsize, EMBED_DIMENSION)
        self.clstoken = nn.Parameter(torch.zeros(1, 1, EMBED_DIMENSION))
        Np = (self.W // wsize) * (self.H // hsize)
        self.posembed = embedding(Np+1, EMBED_DIMENSION, freq=10000) #10000 is described in ViT paper
        self.posembed = self.posembed.repeat(N, 1, 1)
    
    def createPatches(self, imgs):
        size = (self.N, self.C, self.W // self.wsize, self.wsize, self.H // self.hsize, self.hsize)
        perm = (0, 2, 4, 1, 3, 5) #bring col, row index of patch to front
        flat = (1, 2) #flatten (col, row) index into col*row entry index for patches
        imgs = imgs.reshape(size).permute(perm).flatten(*flat)
        return imgs #in format Nimgs, Npatches, C, Wpatch, Hpatch
    
    def flattenPatches(self, imgs): #takes input (N, Npatches, C, W, H)
        return imgs.flatten(2, 4)
    
    def embedding(npatches, EMBED_DIMENSION, freq):
        posembed = torch.zeros(npatches, EMBED_DIMENSION)
        for i in range(npatches):
            for j in range(EMBED_DIMENSION):
                if j % 2 == 0:
                    posembed[i][j] = np.sin(i/(freq**(j/EMBED_DIMENSION)))
                else:
                    posembed[i][j] = np.cos(i/(freq**((j-1)/EMBED_DIMENSION)))
        return posembed
    
    def forward(self, imgs, prepatched=True): #assume size checking done by createPatches
        if not prepatched:
            imgs = self.createPatches(imgs) #create patches
            imgs = self.flattenPatches(imgs) #flatten patch C,W,H into one array
        """Linear Projection and Positional Embedding"""
        tokens = self.proj(imgs) #perform linear projection
        N, Np, P = tokens.shape
        clstoken = self.clstoken.repeat(N, 1, 1)
        tokens = torch.cat([clstoken, tokens], dim=1) #concat the class token
        x = tokens + self.posembed #add positional encoding
        """Transformer"""
        