In [1]:
# Install required packages
!pip install linformer
!pip install vit_pytorch

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from linformer import Linformer
from PIL import Image
from torch.optim.lr_scheduler import StepLR
from tqdm.notebook import tqdm
from vit_pytorch.efficient import ViT
from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.metrics import confusion_matrix
import torch.utils.data as data
import torchvision
from torchvision import transforms



In [2]:
# Check if CUDA is available
print(torch.cuda.is_available())

# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Set seed for reproducibility
torch.manual_seed(142)

True


<torch._C.Generator at 0x7fa3fbe0bb30>

In [3]:
# Hyperparameters
batch_size = 64
epochs = 100
lr = 0.0001
gamma = 0.7
IMG_SIZE = 200
patch_size = 20
num_classes = 2

In [4]:
# Transforms for image resizing and normalization
'''transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor()
])'''
'''transform = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])'''

transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.ToTensor()
])


In [5]:
# Define paths
# train_dir = '/content/drive/MyDrive/dataset/training'
# val_dir = '/content/drive/MyDrive/dataset/test'
# test_dir = '/content/drive/MyDrive/dataset/validation'


train_dir = 'training_files/target_training_datasets/CHEMBL286/dataset/training'
val_dir = 'training_files/target_training_datasets/CHEMBL286/dataset/validation'
test_dir = 'training_files/target_training_datasets/CHEMBL286/dataset/test'


# Load datasets
train_ds = torchvision.datasets.ImageFolder(train_dir, transform=transform)
valid_ds = torchvision.datasets.ImageFolder(val_dir, transform=transform)
test_ds = torchvision.datasets.ImageFolder(test_dir, transform=transform)

# Data loaders
train_loader = data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0)
valid_loader = data.DataLoader(valid_ds, batch_size=batch_size, shuffle=True, num_workers=0)
test_loader = data.DataLoader(test_ds, batch_size=batch_size, shuffle=True, num_workers=0)

# Linear Transformer
efficient_transformer = Linformer(dim=256, seq_len=(IMG_SIZE // patch_size) ** 2 + 1, depth=24, heads=16, k=128)

# Vision Transformer Model
model = ViT(
    dim=256,
    image_size=IMG_SIZE,
    patch_size=patch_size,
    num_classes=num_classes,
    transformer=efficient_transformer,
    channels=3,
).to(device)

# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)

# Learning rate scheduler
scheduler = StepLR(optimizer, step_size=10, gamma=gamma)

# Training loop
for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    epoch_accuracy = 0
    for data, label in tqdm(train_loader):
        data, label = data.to(device), label.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    model.eval()
    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data, label = data.to(device), label.to(device)
            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)

    print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.4f}, Val Loss: {epoch_val_loss:.4f}, Val Accuracy: {epoch_val_accuracy:.4f}")

# Save the model
PATH = f"epochs_{epochs}img{IMG_SIZE}patch{patch_size}lr{lr}.pt"
torch.save(model.state_dict(), PATH)

# Load saved model
efficient_transformer = Linformer(dim=128, seq_len=(IMG_SIZE // patch_size) ** 2 + 1, depth=12, heads=8, k=64)
model = ViT(image_size=IMG_SIZE, patch_size=patch_size, num_classes=num_classes, dim=128, transformer=efficient_transformer, channels=3).to(device)
model.load_state_dict(torch.load(PATH))

# Function to calculate overall accuracy
def overall_accuracy(model, test_loader, criterion):
    model.eval()
    y_proba = []
    y_truth = []
    test_loss = 0
    total = 0
    correct = 0
    with torch.no_grad():
        for data, label in tqdm(test_loader):
            data, label = data.to(device), label.to(device)
            output = model(data)
            test_loss += criterion(output, label.long()).item()
            for index, i in enumerate(output):
                y_proba.append(i[1].item())
                y_truth.append(label[index].item())
                if torch.argmax(i) == label[index]:
                    correct += 1
                total += 1
    accuracy = correct / total
    y_proba_out = np.array(y_proba)
    y_truth_out = np.array(y_truth)
    return test_loss, accuracy, y_proba_out, y_truth_out

# Evaluate model on test data
loss, acc, y_proba, y_truth = overall_accuracy(model, test_loader, criterion)

print(f"Test Accuracy: {acc:.4f}")

# Plot confusion matrix
cm = confusion_matrix(y_truth, np.argmax(y_proba.reshape(-1, 1), axis=1))
print(cm)

  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 1/100, Loss: 0.8855, Accuracy: 0.5063, Val Loss: 0.6849, Val Accuracy: 0.5878


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 2/100, Loss: 0.6736, Accuracy: 0.6023, Val Loss: 0.6809, Val Accuracy: 0.5913


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 3/100, Loss: 0.6742, Accuracy: 0.5996, Val Loss: 0.6778, Val Accuracy: 0.5948


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 4/100, Loss: 0.6711, Accuracy: 0.5979, Val Loss: 0.6715, Val Accuracy: 0.6001


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 5/100, Loss: 0.6691, Accuracy: 0.5979, Val Loss: 0.6622, Val Accuracy: 0.6501


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 6/100, Loss: 0.6544, Accuracy: 0.6196, Val Loss: 0.6527, Val Accuracy: 0.6345


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 7/100, Loss: 0.6441, Accuracy: 0.6408, Val Loss: 0.6614, Val Accuracy: 0.6585


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 8/100, Loss: 0.6142, Accuracy: 0.6739, Val Loss: 0.6404, Val Accuracy: 0.6307


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 9/100, Loss: 0.6023, Accuracy: 0.6943, Val Loss: 0.6452, Val Accuracy: 0.6386


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 10/100, Loss: 0.5911, Accuracy: 0.6877, Val Loss: 0.6436, Val Accuracy: 0.6241


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 11/100, Loss: 0.5804, Accuracy: 0.7046, Val Loss: 0.6213, Val Accuracy: 0.6745


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 12/100, Loss: 0.5504, Accuracy: 0.7257, Val Loss: 0.5967, Val Accuracy: 0.6800


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 13/100, Loss: 0.5393, Accuracy: 0.7404, Val Loss: 0.6341, Val Accuracy: 0.6727


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 14/100, Loss: 0.5179, Accuracy: 0.7412, Val Loss: 0.6357, Val Accuracy: 0.6776


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 15/100, Loss: 0.4870, Accuracy: 0.7858, Val Loss: 0.5969, Val Accuracy: 0.6970


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 16/100, Loss: 0.4608, Accuracy: 0.7939, Val Loss: 0.5564, Val Accuracy: 0.7422


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 17/100, Loss: 0.4286, Accuracy: 0.8016, Val Loss: 0.6341, Val Accuracy: 0.6727


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 18/100, Loss: 0.4105, Accuracy: 0.8114, Val Loss: 0.6516, Val Accuracy: 0.6928


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 19/100, Loss: 0.3745, Accuracy: 0.8393, Val Loss: 0.6123, Val Accuracy: 0.7085


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 20/100, Loss: 0.3625, Accuracy: 0.8427, Val Loss: 0.6841, Val Accuracy: 0.7085


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 21/100, Loss: 0.3673, Accuracy: 0.8483, Val Loss: 0.7121, Val Accuracy: 0.6553


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 22/100, Loss: 0.3177, Accuracy: 0.8680, Val Loss: 0.7095, Val Accuracy: 0.7210


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 23/100, Loss: 0.2862, Accuracy: 0.8832, Val Loss: 0.7353, Val Accuracy: 0.6970


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 24/100, Loss: 0.2601, Accuracy: 0.8901, Val Loss: 0.7771, Val Accuracy: 0.6956


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 25/100, Loss: 0.2787, Accuracy: 0.8905, Val Loss: 0.7199, Val Accuracy: 0.7227


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 26/100, Loss: 0.2708, Accuracy: 0.8987, Val Loss: 0.7350, Val Accuracy: 0.6868


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 27/100, Loss: 0.2261, Accuracy: 0.9135, Val Loss: 0.8822, Val Accuracy: 0.6893


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 28/100, Loss: 0.2220, Accuracy: 0.9168, Val Loss: 0.7347, Val Accuracy: 0.7063


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 29/100, Loss: 0.2111, Accuracy: 0.9177, Val Loss: 0.7894, Val Accuracy: 0.7022


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 30/100, Loss: 0.1919, Accuracy: 0.9341, Val Loss: 0.8801, Val Accuracy: 0.6803


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 31/100, Loss: 0.1800, Accuracy: 0.9284, Val Loss: 0.7658, Val Accuracy: 0.6905


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 32/100, Loss: 0.1357, Accuracy: 0.9522, Val Loss: 1.1183, Val Accuracy: 0.6995


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 33/100, Loss: 0.1237, Accuracy: 0.9481, Val Loss: 0.9919, Val Accuracy: 0.7227


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 34/100, Loss: 0.1521, Accuracy: 0.9473, Val Loss: 0.8475, Val Accuracy: 0.7530


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 35/100, Loss: 0.1150, Accuracy: 0.9579, Val Loss: 0.8856, Val Accuracy: 0.7745


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 36/100, Loss: 0.1068, Accuracy: 0.9588, Val Loss: 1.0477, Val Accuracy: 0.7227


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 37/100, Loss: 0.0946, Accuracy: 0.9696, Val Loss: 0.9141, Val Accuracy: 0.7147


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 38/100, Loss: 0.0918, Accuracy: 0.9687, Val Loss: 0.9542, Val Accuracy: 0.7202


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 39/100, Loss: 0.0769, Accuracy: 0.9745, Val Loss: 1.1553, Val Accuracy: 0.7335


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 40/100, Loss: 0.0801, Accuracy: 0.9695, Val Loss: 1.0263, Val Accuracy: 0.7739


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 41/100, Loss: 0.0700, Accuracy: 0.9794, Val Loss: 1.1983, Val Accuracy: 0.7217


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 42/100, Loss: 0.0690, Accuracy: 0.9786, Val Loss: 1.1625, Val Accuracy: 0.7001


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 43/100, Loss: 0.0675, Accuracy: 0.9794, Val Loss: 1.2111, Val Accuracy: 0.7178


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 44/100, Loss: 0.0729, Accuracy: 0.9712, Val Loss: 1.1931, Val Accuracy: 0.7079


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 45/100, Loss: 0.0490, Accuracy: 0.9786, Val Loss: 1.4385, Val Accuracy: 0.6768


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 46/100, Loss: 0.0725, Accuracy: 0.9786, Val Loss: 1.2164, Val Accuracy: 0.7380


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 47/100, Loss: 0.0528, Accuracy: 0.9810, Val Loss: 1.2816, Val Accuracy: 0.7143


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 48/100, Loss: 0.0717, Accuracy: 0.9778, Val Loss: 0.9952, Val Accuracy: 0.7418


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 49/100, Loss: 0.0670, Accuracy: 0.9753, Val Loss: 1.1870, Val Accuracy: 0.7088


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 50/100, Loss: 0.0484, Accuracy: 0.9802, Val Loss: 1.2624, Val Accuracy: 0.7171


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 51/100, Loss: 0.0654, Accuracy: 0.9778, Val Loss: 1.0385, Val Accuracy: 0.7321


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 52/100, Loss: 0.0485, Accuracy: 0.9836, Val Loss: 1.0292, Val Accuracy: 0.7130


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 53/100, Loss: 0.0348, Accuracy: 0.9893, Val Loss: 1.4026, Val Accuracy: 0.7223


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 54/100, Loss: 0.0380, Accuracy: 0.9827, Val Loss: 1.2678, Val Accuracy: 0.7428


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 55/100, Loss: 0.0326, Accuracy: 0.9910, Val Loss: 1.2343, Val Accuracy: 0.7411


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 56/100, Loss: 0.0617, Accuracy: 0.9770, Val Loss: 1.0464, Val Accuracy: 0.7422


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 57/100, Loss: 0.0485, Accuracy: 0.9819, Val Loss: 1.3427, Val Accuracy: 0.6925


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 58/100, Loss: 0.0368, Accuracy: 0.9868, Val Loss: 1.0382, Val Accuracy: 0.7498


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 59/100, Loss: 0.0449, Accuracy: 0.9819, Val Loss: 1.1635, Val Accuracy: 0.7245


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 60/100, Loss: 0.0202, Accuracy: 0.9959, Val Loss: 1.4032, Val Accuracy: 0.7217


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 61/100, Loss: 0.0429, Accuracy: 0.9844, Val Loss: 1.2820, Val Accuracy: 0.7290


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 62/100, Loss: 0.0209, Accuracy: 0.9942, Val Loss: 1.2013, Val Accuracy: 0.7428


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 63/100, Loss: 0.0261, Accuracy: 0.9934, Val Loss: 1.6703, Val Accuracy: 0.6876


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 64/100, Loss: 0.0213, Accuracy: 0.9918, Val Loss: 1.5006, Val Accuracy: 0.7397


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 65/100, Loss: 0.0255, Accuracy: 0.9901, Val Loss: 1.5528, Val Accuracy: 0.7227


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 66/100, Loss: 0.0334, Accuracy: 0.9893, Val Loss: 1.3344, Val Accuracy: 0.7307


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 67/100, Loss: 0.0235, Accuracy: 0.9934, Val Loss: 1.2168, Val Accuracy: 0.7352


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 68/100, Loss: 0.0102, Accuracy: 0.9967, Val Loss: 1.4311, Val Accuracy: 0.7300


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 69/100, Loss: 0.0280, Accuracy: 0.9918, Val Loss: 1.2584, Val Accuracy: 0.7565


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 70/100, Loss: 0.0319, Accuracy: 0.9885, Val Loss: 1.3538, Val Accuracy: 0.7141


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 71/100, Loss: 0.0252, Accuracy: 0.9910, Val Loss: 1.5692, Val Accuracy: 0.7116


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 72/100, Loss: 0.0253, Accuracy: 0.9909, Val Loss: 1.3583, Val Accuracy: 0.7067


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 73/100, Loss: 0.0226, Accuracy: 0.9901, Val Loss: 1.4313, Val Accuracy: 0.7255


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 74/100, Loss: 0.0271, Accuracy: 0.9909, Val Loss: 1.4540, Val Accuracy: 0.7352


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 75/100, Loss: 0.0187, Accuracy: 0.9951, Val Loss: 1.5186, Val Accuracy: 0.7231


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 76/100, Loss: 0.0242, Accuracy: 0.9934, Val Loss: 1.2111, Val Accuracy: 0.7561


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 77/100, Loss: 0.0207, Accuracy: 0.9934, Val Loss: 1.4165, Val Accuracy: 0.7293


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 78/100, Loss: 0.0342, Accuracy: 0.9868, Val Loss: 1.5540, Val Accuracy: 0.6934


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 79/100, Loss: 0.0397, Accuracy: 0.9885, Val Loss: 1.1613, Val Accuracy: 0.7248


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 80/100, Loss: 0.0237, Accuracy: 0.9934, Val Loss: 1.2811, Val Accuracy: 0.7405


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 81/100, Loss: 0.0204, Accuracy: 0.9926, Val Loss: 1.3760, Val Accuracy: 0.7311


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 82/100, Loss: 0.0174, Accuracy: 0.9951, Val Loss: 1.5276, Val Accuracy: 0.7088


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 83/100, Loss: 0.0187, Accuracy: 0.9951, Val Loss: 1.3543, Val Accuracy: 0.7418


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 84/100, Loss: 0.0130, Accuracy: 0.9975, Val Loss: 1.4646, Val Accuracy: 0.7303


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 85/100, Loss: 0.0053, Accuracy: 0.9984, Val Loss: 1.3651, Val Accuracy: 0.7428


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 86/100, Loss: 0.0124, Accuracy: 0.9967, Val Loss: 1.4649, Val Accuracy: 0.7512


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 87/100, Loss: 0.0172, Accuracy: 0.9967, Val Loss: 1.5958, Val Accuracy: 0.7428


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 88/100, Loss: 0.0136, Accuracy: 0.9942, Val Loss: 1.5323, Val Accuracy: 0.7245


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 89/100, Loss: 0.0144, Accuracy: 0.9951, Val Loss: 1.3879, Val Accuracy: 0.7473


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 90/100, Loss: 0.0176, Accuracy: 0.9950, Val Loss: 1.3035, Val Accuracy: 0.7571


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 91/100, Loss: 0.0134, Accuracy: 0.9959, Val Loss: 1.5051, Val Accuracy: 0.7213


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 92/100, Loss: 0.0272, Accuracy: 0.9934, Val Loss: 1.3253, Val Accuracy: 0.7547


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 93/100, Loss: 0.0163, Accuracy: 0.9934, Val Loss: 1.1967, Val Accuracy: 0.7686


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 94/100, Loss: 0.0245, Accuracy: 0.9926, Val Loss: 1.3074, Val Accuracy: 0.7516


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 95/100, Loss: 0.0124, Accuracy: 0.9967, Val Loss: 1.2428, Val Accuracy: 0.7356


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 96/100, Loss: 0.0149, Accuracy: 0.9951, Val Loss: 1.1836, Val Accuracy: 0.7512


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 97/100, Loss: 0.0155, Accuracy: 0.9942, Val Loss: 1.3483, Val Accuracy: 0.7409


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 98/100, Loss: 0.0084, Accuracy: 0.9975, Val Loss: 1.4181, Val Accuracy: 0.7272


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 99/100, Loss: 0.0136, Accuracy: 0.9958, Val Loss: 1.4238, Val Accuracy: 0.7227


  0%|          | 0/19 [00:00<?, ?it/s]

Epoch 100/100, Loss: 0.0176, Accuracy: 0.9942, Val Loss: 1.5795, Val Accuracy: 0.7241


RuntimeError: Error(s) in loading state_dict for ViT:
	Unexpected key(s) in state_dict: "transformer.net.layers.12.0.fn.proj_k", "transformer.net.layers.12.0.fn.proj_v", "transformer.net.layers.12.0.fn.to_q.weight", "transformer.net.layers.12.0.fn.to_k.weight", "transformer.net.layers.12.0.fn.to_v.weight", "transformer.net.layers.12.0.fn.to_out.weight", "transformer.net.layers.12.0.fn.to_out.bias", "transformer.net.layers.12.0.norm.weight", "transformer.net.layers.12.0.norm.bias", "transformer.net.layers.12.1.fn.w1.weight", "transformer.net.layers.12.1.fn.w1.bias", "transformer.net.layers.12.1.fn.w2.weight", "transformer.net.layers.12.1.fn.w2.bias", "transformer.net.layers.12.1.norm.weight", "transformer.net.layers.12.1.norm.bias", "transformer.net.layers.13.0.fn.proj_k", "transformer.net.layers.13.0.fn.proj_v", "transformer.net.layers.13.0.fn.to_q.weight", "transformer.net.layers.13.0.fn.to_k.weight", "transformer.net.layers.13.0.fn.to_v.weight", "transformer.net.layers.13.0.fn.to_out.weight", "transformer.net.layers.13.0.fn.to_out.bias", "transformer.net.layers.13.0.norm.weight", "transformer.net.layers.13.0.norm.bias", "transformer.net.layers.13.1.fn.w1.weight", "transformer.net.layers.13.1.fn.w1.bias", "transformer.net.layers.13.1.fn.w2.weight", "transformer.net.layers.13.1.fn.w2.bias", "transformer.net.layers.13.1.norm.weight", "transformer.net.layers.13.1.norm.bias", "transformer.net.layers.14.0.fn.proj_k", "transformer.net.layers.14.0.fn.proj_v", "transformer.net.layers.14.0.fn.to_q.weight", "transformer.net.layers.14.0.fn.to_k.weight", "transformer.net.layers.14.0.fn.to_v.weight", "transformer.net.layers.14.0.fn.to_out.weight", "transformer.net.layers.14.0.fn.to_out.bias", "transformer.net.layers.14.0.norm.weight", "transformer.net.layers.14.0.norm.bias", "transformer.net.layers.14.1.fn.w1.weight", "transformer.net.layers.14.1.fn.w1.bias", "transformer.net.layers.14.1.fn.w2.weight", "transformer.net.layers.14.1.fn.w2.bias", "transformer.net.layers.14.1.norm.weight", "transformer.net.layers.14.1.norm.bias", "transformer.net.layers.15.0.fn.proj_k", "transformer.net.layers.15.0.fn.proj_v", "transformer.net.layers.15.0.fn.to_q.weight", "transformer.net.layers.15.0.fn.to_k.weight", "transformer.net.layers.15.0.fn.to_v.weight", "transformer.net.layers.15.0.fn.to_out.weight", "transformer.net.layers.15.0.fn.to_out.bias", "transformer.net.layers.15.0.norm.weight", "transformer.net.layers.15.0.norm.bias", "transformer.net.layers.15.1.fn.w1.weight", "transformer.net.layers.15.1.fn.w1.bias", "transformer.net.layers.15.1.fn.w2.weight", "transformer.net.layers.15.1.fn.w2.bias", "transformer.net.layers.15.1.norm.weight", "transformer.net.layers.15.1.norm.bias", "transformer.net.layers.16.0.fn.proj_k", "transformer.net.layers.16.0.fn.proj_v", "transformer.net.layers.16.0.fn.to_q.weight", "transformer.net.layers.16.0.fn.to_k.weight", "transformer.net.layers.16.0.fn.to_v.weight", "transformer.net.layers.16.0.fn.to_out.weight", "transformer.net.layers.16.0.fn.to_out.bias", "transformer.net.layers.16.0.norm.weight", "transformer.net.layers.16.0.norm.bias", "transformer.net.layers.16.1.fn.w1.weight", "transformer.net.layers.16.1.fn.w1.bias", "transformer.net.layers.16.1.fn.w2.weight", "transformer.net.layers.16.1.fn.w2.bias", "transformer.net.layers.16.1.norm.weight", "transformer.net.layers.16.1.norm.bias", "transformer.net.layers.17.0.fn.proj_k", "transformer.net.layers.17.0.fn.proj_v", "transformer.net.layers.17.0.fn.to_q.weight", "transformer.net.layers.17.0.fn.to_k.weight", "transformer.net.layers.17.0.fn.to_v.weight", "transformer.net.layers.17.0.fn.to_out.weight", "transformer.net.layers.17.0.fn.to_out.bias", "transformer.net.layers.17.0.norm.weight", "transformer.net.layers.17.0.norm.bias", "transformer.net.layers.17.1.fn.w1.weight", "transformer.net.layers.17.1.fn.w1.bias", "transformer.net.layers.17.1.fn.w2.weight", "transformer.net.layers.17.1.fn.w2.bias", "transformer.net.layers.17.1.norm.weight", "transformer.net.layers.17.1.norm.bias", "transformer.net.layers.18.0.fn.proj_k", "transformer.net.layers.18.0.fn.proj_v", "transformer.net.layers.18.0.fn.to_q.weight", "transformer.net.layers.18.0.fn.to_k.weight", "transformer.net.layers.18.0.fn.to_v.weight", "transformer.net.layers.18.0.fn.to_out.weight", "transformer.net.layers.18.0.fn.to_out.bias", "transformer.net.layers.18.0.norm.weight", "transformer.net.layers.18.0.norm.bias", "transformer.net.layers.18.1.fn.w1.weight", "transformer.net.layers.18.1.fn.w1.bias", "transformer.net.layers.18.1.fn.w2.weight", "transformer.net.layers.18.1.fn.w2.bias", "transformer.net.layers.18.1.norm.weight", "transformer.net.layers.18.1.norm.bias", "transformer.net.layers.19.0.fn.proj_k", "transformer.net.layers.19.0.fn.proj_v", "transformer.net.layers.19.0.fn.to_q.weight", "transformer.net.layers.19.0.fn.to_k.weight", "transformer.net.layers.19.0.fn.to_v.weight", "transformer.net.layers.19.0.fn.to_out.weight", "transformer.net.layers.19.0.fn.to_out.bias", "transformer.net.layers.19.0.norm.weight", "transformer.net.layers.19.0.norm.bias", "transformer.net.layers.19.1.fn.w1.weight", "transformer.net.layers.19.1.fn.w1.bias", "transformer.net.layers.19.1.fn.w2.weight", "transformer.net.layers.19.1.fn.w2.bias", "transformer.net.layers.19.1.norm.weight", "transformer.net.layers.19.1.norm.bias", "transformer.net.layers.20.0.fn.proj_k", "transformer.net.layers.20.0.fn.proj_v", "transformer.net.layers.20.0.fn.to_q.weight", "transformer.net.layers.20.0.fn.to_k.weight", "transformer.net.layers.20.0.fn.to_v.weight", "transformer.net.layers.20.0.fn.to_out.weight", "transformer.net.layers.20.0.fn.to_out.bias", "transformer.net.layers.20.0.norm.weight", "transformer.net.layers.20.0.norm.bias", "transformer.net.layers.20.1.fn.w1.weight", "transformer.net.layers.20.1.fn.w1.bias", "transformer.net.layers.20.1.fn.w2.weight", "transformer.net.layers.20.1.fn.w2.bias", "transformer.net.layers.20.1.norm.weight", "transformer.net.layers.20.1.norm.bias", "transformer.net.layers.21.0.fn.proj_k", "transformer.net.layers.21.0.fn.proj_v", "transformer.net.layers.21.0.fn.to_q.weight", "transformer.net.layers.21.0.fn.to_k.weight", "transformer.net.layers.21.0.fn.to_v.weight", "transformer.net.layers.21.0.fn.to_out.weight", "transformer.net.layers.21.0.fn.to_out.bias", "transformer.net.layers.21.0.norm.weight", "transformer.net.layers.21.0.norm.bias", "transformer.net.layers.21.1.fn.w1.weight", "transformer.net.layers.21.1.fn.w1.bias", "transformer.net.layers.21.1.fn.w2.weight", "transformer.net.layers.21.1.fn.w2.bias", "transformer.net.layers.21.1.norm.weight", "transformer.net.layers.21.1.norm.bias", "transformer.net.layers.22.0.fn.proj_k", "transformer.net.layers.22.0.fn.proj_v", "transformer.net.layers.22.0.fn.to_q.weight", "transformer.net.layers.22.0.fn.to_k.weight", "transformer.net.layers.22.0.fn.to_v.weight", "transformer.net.layers.22.0.fn.to_out.weight", "transformer.net.layers.22.0.fn.to_out.bias", "transformer.net.layers.22.0.norm.weight", "transformer.net.layers.22.0.norm.bias", "transformer.net.layers.22.1.fn.w1.weight", "transformer.net.layers.22.1.fn.w1.bias", "transformer.net.layers.22.1.fn.w2.weight", "transformer.net.layers.22.1.fn.w2.bias", "transformer.net.layers.22.1.norm.weight", "transformer.net.layers.22.1.norm.bias", "transformer.net.layers.23.0.fn.proj_k", "transformer.net.layers.23.0.fn.proj_v", "transformer.net.layers.23.0.fn.to_q.weight", "transformer.net.layers.23.0.fn.to_k.weight", "transformer.net.layers.23.0.fn.to_v.weight", "transformer.net.layers.23.0.fn.to_out.weight", "transformer.net.layers.23.0.fn.to_out.bias", "transformer.net.layers.23.0.norm.weight", "transformer.net.layers.23.0.norm.bias", "transformer.net.layers.23.1.fn.w1.weight", "transformer.net.layers.23.1.fn.w1.bias", "transformer.net.layers.23.1.fn.w2.weight", "transformer.net.layers.23.1.fn.w2.bias", "transformer.net.layers.23.1.norm.weight", "transformer.net.layers.23.1.norm.bias". 
	size mismatch for pos_embedding: copying a param with shape torch.Size([1, 101, 256]) from checkpoint, the shape in current model is torch.Size([1, 101, 128]).
	size mismatch for cls_token: copying a param with shape torch.Size([1, 1, 256]) from checkpoint, the shape in current model is torch.Size([1, 1, 128]).
	size mismatch for to_patch_embedding.2.weight: copying a param with shape torch.Size([256, 1200]) from checkpoint, the shape in current model is torch.Size([128, 1200]).
	size mismatch for to_patch_embedding.2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for to_patch_embedding.3.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for to_patch_embedding.3.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.0.0.fn.proj_k: copying a param with shape torch.Size([101, 128]) from checkpoint, the shape in current model is torch.Size([101, 64]).
	size mismatch for transformer.net.layers.0.0.fn.proj_v: copying a param with shape torch.Size([101, 128]) from checkpoint, the shape in current model is torch.Size([101, 64]).
	size mismatch for transformer.net.layers.0.0.fn.to_q.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.0.0.fn.to_k.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.0.0.fn.to_v.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.0.0.fn.to_out.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.0.0.fn.to_out.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.0.0.norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.0.0.norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.0.1.fn.w1.weight: copying a param with shape torch.Size([1024, 256]) from checkpoint, the shape in current model is torch.Size([512, 128]).
	size mismatch for transformer.net.layers.0.1.fn.w1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for transformer.net.layers.0.1.fn.w2.weight: copying a param with shape torch.Size([256, 1024]) from checkpoint, the shape in current model is torch.Size([128, 512]).
	size mismatch for transformer.net.layers.0.1.fn.w2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.0.1.norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.0.1.norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.1.0.fn.proj_k: copying a param with shape torch.Size([101, 128]) from checkpoint, the shape in current model is torch.Size([101, 64]).
	size mismatch for transformer.net.layers.1.0.fn.proj_v: copying a param with shape torch.Size([101, 128]) from checkpoint, the shape in current model is torch.Size([101, 64]).
	size mismatch for transformer.net.layers.1.0.fn.to_q.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.1.0.fn.to_k.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.1.0.fn.to_v.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.1.0.fn.to_out.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.1.0.fn.to_out.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.1.0.norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.1.0.norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.1.1.fn.w1.weight: copying a param with shape torch.Size([1024, 256]) from checkpoint, the shape in current model is torch.Size([512, 128]).
	size mismatch for transformer.net.layers.1.1.fn.w1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for transformer.net.layers.1.1.fn.w2.weight: copying a param with shape torch.Size([256, 1024]) from checkpoint, the shape in current model is torch.Size([128, 512]).
	size mismatch for transformer.net.layers.1.1.fn.w2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.1.1.norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.1.1.norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.2.0.fn.proj_k: copying a param with shape torch.Size([101, 128]) from checkpoint, the shape in current model is torch.Size([101, 64]).
	size mismatch for transformer.net.layers.2.0.fn.proj_v: copying a param with shape torch.Size([101, 128]) from checkpoint, the shape in current model is torch.Size([101, 64]).
	size mismatch for transformer.net.layers.2.0.fn.to_q.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.2.0.fn.to_k.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.2.0.fn.to_v.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.2.0.fn.to_out.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.2.0.fn.to_out.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.2.0.norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.2.0.norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.2.1.fn.w1.weight: copying a param with shape torch.Size([1024, 256]) from checkpoint, the shape in current model is torch.Size([512, 128]).
	size mismatch for transformer.net.layers.2.1.fn.w1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for transformer.net.layers.2.1.fn.w2.weight: copying a param with shape torch.Size([256, 1024]) from checkpoint, the shape in current model is torch.Size([128, 512]).
	size mismatch for transformer.net.layers.2.1.fn.w2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.2.1.norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.2.1.norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.3.0.fn.proj_k: copying a param with shape torch.Size([101, 128]) from checkpoint, the shape in current model is torch.Size([101, 64]).
	size mismatch for transformer.net.layers.3.0.fn.proj_v: copying a param with shape torch.Size([101, 128]) from checkpoint, the shape in current model is torch.Size([101, 64]).
	size mismatch for transformer.net.layers.3.0.fn.to_q.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.3.0.fn.to_k.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.3.0.fn.to_v.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.3.0.fn.to_out.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.3.0.fn.to_out.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.3.0.norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.3.0.norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.3.1.fn.w1.weight: copying a param with shape torch.Size([1024, 256]) from checkpoint, the shape in current model is torch.Size([512, 128]).
	size mismatch for transformer.net.layers.3.1.fn.w1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for transformer.net.layers.3.1.fn.w2.weight: copying a param with shape torch.Size([256, 1024]) from checkpoint, the shape in current model is torch.Size([128, 512]).
	size mismatch for transformer.net.layers.3.1.fn.w2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.3.1.norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.3.1.norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.4.0.fn.proj_k: copying a param with shape torch.Size([101, 128]) from checkpoint, the shape in current model is torch.Size([101, 64]).
	size mismatch for transformer.net.layers.4.0.fn.proj_v: copying a param with shape torch.Size([101, 128]) from checkpoint, the shape in current model is torch.Size([101, 64]).
	size mismatch for transformer.net.layers.4.0.fn.to_q.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.4.0.fn.to_k.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.4.0.fn.to_v.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.4.0.fn.to_out.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.4.0.fn.to_out.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.4.0.norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.4.0.norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.4.1.fn.w1.weight: copying a param with shape torch.Size([1024, 256]) from checkpoint, the shape in current model is torch.Size([512, 128]).
	size mismatch for transformer.net.layers.4.1.fn.w1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for transformer.net.layers.4.1.fn.w2.weight: copying a param with shape torch.Size([256, 1024]) from checkpoint, the shape in current model is torch.Size([128, 512]).
	size mismatch for transformer.net.layers.4.1.fn.w2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.4.1.norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.4.1.norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.5.0.fn.proj_k: copying a param with shape torch.Size([101, 128]) from checkpoint, the shape in current model is torch.Size([101, 64]).
	size mismatch for transformer.net.layers.5.0.fn.proj_v: copying a param with shape torch.Size([101, 128]) from checkpoint, the shape in current model is torch.Size([101, 64]).
	size mismatch for transformer.net.layers.5.0.fn.to_q.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.5.0.fn.to_k.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.5.0.fn.to_v.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.5.0.fn.to_out.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.5.0.fn.to_out.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.5.0.norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.5.0.norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.5.1.fn.w1.weight: copying a param with shape torch.Size([1024, 256]) from checkpoint, the shape in current model is torch.Size([512, 128]).
	size mismatch for transformer.net.layers.5.1.fn.w1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for transformer.net.layers.5.1.fn.w2.weight: copying a param with shape torch.Size([256, 1024]) from checkpoint, the shape in current model is torch.Size([128, 512]).
	size mismatch for transformer.net.layers.5.1.fn.w2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.5.1.norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.5.1.norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.6.0.fn.proj_k: copying a param with shape torch.Size([101, 128]) from checkpoint, the shape in current model is torch.Size([101, 64]).
	size mismatch for transformer.net.layers.6.0.fn.proj_v: copying a param with shape torch.Size([101, 128]) from checkpoint, the shape in current model is torch.Size([101, 64]).
	size mismatch for transformer.net.layers.6.0.fn.to_q.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.6.0.fn.to_k.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.6.0.fn.to_v.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.6.0.fn.to_out.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.6.0.fn.to_out.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.6.0.norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.6.0.norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.6.1.fn.w1.weight: copying a param with shape torch.Size([1024, 256]) from checkpoint, the shape in current model is torch.Size([512, 128]).
	size mismatch for transformer.net.layers.6.1.fn.w1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for transformer.net.layers.6.1.fn.w2.weight: copying a param with shape torch.Size([256, 1024]) from checkpoint, the shape in current model is torch.Size([128, 512]).
	size mismatch for transformer.net.layers.6.1.fn.w2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.6.1.norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.6.1.norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.7.0.fn.proj_k: copying a param with shape torch.Size([101, 128]) from checkpoint, the shape in current model is torch.Size([101, 64]).
	size mismatch for transformer.net.layers.7.0.fn.proj_v: copying a param with shape torch.Size([101, 128]) from checkpoint, the shape in current model is torch.Size([101, 64]).
	size mismatch for transformer.net.layers.7.0.fn.to_q.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.7.0.fn.to_k.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.7.0.fn.to_v.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.7.0.fn.to_out.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.7.0.fn.to_out.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.7.0.norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.7.0.norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.7.1.fn.w1.weight: copying a param with shape torch.Size([1024, 256]) from checkpoint, the shape in current model is torch.Size([512, 128]).
	size mismatch for transformer.net.layers.7.1.fn.w1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for transformer.net.layers.7.1.fn.w2.weight: copying a param with shape torch.Size([256, 1024]) from checkpoint, the shape in current model is torch.Size([128, 512]).
	size mismatch for transformer.net.layers.7.1.fn.w2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.7.1.norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.7.1.norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.8.0.fn.proj_k: copying a param with shape torch.Size([101, 128]) from checkpoint, the shape in current model is torch.Size([101, 64]).
	size mismatch for transformer.net.layers.8.0.fn.proj_v: copying a param with shape torch.Size([101, 128]) from checkpoint, the shape in current model is torch.Size([101, 64]).
	size mismatch for transformer.net.layers.8.0.fn.to_q.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.8.0.fn.to_k.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.8.0.fn.to_v.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.8.0.fn.to_out.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.8.0.fn.to_out.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.8.0.norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.8.0.norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.8.1.fn.w1.weight: copying a param with shape torch.Size([1024, 256]) from checkpoint, the shape in current model is torch.Size([512, 128]).
	size mismatch for transformer.net.layers.8.1.fn.w1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for transformer.net.layers.8.1.fn.w2.weight: copying a param with shape torch.Size([256, 1024]) from checkpoint, the shape in current model is torch.Size([128, 512]).
	size mismatch for transformer.net.layers.8.1.fn.w2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.8.1.norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.8.1.norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.9.0.fn.proj_k: copying a param with shape torch.Size([101, 128]) from checkpoint, the shape in current model is torch.Size([101, 64]).
	size mismatch for transformer.net.layers.9.0.fn.proj_v: copying a param with shape torch.Size([101, 128]) from checkpoint, the shape in current model is torch.Size([101, 64]).
	size mismatch for transformer.net.layers.9.0.fn.to_q.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.9.0.fn.to_k.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.9.0.fn.to_v.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.9.0.fn.to_out.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.9.0.fn.to_out.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.9.0.norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.9.0.norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.9.1.fn.w1.weight: copying a param with shape torch.Size([1024, 256]) from checkpoint, the shape in current model is torch.Size([512, 128]).
	size mismatch for transformer.net.layers.9.1.fn.w1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for transformer.net.layers.9.1.fn.w2.weight: copying a param with shape torch.Size([256, 1024]) from checkpoint, the shape in current model is torch.Size([128, 512]).
	size mismatch for transformer.net.layers.9.1.fn.w2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.9.1.norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.9.1.norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.10.0.fn.proj_k: copying a param with shape torch.Size([101, 128]) from checkpoint, the shape in current model is torch.Size([101, 64]).
	size mismatch for transformer.net.layers.10.0.fn.proj_v: copying a param with shape torch.Size([101, 128]) from checkpoint, the shape in current model is torch.Size([101, 64]).
	size mismatch for transformer.net.layers.10.0.fn.to_q.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.10.0.fn.to_k.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.10.0.fn.to_v.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.10.0.fn.to_out.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.10.0.fn.to_out.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.10.0.norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.10.0.norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.10.1.fn.w1.weight: copying a param with shape torch.Size([1024, 256]) from checkpoint, the shape in current model is torch.Size([512, 128]).
	size mismatch for transformer.net.layers.10.1.fn.w1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for transformer.net.layers.10.1.fn.w2.weight: copying a param with shape torch.Size([256, 1024]) from checkpoint, the shape in current model is torch.Size([128, 512]).
	size mismatch for transformer.net.layers.10.1.fn.w2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.10.1.norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.10.1.norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.11.0.fn.proj_k: copying a param with shape torch.Size([101, 128]) from checkpoint, the shape in current model is torch.Size([101, 64]).
	size mismatch for transformer.net.layers.11.0.fn.proj_v: copying a param with shape torch.Size([101, 128]) from checkpoint, the shape in current model is torch.Size([101, 64]).
	size mismatch for transformer.net.layers.11.0.fn.to_q.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.11.0.fn.to_k.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.11.0.fn.to_v.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.11.0.fn.to_out.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([128, 128]).
	size mismatch for transformer.net.layers.11.0.fn.to_out.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.11.0.norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.11.0.norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.11.1.fn.w1.weight: copying a param with shape torch.Size([1024, 256]) from checkpoint, the shape in current model is torch.Size([512, 128]).
	size mismatch for transformer.net.layers.11.1.fn.w1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for transformer.net.layers.11.1.fn.w2.weight: copying a param with shape torch.Size([256, 1024]) from checkpoint, the shape in current model is torch.Size([128, 512]).
	size mismatch for transformer.net.layers.11.1.fn.w2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.11.1.norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for transformer.net.layers.11.1.norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for mlp_head.0.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for mlp_head.0.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for mlp_head.1.weight: copying a param with shape torch.Size([2, 256]) from checkpoint, the shape in current model is torch.Size([2, 128]).

In [6]:
# Function to calculate overall accuracy
def overall_accuracy(model, test_loader, criterion):
    model.eval()
    y_proba = []
    y_truth = []
    test_loss = 0
    total = 0
    correct = 0
    with torch.no_grad():
        for data, label in tqdm(test_loader):
            data, label = data.to(device), label.to(device)
            output = model(data)
            test_loss += criterion(output, label.long()).item()
            for index, i in enumerate(output):
                y_proba.append(i[1].item())
                y_truth.append(label[index].item())
                if torch.argmax(i) == label[index]:
                    correct += 1
                total += 1
    accuracy = correct / total
    y_proba_out = np.array(y_proba)
    y_truth_out = np.array(y_truth)
    return test_loss, accuracy, y_proba_out, y_truth_out

# Evaluate model on test data
loss, acc, y_proba, y_truth = overall_accuracy(model, test_loader, criterion)

print(f"Test Accuracy: {acc:.4f}")


  0%|          | 0/6 [00:00<?, ?it/s]

Test Accuracy: 0.4005
