In [4]:
"""numpy and torch"""
import numpy as np
import torch

"""PIL"""
from PIL import Image

"""torchvision and utils"""
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import make_grid
from torchvision.utils import save_image

"""os"""
import os


"""Get image to tensor"""
transform = transforms.Compose([
    transforms.PILToTensor()
])

"""Loading data into arrays"""
xtrain, xtrain, xtest, ytest = [], [], [], []

"""training data"""
trainDIRs = ['../../../AD_NC/train/AD/', '../../../AD_NC/train/NC']
size = [0, 0]
for i, DIR in enumerate(trainDIRs):
    for filename in os.listdir(DIR):
        f = os.path.join(DIR, filename)
        img = Image.open(f)
        tensor = transform(img).float()
        tensor.require_grad = True
        xtrain.append(tensor/255)
        size[i] += 1
xtrain = torch.stack(xtrain)
ytrain = torch.from_numpy(np.concatenate((np.ones(size[0]), np.zeros(size[1])), axis=0))

"""testing data"""
testDIRs = ['../../../AD_NC/test/AD/', '../../../AD_NC/test/NC']
size = [0, 0]
for i, DIR in enumerate(testDIRs):
    for filename in os.listdir(DIR):
        f = os.path.join(DIR, filename)
        img = Image.open(f)
        tensor = transform(img).float()
        tensor.require_grad = True
        xtest.append(tensor/255)
        size[i] += 1
xtest = torch.stack(xtest)
ytest = torch.from_numpy(np.concatenate((np.ones(size[0]), np.zeros(size[1])), axis=0))

In [5]:
def createPatches(imgs, wsize, hsize):
    N, C, W, H = imgs.shape #number imgs, channels, width, height
    size = (N, C, W // wsize, wsize, H // hsize, hsize)
    perm = (0, 2, 4, 1, 3, 5) #bring col, row index of patch to front
    flat = (1, 2) #flatten (col, row) index into col*row entry index for patches
    imgs = imgs.reshape(size).permute(perm).flatten(*flat)
    return imgs

In [6]:
patches = createPatches(xtrain, 16, 16)

In [7]:
print("whole thing")
print(xtrain.shape)
print(patches.shape)
print("individual image")
print(xtrain[0].shape)
print(patches[0].shape)

whole thing
torch.Size([21520, 1, 240, 256])
torch.Size([21520, 240, 1, 16, 16])
individual image
torch.Size([1, 240, 256])
torch.Size([240, 1, 16, 16])


In [15]:
import torch.nn as nn

def flattenPatches(imgs): #takes input (N, Npatches, C, W, H)
    return imgs.flatten(2, 4)

In [16]:
flattenedpatches = flattenPatches(patches)
print(flattenedpatches.shape)

torch.Size([21520, 240, 256])


In [17]:
"""projecting the patches to tokens"""
EMBED_DIMENSION = 123
wsize, hsize = 16, 16
N, C, W, H = xtrain.shape
proj = nn.Linear(C*wsize*hsize, EMBED_DIMENSION)
tokens = proj(flattenedpatches)

In [18]:
print(tokens.shape) #of the form N, Ntokens, EMBED_DIMENSION

torch.Size([21520, 240, 123])


In [19]:
"""adding the class tokens"""
clstoken = nn.Parameter(torch.zeros(1, 1, EMBED_DIMENSION))

In [20]:
"""positional embedding"""
def embedding(npatches, EMBED_DIMENSION, freq):
    posembed = torch.zeros(npatches, EMBED_DIMENSION)
    for i in range(npatches):
        for j in range(EMBED_DIMENSION):
            if j % 2 == 0:
                posembed[i][j] = np.sin(i/(freq**(j/EMBED_DIMENSION)))
            else:
                posembed[i][j] = np.cos(i/(freq**((j-1)/EMBED_DIMENSION)))
    return posembed

In [26]:
#test the positional embedding
embed = embedding(4, 6, 10)
print(embed.shape)

torch.Size([4, 6])


In [34]:
"""
Vision Transformer Class to create a vision transformer model
"""
class VisionTransformer(nn.Module):
    def __init__(self, imgsize, patchsize):
        super().__init__()
        (self.N, self.C, self.W, self.H) = imgsize
        (self.wsize, self.hsize) = patchsize
        """check for errors with sizing"""
        if (W % wsize != 0) or (H % hsize != 0):
            raise Exception("patchsize is not appropriate")
        if (self.C != C) or (self.H != H):
            raise Exception("given sizes do not match")
        """components"""
        self.proj = nn.Linear(self.C*self.wsize*self.hsize, EMBED_DIMENSION)
        self.clstoken = nn.Parameter(torch.zeros(1, 1, EMBED_DIMENSION))
        Np = (self.W // wsize) * (self.H // hsize)
        self.posembed = embedding(Np+1, EMBED_DIMENSION, freq=10000) #10000 is described in ViT paper
        self.posembed = self.posembed.repeat(N, 1, 1)
    
    def createPatches(self, imgs):
        size = (self.N, self.C, self.W // self.wsize, self.wsize, self.H // self.hsize, self.hsize)
        perm = (0, 2, 4, 1, 3, 5) #bring col, row index of patch to front
        flat = (1, 2) #flatten (col, row) index into col*row entry index for patches
        imgs = imgs.reshape(size).permute(perm).flatten(*flat)
        return imgs #in format Nimgs, Npatches, C, Wpatch, Hpatch
    
    def flattenPatches(self, imgs): #takes input (N, Npatches, C, W, H)
        return imgs.flatten(2, 4)
    
    def embedding(npatches, EMBED_DIMENSION, freq):
        posembed = torch.zeros(npatches, EMBED_DIMENSION)
        for i in range(npatches):
            for j in range(EMBED_DIMENSION):
                if j % 2 == 0:
                    posembed[i][j] = np.sin(i/(freq**(j/EMBED_DIMENSION)))
                else:
                    posembed[i][j] = np.cos(i/(freq**((j-1)/EMBED_DIMENSION)))
        return posembed
    
    def forward(self, imgs, prepatched=True): #assume size checking done by createPatches
        if not prepatched:
            imgs = self.createPatches(imgs) #create patches
            imgs = self.flattenPatches(imgs) #flatten patch C,W,H into one array
        """Linear Projection and Positional Embedding"""
        tokens = self.proj(imgs) #perform linear projection
        N, Np, P = tokens.shape
        clstoken = self.clstoken.repeat(N, 1, 1)
        tokens = torch.cat([clstoken, tokens], dim=1) #concat the class token
        tokens = tokens + self.posembed #add positional encoding
        """Transformer"""
        

In [35]:
patchsize = (16, 16)
ViT = VisionTransformer(xtrain.shape, patchsize)
ViT.forward(xtrain, prepatched=False)

tokens shape: torch.Size([21520, 240, 123])
cls shape: torch.Size([21520, 1, 123])
tokens shape: torch.Size([21520, 241, 123])
tokens+embed shape: torch.Size([21520, 241, 123])


In [36]:
class Attention(nn.Module):
    def __init__(self, heads, EMBED_DIMENSION):
        super().__init__()
        self.heads = heads
        self.attn = nn.MultiheadAttention(EMBED_DIMENSION, heads, batch_first=True)
        self.Q = nn.Linear(EMBED_DIMENSION, EMBED_DIMENSION, bias=False)
        self.K = nn.Linear(EMBED_DIMENSION, EMBED_DIMENSION, bias=False)
        self.V = nn.Linear(EMBED_DIMENSION, EMBED_DIMENSION, bias=False)
        
    def forward(self, x):
        Q = self.Q(x)
        K = self.K(x)
        V = self.V(x)
        
        attnout, attnweights = self.attn(Q, K, V)

In [None]:
class TransBlock(nn.Module):
    def __init__(self, heads, EMBED_DIMENSION, fflsize)
        super().__init__()
        self.fnorm = nn.LayerNorm(EMBED_DIMENSION)
        self.snorm = nn.LayerNorm(EMBED_DIMENSION)
        self.attn = Attention(heads, EMBED_DIMENSION)
        self.ffl = nn.Sequential(
            nn.Linear(EMBED_DIMENSION, fflsize),
            nn.GELU(),
            nn.Linear(fflsize, EMBED_DIMENSION)
        )
    
    def forward(self, x):
        """
        Switching to pre-MHA LayerNorm is supposed to give better performance,
        this is used in other models such as LLMs like GPT. Gradients are meant
        to be stabilised. This is different to the original ViT paper.
        """
        x = x + self.attn(self.fnorm(x))
        x = x + self.ffl(self.snorm(x))
        return x