In [63]:
"""numpy and torch"""
import numpy as np
import torch

"""PIL"""
from PIL import Image

"""torchvision and utils"""
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import make_grid
from torchvision.utils import save_image

"""os"""
import os


"""Get image to tensor"""
transform = transforms.Compose([
    transforms.PILToTensor()
])

"""Loading data into arrays"""
xtrain, xtrain, xtest, ytest = [], [], [], []

"""training data"""
trainDIRs = ['../../../AD_NC/train/AD/', '../../../AD_NC/train/NC']
size = [0, 0]
for i, DIR in enumerate(trainDIRs):
    for filename in os.listdir(DIR):
        f = os.path.join(DIR, filename)
        img = Image.open(f)
        tensor = transform(img).float()
        tensor.require_grad = True
        xtrain.append(tensor/255)
        size[i] += 1
xtrain = torch.stack(xtrain)
ytrain = torch.from_numpy(np.concatenate((np.ones(size[0]), np.zeros(size[1])), axis=0))

"""testing data"""
testDIRs = ['../../../AD_NC/test/AD/', '../../../AD_NC/test/NC']
size = [0, 0]
for i, DIR in enumerate(testDIRs):
    for filename in os.listdir(DIR):
        f = os.path.join(DIR, filename)
        img = Image.open(f)
        tensor = transform(img).float()
        tensor.require_grad = True
        xtest.append(tensor/255)
        size[i] += 1
xtest = torch.stack(xtest)
ytest = torch.from_numpy(np.concatenate((np.ones(size[0]), np.zeros(size[1])), axis=0))

In [79]:
def createPatches(imgs, wsize, hsize):
    N, C, W, H = imgs.shape #number imgs, channels, width, height
    size = (N, C, W // wsize, wsize, H // hsize, hsize)
    perm = (0, 2, 4, 1, 3, 5) #bring col, row index of patch to front
    flat = (1, 2) #flatten (col, row) index into col*row entry index for patches
    imgs = imgs.reshape(size).permute(perm).flatten(*flat)
    return imgs

In [80]:
patches = createPatches(xtrain, 16, 16)

In [104]:
print("whole thing")
print(xtrain.shape)
print(patches.shape)
print("individual image")
print(xtrain[0].shape)
print(patches[0].shape)

whole thing
torch.Size([21520, 1, 240, 256])
torch.Size([21520, 240, 1, 16, 16])
individual image
torch.Size([1, 240, 256])
torch.Size([240, 1, 16, 16])


In [109]:
import torch.nn as nn

def flattenPatches(imgs): #takes input (N, Npatches, C, W, H)
    return imgs.flatten(2, 4)

In [111]:
flattenedpatches = flattenPatches(patches)
print(flattenedpatches.shape)

torch.Size([21520, 240, 256])


In [118]:
"""projecting the patches to tokens"""
EMBED_DIMENSION = 123
wsize, hsize = 16, 16
N, C, W, H = xtrain.shape
proj = nn.Linear(C*wsize*hsize, EMBED_DIMENSION)
tokens = proj(flattenedpatches)

torch.Size([21520, 240, 123])


In [119]:
print(tokens.shape) #of the form N, Ntokens, EMBED_DIMENSION

torch.Size([21520, 240, 123])


In [None]:
"""adding the class tokens"""
clstoken = nn.Parameter(torch.zeros(1, 1, EMBED_DIMENSION))

In [125]:
"""positional embedding"""
def embedding(npatches, EMBED_DIMENSION, freq):
    posembed = torch.zeros(npatches, EMBED_DIMENSION)
    for i in range(npatches):
        for j in range(EMBED_DIMENSION):
            if j % 2 == 0:
                posembed[i][j] = np.sin(i/(freq**(j/EMBED_DIMENSION)))
            else:
                posembed[i][j] = np.cos(i/(freq**((j-1)/EMBED_DIMENSION)))
    return posembed

In [126]:
#test the positional embedding
embed = embedding(4, 6, 10)
print(embed)

tensor([[ 0.0000,  1.0000,  0.0000,  1.0000,  0.0000,  1.0000],
        [ 0.8415,  0.5403,  0.4477,  0.8942,  0.2138,  0.9769],
        [ 0.9093, -0.4161,  0.8006,  0.5992,  0.4177,  0.9086],
        [ 0.1411, -0.9900,  0.9841,  0.1774,  0.6023,  0.7983]])


In [None]:
"""
Vision Transformer Class to create a vision transformer model
"""
class VisionTransformer(nn.Module):
    def __init__(self, imgsize, patchsize):
        super().__init__()
        self.C, self.W, self.H = *imgsize
        self.wsize, self.hsize = *patchsize
        """components"""
        self.proj = nn.Linear(self.C*self.W*self.H, EMBED_DIMENSION)
        self.clstoken = nn.Parameter(torch.zeros(1, 1, EMBED_DIMENSION))
    
    def createPatches(self, imgs, wsize, hsize):
        N, C, W, H = imgs.shape
        if (W % wsize != 0) or (H % hsize != 0):
            raise Exception("patchsize is not appropriate")
        else if (self.C != C) or (self.H != H):
            raise Exception("given sizes do not match")
        """if everything ok"""
        size = (N, C, W // wsize, wsize, H // hsize, hsize)
        perm = (0, 2, 4, 1, 3, 5) #bring col, row index of patch to front
        flat = (1, 2) #flatten (col, row) index into col*row entry index for patches
        imgs = imgs.reshape(size).permute(perm).flatten(*flat)
        return imgs #in format Nimgs, Npatches, C, Wpatch, Hpatch
    
    def forward(self, imgs): #assume size checking done by createPatches
        patches = self.createPatches(imgs, self.wsize, self.hsize)