In [None]:
# installing libraries
# !pip install torch
# !pip install torchvision
# !pip install timm

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

plt.style.use("ggplot")

import torch
import torch.nn as nn

import timm

import gc
import os
import time
import random
from datetime import datetime

from PIL import Image
from tqdm.notebook import tqdm
from sklearn import model_selection, metrics


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

## CIFAR Final - 40 eps

In [3]:
MODEL_PATH = ("jx_vit_base_p16_224-80ecf9dd.pth")

In [4]:
# For parallelization in TPUs
os.environ["XLA_USE_BF16"] = "1"
os.environ["XLA_TENSOR_ALLOCATOR_MAXSIZE"] = "100000000"

In [5]:
IMG_SIZE = 224
BATCH_SIZE = 16
LR = 0.01
N_EPOCHS = 20

In [6]:
import torchvision
from torchvision.transforms import ToTensor
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

tt = transforms.Compose(
    [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        #transforms.Resize(224),
        transforms.ToTensor()
    ]
)

train_set = datasets.CIFAR10(root='./../datasets', train=True, download=True, transform=tt)
test_set = datasets.CIFAR10(root='./../datasets', train=False, download=True, transform=tt)

train_loader = DataLoader(train_set, BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_set, BATCH_SIZE*2, num_workers=4, pin_memory=True)


Files already downloaded and verified
Files already downloaded and verified


In [7]:
class ViTBase16(nn.Module):
    def __init__(self, n_classes, pretrained=False):

        super(ViTBase16, self).__init__()

        self.model = timm.create_model("vit_base_patch16_224", pretrained=False)
        if pretrained:
            self.model.load_state_dict(torch.load(MODEL_PATH))

        self.model.head = nn.Linear(self.model.head.in_features, n_classes)

    def forward(self, x):
        x = self.model(x)
        return x

    def train_one_epoch(self, train_loader, criterion, optimizer, device):
        # keep track of training loss
        epoch_loss = 0.0
        epoch_accuracy = 0.0

        ###################
        # train the model #
        ###################
        self.model.train()
        for i, (data, target) in enumerate(train_loader):
            # move tensors to GPU if CUDA is available
            if device.type == "cuda":
                data, target = data.cuda(), target.cuda()

            # clear the gradients of all optimized variables
            optimizer.zero_grad()
            # forward pass: compute predicted outputs by passing inputs to the model
            output = self.forward(data)
            # calculate the batch loss
            loss = criterion(output, target)
            # backward pass: compute gradient of the loss with respect to model parameters
            loss.backward()
            # Calculate Accuracy
            accuracy = (output.argmax(dim=1) == target).float().mean()
            # update training loss and accuracy
            epoch_loss += loss
            epoch_accuracy += accuracy
            optimizer.step()

            if i % 20 == 0:
                    print(f"\tBATCH {i+1}/{len(train_loader)} - LOSS: {loss}")
                

        return epoch_loss / len(train_loader), epoch_accuracy / len(train_loader)

    def validate_one_epoch(self, test_loader, criterion, device):
        # keep track of validation loss
        valid_loss = 0.0
        valid_accuracy = 0.0

        ######################
        # validate the model #
        ######################
        self.model.eval()
        for data, target in test_loader:
            # move tensors to GPU if CUDA is available
            if device.type == "cuda":
                data, target = data.cuda(), target.cuda()

            with torch.no_grad():
                # forward pass: compute predicted outputs by passing inputs to the model
                output = self.model(data)
                # calculate the batch loss
                loss = criterion(output, target)
                # Calculate Accuracy
                accuracy = (output.argmax(dim=1) == target).float().mean()
                # update average validation loss and accuracy
                valid_loss += loss
                valid_accuracy += accuracy

        return valid_loss / len(test_loader), valid_accuracy / len(test_loader)


In [8]:
def fit_tpu(
    model, epochs, device, criterion, optimizer, train_loader, test_loader=None
):

    valid_loss_min = np.Inf  # track change in validation loss

    # keeping track of losses as it happen
    train_losses = []
    valid_losses = []
    train_accs = []
    valid_accs = []

    for epoch in range(1, epochs + 1):
        gc.collect()

        print(f"{'='*50}")
        print(f"EPOCH {epoch} - TRAINING...")
        train_loss, train_acc = model.train_one_epoch(
            train_loader, criterion, optimizer, device
        )
        print(
            f"\n\t[TRAIN] EPOCH {epoch} - LOSS: {train_loss}, ACCURACY: {train_acc}\n"
        )
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        gc.collect()

        if test_loader is not None:
            gc.collect()
            print(f"EPOCH {epoch} - VALIDATING...")
            valid_loss, valid_acc = model.validate_one_epoch(
                test_loader, criterion, device
            )
            print(f"\t[VALID] LOSS: {valid_loss}, ACCURACY: {valid_acc}\n")
            valid_losses.append(valid_loss)
            valid_accs.append(valid_acc)
            gc.collect()

            # save model if validation loss has decreased
            if valid_loss <= valid_loss_min and epoch != 1:
                print(
                    "Validation loss decreased ({:.4f} --> {:.4f}).  Saving model ...".format(
                        valid_loss_min, valid_loss
                    )
                )
            #                 xm.save(model.state_dict(), 'best_model.pth')

            valid_loss_min = valid_loss

    return {
        "train_loss": train_losses,
        "valid_losses": valid_losses,
        "train_acc": train_accs,
        "valid_acc": valid_accs,
    }

In [9]:
model = ViTBase16(n_classes=10, pretrained=True)

In [None]:
criterion = nn.CrossEntropyLoss()
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
model.to(device)

lr = LR 
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

print(f"INITIALIZING TRAINING")
start_time = datetime.now()
print(f"Start Time: {start_time}")

logs = fit_tpu(
        model=model,
        epochs=N_EPOCHS,
        device=device,
        criterion=criterion,
        optimizer=optimizer,
        train_loader=train_loader,
        test_loader=test_loader,
    )

print(f"Execution time: {datetime.now() - start_time}")

INITIALIZING TRAINING
Start Time: 2022-05-12 18:43:24.471594
EPOCH 1 - TRAINING...
	BATCH 1/3125 - LOSS: 2.126641273498535
	BATCH 21/3125 - LOSS: 3.1973066329956055
	BATCH 41/3125 - LOSS: 2.4380807876586914
	BATCH 61/3125 - LOSS: 2.471529006958008
	BATCH 81/3125 - LOSS: 2.3762385845184326
	BATCH 101/3125 - LOSS: 2.4663960933685303
	BATCH 121/3125 - LOSS: 2.5387799739837646
	BATCH 141/3125 - LOSS: 2.404280662536621
	BATCH 161/3125 - LOSS: 2.2756927013397217
	BATCH 181/3125 - LOSS: 2.358534097671509
	BATCH 201/3125 - LOSS: 2.365492820739746
	BATCH 221/3125 - LOSS: 2.3388683795928955
	BATCH 241/3125 - LOSS: 2.387812614440918
	BATCH 261/3125 - LOSS: 2.3788771629333496
	BATCH 281/3125 - LOSS: 2.260262966156006
	BATCH 301/3125 - LOSS: 2.431842565536499
	BATCH 321/3125 - LOSS: 2.253547191619873
	BATCH 341/3125 - LOSS: 2.2058823108673096
	BATCH 361/3125 - LOSS: 2.349489688873291
	BATCH 381/3125 - LOSS: 2.456937313079834
	BATCH 401/3125 - LOSS: 2.3546061515808105
	BATCH 421/3125 - LOSS: 2.25374



	[VALID] LOSS: 2.1372365951538086, ACCURACY: 0.19748401641845703

EPOCH 2 - TRAINING...




	BATCH 1/3125 - LOSS: 1.9689531326293945
	BATCH 21/3125 - LOSS: 2.3453121185302734
	BATCH 41/3125 - LOSS: 2.1453938484191895
	BATCH 61/3125 - LOSS: 2.023744583129883
	BATCH 81/3125 - LOSS: 2.1231613159179688
	BATCH 101/3125 - LOSS: 1.9703654050827026
	BATCH 121/3125 - LOSS: 2.054694414138794
	BATCH 141/3125 - LOSS: 2.1391022205352783
	BATCH 161/3125 - LOSS: 2.415559768676758
	BATCH 181/3125 - LOSS: 2.1005570888519287
	BATCH 201/3125 - LOSS: 2.204759120941162
	BATCH 221/3125 - LOSS: 2.247696876525879
	BATCH 241/3125 - LOSS: 2.385873556137085
	BATCH 261/3125 - LOSS: 2.1997921466827393
	BATCH 281/3125 - LOSS: 2.2351326942443848
	BATCH 301/3125 - LOSS: 2.2362613677978516
	BATCH 321/3125 - LOSS: 2.0449774265289307
	BATCH 341/3125 - LOSS: 2.38157320022583
	BATCH 361/3125 - LOSS: 2.2367632389068604
	BATCH 381/3125 - LOSS: 2.501051902770996
	BATCH 401/3125 - LOSS: 2.1639344692230225
	BATCH 421/3125 - LOSS: 2.0467679500579834
	BATCH 441/3125 - LOSS: 2.3104918003082275
	BATCH 461/3125 - LOSS: 2.




	[TRAIN] EPOCH 2 - LOSS: 2.188119649887085, ACCURACY: 0.1722399890422821

EPOCH 2 - VALIDATING...




	[VALID] LOSS: 2.171830415725708, ACCURACY: 0.1677316278219223

EPOCH 3 - TRAINING...




	BATCH 1/3125 - LOSS: 2.3745546340942383
	BATCH 21/3125 - LOSS: 2.244563102722168
	BATCH 41/3125 - LOSS: 2.336925983428955
	BATCH 61/3125 - LOSS: 2.2836453914642334
	BATCH 81/3125 - LOSS: 2.2522213459014893
	BATCH 101/3125 - LOSS: 2.144573450088501
	BATCH 121/3125 - LOSS: 2.227341651916504
	BATCH 141/3125 - LOSS: 2.185807704925537
	BATCH 161/3125 - LOSS: 2.1957695484161377
	BATCH 181/3125 - LOSS: 2.188368082046509
	BATCH 201/3125 - LOSS: 2.1474924087524414
	BATCH 221/3125 - LOSS: 2.1971352100372314
	BATCH 241/3125 - LOSS: 2.339203119277954
	BATCH 261/3125 - LOSS: 2.1915078163146973
	BATCH 281/3125 - LOSS: 2.0587217807769775
	BATCH 301/3125 - LOSS: 2.1432790756225586
	BATCH 321/3125 - LOSS: 2.223937749862671
	BATCH 341/3125 - LOSS: 2.2269089221954346
	BATCH 361/3125 - LOSS: 2.313445568084717
	BATCH 381/3125 - LOSS: 2.2066919803619385
	BATCH 401/3125 - LOSS: 2.0784616470336914
	BATCH 421/3125 - LOSS: 2.597271203994751
	BATCH 441/3125 - LOSS: 2.1830990314483643
	BATCH 461/3125 - LOSS: 2.1




	[TRAIN] EPOCH 3 - LOSS: 2.1714088916778564, ACCURACY: 0.17875999212265015

EPOCH 3 - VALIDATING...




	[VALID] LOSS: 2.112957000732422, ACCURACY: 0.1994808316230774

Validation loss decreased (2.1718 --> 2.1130).  Saving model ...
EPOCH 4 - TRAINING...




	BATCH 1/3125 - LOSS: 2.0368189811706543
	BATCH 21/3125 - LOSS: 2.089770793914795
	BATCH 41/3125 - LOSS: 2.070692300796509
	BATCH 61/3125 - LOSS: 2.0995466709136963
	BATCH 81/3125 - LOSS: 2.001901149749756
	BATCH 101/3125 - LOSS: 2.3299968242645264
	BATCH 121/3125 - LOSS: 2.2642364501953125
	BATCH 141/3125 - LOSS: 2.0891497135162354
	BATCH 161/3125 - LOSS: 2.029125452041626
	BATCH 181/3125 - LOSS: 2.315627098083496
	BATCH 201/3125 - LOSS: 1.9504667520523071
	BATCH 221/3125 - LOSS: 2.156087875366211
	BATCH 241/3125 - LOSS: 2.299520254135132
