In [2]:
import torch
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision import transforms, datasets
import os
from argparse import Namespace
import numpy as np
import matplotlib.pyplot as plt
from torch.nn import CrossEntropyLoss


from vit import ViT
from util import *



  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dataDir = './data/'

In [4]:
imageDatasets = {x: datasets.ImageFolder(os.path.join(dataDir, x), dataTransforms[x]) for x in ['train', 'val', 'test']}
dataloaders = {
    'train': DataLoader(imageDatasets['train'], batch_size=args.batchSize, shuffle=True, num_workers=1),
    'val': DataLoader(imageDatasets['val'], batch_size=args.batchSize, shuffle=True, num_workers=1),
    'test': DataLoader(imageDatasets['test'], batch_size=args.batchSize, shuffle=False, num_workers=0),
}
datasetSizes = {x: len(imageDatasets[x]) for x in ['train', 'val', 'test']}
classNames = imageDatasets['train'].classes

model = ViT(image_size = 256, patch_size = 32, num_classes = len(classNames), dim = 1024, depth = 6, heads = 8, mlp_dim=512, dropout = 0.1, emb_dropout = 0.1).to(device)
optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=args.initLr)
                                #  betas=(args.beta1, args.beta2), 
                                #  eps=1e-8, 
                                #  weight_decay=args.weight_decay)
criterion = CrossEntropyLoss()

In [20]:
valLossList, valAccList = [], []
for epoch in range(args.epochs):
    ''' Train '''
    model.train()
    trainLoss = 0.0
    correct, total = 0, 0
    lossList, accList = [], []
    for i, (x, y) in enumerate(dataloaders['train']):
        x, y = x.to(device), y.to(device)
        yHat = model(x)
        loss = criterion(yHat, y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        # stat
        trainLoss += loss.item()
        total += yHat.size(0)
        val, idx = yHat.max(1)
        correct += idx.eq(y).sum().item()
        avgLoss = float(trainLoss)/(i+1)
        avgAcc = float(correct)/total
        lossList.append(avgLoss)
        accList.append(avgAcc)
        if(i % 100 == 0):
            print(f"Train: {i}/{len(dataloaders['train'])} - AccuLoss: {avgLoss:.3f} | AccuAcc: {avgAcc:.3f} ({correct}/{total})")
    # plot for each epoch
    plotTwo(lossList, accList, f"Epoch {epoch} Loss and Acc", f"./plots/e{epoch}_train.png")

    ''' Val '''
    model.eval()
    testLoss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for i, (x, y) in enumerate(dataloaders['val']):
            x, y = x.to(device), y.to(device)
            yHat = model(x)
            loss = criterion(yHat, y)
            
            testLoss += loss.item()
            total += yHat.size(0)
            val, idx = yHat.max(1)
            correct += idx.eq(y).sum().item()
            avgLoss = float(testLoss)/(i+1)
            avgAcc = float(correct)/total

            valLossList.append(avgLoss)
            valAccList.append(avgAcc)
            if(i % 100 == 0):
                print(f"Val: {i}/{len(dataloaders['val'])} - AccuLoss: {avgLoss:.3f} | AccuAcc: {avgAcc:.3f} ({correct}/{total})")
plotTwo(valLossList, valAccList, f"Validation Loss and Acc", f"./plots/eval.png")

Train: 0/512 - AccuLoss: 1.728 | AccuAcc: 0.500 (4/8)
Train: 100/512 - AccuLoss: 0.950 | AccuAcc: 0.545 (440/808)
Train: 200/512 - AccuLoss: 0.942 | AccuAcc: 0.541 (870/1608)
Train: 300/512 - AccuLoss: 0.948 | AccuAcc: 0.537 (1294/2408)
Train: 400/512 - AccuLoss: 0.943 | AccuAcc: 0.536 (1720/3208)
Train: 500/512 - AccuLoss: 0.932 | AccuAcc: 0.543 (2175/4008)
Val: 0/129 - AccuLoss: 1.230 | AccuAcc: 0.375 (3/8)
Val: 100/129 - AccuLoss: 0.940 | AccuAcc: 0.575 (465/808)
Train: 0/512 - AccuLoss: 1.193 | AccuAcc: 0.375 (3/8)
Train: 100/512 - AccuLoss: 0.906 | AccuAcc: 0.551 (445/808)
Train: 200/512 - AccuLoss: 0.889 | AccuAcc: 0.576 (927/1608)
Train: 300/512 - AccuLoss: 0.882 | AccuAcc: 0.578 (1393/2408)
Train: 400/512 - AccuLoss: 0.898 | AccuAcc: 0.563 (1806/3208)
Train: 500/512 - AccuLoss: 0.897 | AccuAcc: 0.561 (2250/4008)
Val: 0/129 - AccuLoss: 0.610 | AccuAcc: 0.625 (5/8)
Val: 100/129 - AccuLoss: 0.839 | AccuAcc: 0.600 (485/808)
Train: 0/512 - AccuLoss: 0.994 | AccuAcc: 0.625 (5/8)
Trai