In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import os
os.chdir("/content/drive/MyDrive/CS444")

In [None]:
# installing libraries
#!pip install torch
#!pip install torchvision
#!pip install timm

In [3]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

plt.style.use("ggplot")

import torch
import torch.nn as nn
import torchvision.transforms as transforms

import timm

import gc
import os
import time
import random
from datetime import datetime

from PIL import Image
from tqdm.notebook import tqdm
from sklearn import model_selection, metrics


In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

## MNIST Pretrained - 40 eps

In [5]:
MODEL_PATH = ("/content/drive/MyDrive/CS444/jx_vit_base_p16_224-80ecf9dd.pth")

In [6]:
IMG_SIZE = 224
BATCH_SIZE = 16
LR = 0.01
N_EPOCHS = 40

In [13]:
# create image augmentations
transforms_train = transforms.Compose(
    [
        transforms.Resize((IMG_SIZE, IMG_SIZE)), 
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
    ]
)

transforms_test = transforms.Compose(
    [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
    ]
)

In [14]:
import torchvision
DOWNLOAD_PATH = '/data/mnist'
train_set = torchvision.datasets.MNIST(DOWNLOAD_PATH, train=True, download=True,
                                       transform=transforms_train)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)

test_set = torchvision.datasets.MNIST(DOWNLOAD_PATH, train=False, download=True,
                                      transform=transforms_test)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True)

In [9]:
#import torchvision.datasets as datasets
#from torch.utils.data import DataLoader
#train_set = datasets.CIFAR10(root='./../datasets', train=True, download=True, transform=transforms_train)
#test_set = datasets.CIFAR10(root='./../datasets', train=False, download=True, transform=transforms_test)
#train_loader = DataLoader(train_set, BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
#test_loader = DataLoader(test_set, BATCH_SIZE*2, num_workers=4, pin_memory=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./../datasets/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./../datasets/cifar-10-python.tar.gz to ./../datasets


KeyboardInterrupt: ignored

In [9]:
class ViTBase16(nn.Module):
    def __init__(self, n_classes, pretrained=False):

        super(ViTBase16, self).__init__()

        self.model = timm.create_model("vit_base_patch16_224", pretrained=False)
        if pretrained:
            self.model.load_state_dict(torch.load(MODEL_PATH))

        self.model.head = nn.Linear(self.model.head.in_features, n_classes)

    def forward(self, x):
        x = self.model(x)
        return x

    def train_one_epoch(self, train_loader, criterion, optimizer, device):
        # keep track of training loss
        epoch_loss = 0.0
        epoch_accuracy = 0.0

        ###################
        # train the model #
        ###################
        self.model.train()
        for i, (data, target) in enumerate(train_loader):
            # move tensors to GPU if CUDA is available
            if device.type == "cuda":
                data, target = data.cuda(), target.cuda()

            # clear the gradients of all optimized variables
            optimizer.zero_grad()
            # forward pass: compute predicted outputs by passing inputs to the model
            output = self.forward(data)
            # calculate the batch loss
            loss = criterion(output, target)
            # backward pass: compute gradient of the loss with respect to model parameters
            loss.backward()
            # Calculate Accuracy
            accuracy = (output.argmax(dim=1) == target).float().mean()
            # update training loss and accuracy
            epoch_loss += loss
            epoch_accuracy += accuracy
            optimizer.step()

            if i % 20 == 0:
                    print(f"\tBATCH {i+1}/{len(train_loader)} - LOSS: {loss}")
                

        return epoch_loss / len(train_loader), epoch_accuracy / len(train_loader)

    def validate_one_epoch(self, test_loader, criterion, device):
        # keep track of validation loss
        valid_loss = 0.0
        valid_accuracy = 0.0

        ######################
        # validate the model #
        ######################
        self.model.eval()
        for data, target in test_loader:
            # move tensors to GPU if CUDA is available
            if device.type == "cuda":
                data, target = data.cuda(), target.cuda()

            with torch.no_grad():
                # forward pass: compute predicted outputs by passing inputs to the model
                output = self.model(data)
                # calculate the batch loss
                loss = criterion(output, target)
                # Calculate Accuracy
                accuracy = (output.argmax(dim=1) == target).float().mean()
                # update average validation loss and accuracy
                valid_loss += loss
                valid_accuracy += accuracy

        return valid_loss / len(test_loader), valid_accuracy / len(test_loader)


In [10]:
def fit_tpu(
    model, epochs, device, criterion, optimizer, train_loader, test_loader=None
):

    valid_loss_min = np.Inf  # track change in validation loss

    # keeping track of losses as it happen
    train_losses = []
    valid_losses = []
    train_accs = []
    valid_accs = []

    for epoch in range(1, epochs + 1):
        gc.collect()

        print(f"{'='*50}")
        print(f"EPOCH {epoch} - TRAINING...")
        train_loss, train_acc = model.train_one_epoch(
            train_loader, criterion, optimizer, device
        )
        print(
            f"\n\t[TRAIN] EPOCH {epoch} - LOSS: {train_loss}, ACCURACY: {train_acc}\n"
        )
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        gc.collect()

        if test_loader is not None:
            gc.collect()
            print(f"EPOCH {epoch} - VALIDATING...")
            valid_loss, valid_acc = model.validate_one_epoch(
                test_loader, criterion, device
            )
            print(f"\t[VALID] LOSS: {valid_loss}, ACCURACY: {valid_acc}\n")
            valid_losses.append(valid_loss)
            valid_accs.append(valid_acc)
            gc.collect()

            # save model if validation loss has decreased
            if valid_loss <= valid_loss_min and epoch != 1:
                print(
                    "Validation loss decreased ({:.4f} --> {:.4f}).  Saving model ...".format(
                        valid_loss_min, valid_loss
                    )
                )
            #                 xm.save(model.state_dict(), 'best_model.pth')

            valid_loss_min = valid_loss

    return {
        "train_loss": train_losses,
        "valid_losses": valid_losses,
        "train_acc": train_accs,
        "valid_acc": valid_accs,
    }

In [11]:
model = ViTBase16(n_classes=10, pretrained=True)

In [15]:
criterion = nn.CrossEntropyLoss()
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
model.to(device)

lr = LR 
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

print(f"INITIALIZING TRAINING")
start_time = datetime.now()
print(f"Start Time: {start_time}")

logs = fit_tpu(
        model=model,
        epochs=N_EPOCHS,
        device=device,
        criterion=criterion,
        optimizer=optimizer,
        train_loader=train_loader,
        test_loader=test_loader,
    )

print(f"Execution time: {datetime.now() - start_time}")

INITIALIZING TRAINING
Start Time: 2022-05-12 18:01:51.143416
EPOCH 1 - TRAINING...
	BATCH 1/3750 - LOSS: 2.2490615844726562
	BATCH 21/3750 - LOSS: 2.7702512741088867
	BATCH 41/3750 - LOSS: 1.9304150342941284
	BATCH 61/3750 - LOSS: 2.0892181396484375
	BATCH 81/3750 - LOSS: 1.778267741203308
	BATCH 101/3750 - LOSS: 2.038327932357788
	BATCH 121/3750 - LOSS: 1.8042110204696655
	BATCH 141/3750 - LOSS: 2.1231768131256104
	BATCH 161/3750 - LOSS: 2.247530698776245
	BATCH 181/3750 - LOSS: 1.9126079082489014
	BATCH 201/3750 - LOSS: 1.871691107749939
	BATCH 221/3750 - LOSS: 1.751727819442749
	BATCH 241/3750 - LOSS: 1.8868409395217896
	BATCH 261/3750 - LOSS: 1.8334274291992188
	BATCH 281/3750 - LOSS: 2.15313720703125
	BATCH 301/3750 - LOSS: 1.9897031784057617
	BATCH 321/3750 - LOSS: 1.792114496231079
	BATCH 341/3750 - LOSS: 2.077228307723999
	BATCH 361/3750 - LOSS: 1.7914998531341553
	BATCH 381/3750 - LOSS: 1.9459574222564697
	BATCH 401/3750 - LOSS: 1.9408681392669678
	BATCH 421/3750 - LOSS: 1.901

KeyboardInterrupt: ignored