In [1]:
import os

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.utils import make_grid

import matplotlib.pyplot as plt

from dataset import Dataset
from tools import getDataset, print_class_distribution

import numpy as np
from sklearn.model_selection import train_test_split

import random

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    print("MPS device not found.")
print('Device:', device)

Device: mps


In [3]:
def make_dataset_loaders(num_classes, batch_size, num_workers, image_size):
    root_dir = os.path.join(os.getcwd(), 'datasets/EuroSAT_RGB')
    dataset, label_mapping = getDataset(path=root_dir, num_classes=num_classes, num_images_per_class=20, shuffle_images=False)

    class_images = {i: [] for i in range(num_classes)}

    # Group images by class
    for image_path, class_label in dataset:
        class_images[class_label].append((image_path, class_label))

    # Initialize training and test sets
    train_set = []
    test_set = []

    # Select 5 images for training and 15 for testing from each class
    for class_label, images in class_images.items():
        train_set.extend(random.sample(images, k=5))
        test_set.extend(list(set(images) - set(train_set)))


    train_transforms = transforms.Compose(
            [
                transforms.Resize((image_size, image_size)),
                transforms.RandomCrop(image_size, padding=8),
                transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
                transforms.GaussianBlur(3, sigma=(0.1, 2.0)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation(degrees=5),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            ]
        )

    test_transforms = transforms.Compose(
            [
                transforms.Resize((image_size, image_size)),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
            ]
        )
        
    train_dataset = Dataset(dataset=train_set, path=root_dir, phase='train', transform=train_transforms)
    test_dataset = Dataset(dataset=test_set, path=root_dir, phase='test', transform=test_transforms)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    print_class_distribution(train_dataset, "Training", label_mapping)
    print_class_distribution(test_dataset, "Testing", label_mapping)

    return train_loader, test_loader

In [4]:
def eval(net, data_loader, criterion=nn.CrossEntropyLoss()):
    use_cuda = torch.cuda.is_available()
    if use_cuda:
        net = net.cuda()
    net.eval()
    correct = 0.0
    num_images = 0.0
    loss = 0.0
    for i_batch, (images, labels) in enumerate(data_loader):
        images, labels = images.to(device), labels.to(device)
        outs = net(images)
        loss += criterion(outs, labels).item()
        _, predicted = torch.max(outs.data, 1)
        correct += (predicted == labels).sum().item()
        num_images += len(labels)
        print('testing -> batch: %d correct: %d num of images: %d' % (i_batch, correct, num_images) + '\r', end='')
    acc = correct / num_images
    loss /= len(data_loader)
    return acc, loss


# training function
def train(net, train_loader, test_loader, num_epochs, learning_rate):

    criterion = nn.CrossEntropyLoss()
    # optimizer = torch.optim.SGD(params= net.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0001)
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate, weight_decay=0.0001, betas=(0.5, 0.999))
    scheduler = StepLR(optimizer, step_size=7, gamma=0.01)

    use_cuda = torch.cuda.is_available()
    if use_cuda:
        net = net.cuda()

    training_losses = []
    val_losses = []
    for epoch in range(num_epochs):
        net.train()
        correct = 0.0  # used to accumulate number of correctly recognized images
        num_images = 0.0  # used to accumulate number of images
        total_loss = 0.0

        for i_batch, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)
            output_train = net(images)
            loss = criterion(output_train, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            predicts = output_train.argmax(dim=1)
            correct += predicts.eq(labels).sum().item()
            num_images += len(labels)
            total_loss += loss.item()

            print('training -> epoch: %d, batch: %d, loss: %f' % (epoch, i_batch, loss.item()) + '\r', end='')

        print()
        acc = correct / num_images
        acc_test, test_loss = eval(net, test_loader)
        average_loss = total_loss / len(train_loader)
        training_losses.append(average_loss)
        val_losses.append(test_loss)
        print('\nepoch: %d, lr: %f, accuracy: %f, avg. loss: %f, test accuracy: %f test loss: %f\n' % (epoch, optimizer.param_groups[0]['lr'], acc, average_loss, acc_test, test_loss))

        scheduler.step()

    return net, training_losses, val_losses

In [6]:
from models.ViT import ViT

batch_size = 5
num_workers = 4
learning_rate = 0.0001 # 0.00005 - the best one so far
num_epochs = 3
num_classes = 5
image_size = 224

print(f"Hyperparameters:")
print(f"Batch Size: {batch_size}")
print(f"Learning Rate: {learning_rate}")
print(f"Number of Epochs: {num_epochs}")
print(f"Number of Workers: {num_workers}")
print(f"Number of Classes: {num_classes}\n")


loaded_model = torch.load(os.path.join(os.getcwd(), 'pretrained/vit_model_best_90.pth'), map_location=device)
model = ViT(num_classes=50).to(device)

# Load the state dictionary into your model
model.load_state_dict(loaded_model.state_dict())

in_features = model.vit.head.in_features
model.vit.head = nn.Linear(in_features, num_classes).to(device)  

for param in model.parameters():
    param.requires_grad = True

for iter in range(4):
    print('--------------------------------------------------------------')
    print(f"Training Iteration: {iter}")
    train_loader, test_loader = make_dataset_loaders(num_classes=num_classes, batch_size=batch_size, num_workers=num_workers, image_size=image_size)
    model, training_losses, val_losses = train(net=model, train_loader=train_loader, test_loader=test_loader, num_epochs=num_epochs, learning_rate=learning_rate)

    acc_test, test_loss = eval(model, test_loader)
    print('\naccuracy on testing data: %f' % acc_test)
    print('loss on testing data: %f' % test_loss)
    print('--------------------------------------------------------------')
    
plt.plot(training_losses, label='Training Loss')
plt.xlabel('Epoch')
plt.plot(test_loss, label='Validation Loss')
plt.ylabel('Loss')
plt.legend()
plt.show()

Hyperparameters:
Batch Size: 5
Learning Rate: 0.0001
Number of Epochs: 3
Number of Workers: 4
Number of Classes: 5



Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth" to /Users/nipunwaas/.cache/torch/hub/checkpoints/vit_small_p16_224-15ec54c9.pth


RuntimeError: Error(s) in loading state_dict for ViT:
	Unexpected key(s) in state_dict: "vit.blocks.8.norm1.weight", "vit.blocks.8.norm1.bias", "vit.blocks.8.attn.qkv.weight", "vit.blocks.8.attn.qkv.bias", "vit.blocks.8.attn.proj.weight", "vit.blocks.8.attn.proj.bias", "vit.blocks.8.norm2.weight", "vit.blocks.8.norm2.bias", "vit.blocks.8.mlp.fc1.weight", "vit.blocks.8.mlp.fc1.bias", "vit.blocks.8.mlp.fc2.weight", "vit.blocks.8.mlp.fc2.bias", "vit.blocks.9.norm1.weight", "vit.blocks.9.norm1.bias", "vit.blocks.9.attn.qkv.weight", "vit.blocks.9.attn.qkv.bias", "vit.blocks.9.attn.proj.weight", "vit.blocks.9.attn.proj.bias", "vit.blocks.9.norm2.weight", "vit.blocks.9.norm2.bias", "vit.blocks.9.mlp.fc1.weight", "vit.blocks.9.mlp.fc1.bias", "vit.blocks.9.mlp.fc2.weight", "vit.blocks.9.mlp.fc2.bias", "vit.blocks.10.norm1.weight", "vit.blocks.10.norm1.bias", "vit.blocks.10.attn.qkv.weight", "vit.blocks.10.attn.qkv.bias", "vit.blocks.10.attn.proj.weight", "vit.blocks.10.attn.proj.bias", "vit.blocks.10.norm2.weight", "vit.blocks.10.norm2.bias", "vit.blocks.10.mlp.fc1.weight", "vit.blocks.10.mlp.fc1.bias", "vit.blocks.10.mlp.fc2.weight", "vit.blocks.10.mlp.fc2.bias", "vit.blocks.11.norm1.weight", "vit.blocks.11.norm1.bias", "vit.blocks.11.attn.qkv.weight", "vit.blocks.11.attn.qkv.bias", "vit.blocks.11.attn.proj.weight", "vit.blocks.11.attn.proj.bias", "vit.blocks.11.norm2.weight", "vit.blocks.11.norm2.bias", "vit.blocks.11.mlp.fc1.weight", "vit.blocks.11.mlp.fc1.bias", "vit.blocks.11.mlp.fc2.weight", "vit.blocks.11.mlp.fc2.bias", "vit.blocks.0.attn.qkv.bias", "vit.blocks.1.attn.qkv.bias", "vit.blocks.2.attn.qkv.bias", "vit.blocks.3.attn.qkv.bias", "vit.blocks.4.attn.qkv.bias", "vit.blocks.5.attn.qkv.bias", "vit.blocks.6.attn.qkv.bias", "vit.blocks.7.attn.qkv.bias". 
	size mismatch for vit.cls_token: copying a param with shape torch.Size([1, 1, 384]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
	size mismatch for vit.pos_embed: copying a param with shape torch.Size([1, 197, 384]) from checkpoint, the shape in current model is torch.Size([1, 197, 768]).
	size mismatch for vit.patch_embed.proj.weight: copying a param with shape torch.Size([384, 3, 16, 16]) from checkpoint, the shape in current model is torch.Size([768, 3, 16, 16]).
	size mismatch for vit.patch_embed.proj.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.0.norm1.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.0.norm1.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.0.attn.qkv.weight: copying a param with shape torch.Size([1152, 384]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for vit.blocks.0.attn.proj.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for vit.blocks.0.attn.proj.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.0.norm2.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.0.norm2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.0.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for vit.blocks.0.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([2304]).
	size mismatch for vit.blocks.0.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([768, 2304]).
	size mismatch for vit.blocks.0.mlp.fc2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.1.norm1.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.1.norm1.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.1.attn.qkv.weight: copying a param with shape torch.Size([1152, 384]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for vit.blocks.1.attn.proj.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for vit.blocks.1.attn.proj.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.1.norm2.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.1.norm2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.1.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for vit.blocks.1.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([2304]).
	size mismatch for vit.blocks.1.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([768, 2304]).
	size mismatch for vit.blocks.1.mlp.fc2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.2.norm1.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.2.norm1.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.2.attn.qkv.weight: copying a param with shape torch.Size([1152, 384]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for vit.blocks.2.attn.proj.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for vit.blocks.2.attn.proj.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.2.norm2.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.2.norm2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.2.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for vit.blocks.2.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([2304]).
	size mismatch for vit.blocks.2.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([768, 2304]).
	size mismatch for vit.blocks.2.mlp.fc2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.3.norm1.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.3.norm1.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.3.attn.qkv.weight: copying a param with shape torch.Size([1152, 384]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for vit.blocks.3.attn.proj.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for vit.blocks.3.attn.proj.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.3.norm2.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.3.norm2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.3.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for vit.blocks.3.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([2304]).
	size mismatch for vit.blocks.3.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([768, 2304]).
	size mismatch for vit.blocks.3.mlp.fc2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.4.norm1.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.4.norm1.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.4.attn.qkv.weight: copying a param with shape torch.Size([1152, 384]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for vit.blocks.4.attn.proj.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for vit.blocks.4.attn.proj.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.4.norm2.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.4.norm2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.4.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for vit.blocks.4.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([2304]).
	size mismatch for vit.blocks.4.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([768, 2304]).
	size mismatch for vit.blocks.4.mlp.fc2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.5.norm1.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.5.norm1.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.5.attn.qkv.weight: copying a param with shape torch.Size([1152, 384]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for vit.blocks.5.attn.proj.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for vit.blocks.5.attn.proj.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.5.norm2.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.5.norm2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.5.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for vit.blocks.5.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([2304]).
	size mismatch for vit.blocks.5.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([768, 2304]).
	size mismatch for vit.blocks.5.mlp.fc2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.6.norm1.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.6.norm1.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.6.attn.qkv.weight: copying a param with shape torch.Size([1152, 384]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for vit.blocks.6.attn.proj.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for vit.blocks.6.attn.proj.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.6.norm2.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.6.norm2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.6.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for vit.blocks.6.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([2304]).
	size mismatch for vit.blocks.6.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([768, 2304]).
	size mismatch for vit.blocks.6.mlp.fc2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.7.norm1.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.7.norm1.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.7.attn.qkv.weight: copying a param with shape torch.Size([1152, 384]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for vit.blocks.7.attn.proj.weight: copying a param with shape torch.Size([384, 384]) from checkpoint, the shape in current model is torch.Size([768, 768]).
	size mismatch for vit.blocks.7.attn.proj.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.7.norm2.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.7.norm2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.blocks.7.mlp.fc1.weight: copying a param with shape torch.Size([1536, 384]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
	size mismatch for vit.blocks.7.mlp.fc1.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([2304]).
	size mismatch for vit.blocks.7.mlp.fc2.weight: copying a param with shape torch.Size([384, 1536]) from checkpoint, the shape in current model is torch.Size([768, 2304]).
	size mismatch for vit.blocks.7.mlp.fc2.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.norm.weight: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.norm.bias: copying a param with shape torch.Size([384]) from checkpoint, the shape in current model is torch.Size([768]).
	size mismatch for vit.head.weight: copying a param with shape torch.Size([50, 384]) from checkpoint, the shape in current model is torch.Size([50, 768]).

In [None]:
model.eval()  

for i_batch, (images, labels) in enumerate(test_loader):
    images, labels = images.to(device), labels.to(device)
    outs = model(images)
    _, predicted = torch.max(outs.data, 1)
    print(predicted)
    print(labels)
    _, predicted = torch.max(outs.data, 1)
    correct = (predicted == labels).sum().item()
    print(correct, len(labels))
    break

In [None]:
torch.save(model, os.path.join(os.getcwd(), 'pretrained/vit_trained.pth'))