**MOUNT DRIVE (IF FOR EXAMPLE YOU WANT TO READ/WRITE WEIGHTS FROM MyDrive):**

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


**Change current working directory from `content` to the directory of the location of this script in `/content/drive/MyDrive/img_cls_to_vit`. This allows the hard-coded relative paths from local machine set up to work here in Colab:**

In [2]:
import os
os.chdir('/content/drive/MyDrive/img_cls_to_vit')
os.getcwd()

'/content/drive/MyDrive/img_cls_to_vit'

**INSTALL ALLOWED LIBRARIES:**

In [3]:
from time import time
start = time()
!pip install torch
!pip install torchvision
!pip install pillow
!pip install tqdm
print(f'Pip installed torch, torchvision, pillow and tqdm in {round((time() - start)/60, 2)} mins')

Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m64.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.6/823.6 kB[0m [31m66.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.1/14.1 MB[0m [31m96.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Downloading nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

For TPU, Pip installed torch, torchvision, pillow and tqdm in 1365 secs, i.e. 23 mins. i.e. it's bloody slow!!!


For T4 GPU, Pip installed torch, torchvision, pillow and tqdm in 1.27 mins. i.e. it's much faster !!

In [25]:
# FOR COLAB THIS GIVES: Ubuntu 22.04.3 LTS
# !cat /etc/*release

**IMPORT `CIFAR-10` IMAGE TRAINING DATASET,
TRANSFORM & LOAD TO DATALOADER & ITERATOR,
LOOK AT EXAMPLE OF A BATCH OF IMAGES**<br>
(The CIFAR-10 dataset consists of 60,000 32x32 color images in 10 classes, with 6,000 images per class. It is divided into 50,000 training images and 10,000 test images. The 10 classes are: plane, car, bird, cat, deer, dog, frog, horse, ship & truck.)

In [15]:
# CAN SKIP THIS CELL. IT'S JUST TO MAKE A PNG OF 20 EXAMPLE IMAGES
import torchvision.transforms as tv_transforms
import torchvision.transforms as tv_transforms
import torchvision.datasets as tv_datasets
import torch
from PIL import Image

batch_size = 20
pretrained_transforms = tv_transforms.Compose([
    tv_transforms.Resize((224, 224)),
    tv_transforms.ToTensor(),
    tv_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
trainset = tv_datasets.CIFAR10(root='./data', train=True, download=True, transform=pretrained_transforms)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
dataiter = iter(trainloader)
images, labels = next(dataiter) # note: for pytorch versions (<1.14) use dataiter.next()

# Assuming these are the correct normalization parameters used in your pretrained_transforms
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)

# Concatenate batch of images into a single image
images_concat = torch.cat(images.split(1, 0), 3).squeeze()
# De-normalize
for i in range(3):  # Assuming RGB images
    images_concat[i] = images_concat[i] * std[i] + mean[i]
# Clamp the values to ensure they are between 0 and 1 (this may not be necessary if values are already scaled correctly)
images_concat = torch.clamp(images_concat, 0, 1)
# Convert to numpy array and then to a PIL Image
im = Image.fromarray((images_concat.permute(1, 2, 0).numpy() * 255).astype('uint8'))
im.save("train_images_corrected.png")
print('train_images.png saved.')
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
print('Ground truth labels:' + ' '.join('%5s' % classes[labels[j]] for j in range(batch_size)))


Files already downloaded and verified
train_images.png saved.
Ground truth labels: frog plane  frog horse   cat   dog  ship  deer  deer   cat   car horse  ship   car   cat horse  bird plane   dog truck


In [5]:
from time import time
from tqdm import tqdm
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as tv_transforms
import torchvision.datasets as tv_datasets
from PIL import Image
device = torch.device('cuda' if torch.cuda.is_available() else 'mps'
                      if torch.backends.mps.is_available() else 'cpu')
print(f'Using {device} device')

# # may be useful if using CPU
# import multiprocessing
# cpu_count_mp = multiprocessing.cpu_count()
# print(f'cpu count {cpu_count_mp}')
# TPU have a CPU count of 40 !

Using cuda device


### SET FLAG TO TRUE TO DO INFERENCE ONLY

In [29]:
# SET THIS FLAG TO TRUE IF YOU JUST WANT TO DO INFERENCE (AND NOT DO FINE-TUNING OF PRETRAINED VIT MODEL):
# load_finetuned_vit_for_inference_only = True
load_finetuned_vit_for_inference_only = False

if load_finetuned_vit_for_inference_only:
    print('You are loading a CIFAR-10-fine-tuned pretrained ViT for inference in testing loop only.')
    pretrained_vit = torchvision.models.vit_b_16()
    pretrained_vit.heads = nn.Sequential(nn.Linear(in_features=768, out_features=10))
    # print('\nWeights before loading saved model:')
    # print(pretrained_vit.heads[0].weight.data)
    saved_model_path = 'saved_models/pretrained_finetuned/vit_finetuned.pt'
    pretrained_vit.load_state_dict(torch.load(saved_model_path, map_location=torch.device('cuda')))
    # print('\nWeights after loading saved model:')
    # print(pretrained_vit.heads[0].weight.data)
    pretrained_transforms = tv_transforms.Compose([
        tv_transforms.Resize((224, 224)),
        tv_transforms.ToTensor(),
        tv_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

else:# OTHERWISE TO FINE-TUNE PRETRAINED VIT (ON CIFAR-10): FREEZE WEIGHTS AND THEN ADD TO HEAD:
    print('You are loading a pretrained ViT for fine-tuning in training loop (with inference as well in testing loop).')
    pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT
    pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights)
    pretrained_transforms = pretrained_vit_weights.transforms()
    # FREEZE MODEL PARAMETERS TO PERFORM TRANSFER LEARNING (I.E FINE-TUNING):
    for params in pretrained_vit.parameters():
        params.requires_grad=False
    pretrained_vit.heads = nn.Sequential(nn.Linear(in_features=768, out_features=10))

pretrained_vit.to(device)
print('print here to prevent entire architecture print out from previous line')

You are loading a pretrained ViT for fine-tuning in training loop (with inference as well in testing loop).
print here to prevent entire architecture print out from previous line


**TRAIN MODEL FOR 20 EPOCHS:**

In [None]:
# I PUT test_inference() function and MixUp class inside this cell for debugging purposes, but you can just fold them up
def test_inference(pretrained_vit, testloader, criterion):
    test_start = time()
    pretrained_vit.eval()
    test_loss_per_epoch, test_accuracy = 0, 0

    with torch.inference_mode():
        for i, data in enumerate(testloader):
            inputs, labels = data[0].to(device), data[1].to(device)
            y_preds = pretrained_vit(inputs)
            # y_preds.shape is .  these are   logits.
            loss = criterion(y_preds, labels)

            test_pred_labels = y_preds.argmax(dim=1)
            test_accuracy += ((test_pred_labels == labels).sum().item()/len(test_pred_labels))
            test_loss_per_epoch += loss.item()
            if i % 2000 == 1999:
                print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, test_loss_per_epoch / 2000))
                test_loss_per_epoch = 0.0

    test_loss_per_epoch = test_loss_per_epoch / len(testloader)
    test_accuracy = test_accuracy / len(testloader)
    return test_loss_per_epoch, test_accuracy

batch_size = 20
testset = tv_datasets.CIFAR10(root='./data', train=False, download=True, transform=pretrained_transforms)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
criterion = torch.nn.CrossEntropyLoss()
criterion = criterion.to(device)

print(f'load_finetuned_vit_for_inference_only={load_finetuned_vit_for_inference_only}')

if load_finetuned_vit_for_inference_only:
    # INFERENCE (ON TEST-SET) ONLY:
    print(f'One-off inference only')
    test_loss, test_acc = test_inference(pretrained_vit, testloader, criterion)

else:
    # FINE-TUNE AND EVALUATE ON TEST SET FOR 20 EPOCHS:
    print('20 epochs of fine-tuning and evalutation on test set per epoch.')
    trainset = tv_datasets.CIFAR10(root='./data', train=True, download=True, transform=pretrained_transforms)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
    opt = torch.optim.Adam(pretrained_vit.parameters(), lr=0.003)

    class MixUp(nn.Module):

        def augment(self, device, X, y, batch_size, sampling_method, alpha=0.2):
            """
            If sampling_method is 1: λ is sampled from a beta distribution as described in Zhang et al 2018.
            If sampling_method is 2: λ is sampled uniformly from a predefined range.

            "For mixup, we find that αlpha ∈ [0.1, 0.4] leads to improved performance over ERM,
            whereas for large αlpha, mixup leads to underfitting." Zhang et al.
            """
            np.random.seed(42)

            if sampling_method == 2:
                lambda_ = np.random.uniform(low=0.0, high=1.0)
            else:
                lambda_ = np.random.beta(alpha, alpha)

            lam = torch.tensor(lambda_, device=device)

            random_i = torch.randperm(batch_size).to(device)
            X2 = X[random_i, :, :, :]
            y2 = y[random_i]

            y = F.one_hot(y, num_classes=10) * 1.0
            y2 = F.one_hot(y2, num_classes=10) * 1.0
            new_X = (lam * X) + ((1. - lam) * X2)
            new_y = (lam * y) + ((1. - lam) * y2)

            return new_X, new_y

    epochs = 20
    mix_up = MixUp()

    start = time()
    for epoch in tqdm(range(epochs)):
        print(f'\nEpoch number {epoch}')
        train_loss_per_epoch, test_loss_per_epoch = 0, 0

        pretrained_vit.train()

        train_losses = torch.zeros(epochs)
        test_losses = torch.zeros(epochs)
        train_accs = torch.zeros(epochs)
        test_accs = torch.zeros(epochs)
        train_accuracy, test_accuracy = 0, 0

        for i, data in enumerate(trainloader):
            inputs, labels = data[0].to(device), data[1].to(device)
            sampling_method = 1  # 1 FOR UNIFORM
            # sampling_method = 2  # 2 for BETA
            X, y = mix_up.augment(device=device, X=inputs, y=labels, sampling_method=sampling_method, batch_size=inputs.shape[0])
            opt.zero_grad()
            y_preds = pretrained_vit(X)
            y_preds_class = torch.argmax(torch.softmax(y_preds, dim=1), dim=1)
            y = torch.argmax(y, dim=1) # convert one-hot back to original
            train_accuracy += (y_preds_class == y).sum().item()/len(y_preds)
            loss = criterion(y_preds, y)
            loss.backward()
            opt.step()

            train_loss_per_epoch += loss.item()
            if i % 2000 == 1999:  # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, train_loss_per_epoch / 2000))

        train_loss_per_epoch = train_loss_per_epoch / len(trainloader)
        train_accuracy = train_accuracy / len(trainloader)
        print(f'train_accuracy {train_accuracy}')

        print(f'type(train_loss_per_epoch) {type(train_loss_per_epoch)}')
        train_losses[epoch] = train_loss_per_epoch
        train_accs[epoch] = train_accuracy
        print(f'Epoch {epoch}, training took {round(((time() - start) / 60), 4)} mins')

        # 2. EVALUATE ON TEST-SET AFTER EACH EPOCH OF FINE-TUNING:
        test_start = time()
        test_loss_per_epoch, test_accuracy = test_inference(pretrained_vit, testloader, criterion)

        print(f'type(test_loss_per_epoch) {type(test_loss_per_epoch)}')
        test_losses[epoch] = test_loss_per_epoch
        test_accs[epoch] = test_accuracy
        print(f'Epoch {epoch}, test took {round(((time() - test_start) / 60), 4)} mins')

    # # SAVE LOSSES & ACCURACIES FOR EACH OF 20 EPOCHS TO CSV FILES:
    train_losses_np = train_losses.cpu().numpy()
    test_losses_np = test_losses.cpu().numpy()
    test_accs_np = test_accs.cpu().numpy()
    train_accs_np = train_accs.cpu().numpy()

    losses_accs_dirs = f'saved_models/acc_losses/sm_{sampling_method}'
    if not os.path.exists(losses_accs_dirs): os.makedirs(losses_accs_dirs)
    vit_train_losses_path = os.path.join(losses_accs_dirs, 'train_losses_np.csv')
    vit_test_losses_path = os.path.join(losses_accs_dirs, 'test_losses_np.csv')
    vit_train_accs_path = os.path.join(losses_accs_dirs, 'train_accs_np.csv')
    vit_test_accs_path = os.path.join(losses_accs_dirs, 'test_accs_np.csv')

    np.savetxt(vit_test_losses_path, test_losses_np, delimiter=',')
    np.savetxt(vit_test_accs_path, test_accs_np, delimiter=',')
    np.savetxt(vit_train_losses_path, train_losses_np, delimiter=',')
    np.savetxt(vit_train_accs_path, train_accs_np, delimiter=',')

    print(f'Classification accuracy per epoch for test set= {test_losses}')
    print(f'END - Fine-tuning model for {epochs} epochs took {round(((time() - start) / 60), 4)} mins')

# TRUE & TRUE IF YOU'VE JUST FINE-TUNED THE PRETRAINED MODEL ON CIFAR-10 AND WANT TO SAVE THE TRAIN LOSSES & ACCURACIES:

# TRUE & TRUE IF YOU'VE JUST FINE-TUNED THE PRETRAINED MODEL ON CIFAR-10
# AND YOU WANT TO SAVE THE NEW WEIGHTS (WARNING: OVER-WRITES PATH-FILENAME):
save_fine_tuned_model = False
if not load_finetuned_vit_for_inference_only:
    tuned_model_dirs = f'saved_models/pretrained_finetuned/sm_{sampling_method}'
    if not os.path.exists(tuned_model_dirs): os.makedirs(tuned_model_dirs)
    fine_tuned_path = os.path.join(tuned_model_dirs, 'vit_finetuned.pt')
    torch.save(pretrained_vit.state_dict(), fine_tuned_path)
    print('Trained model saved.')

Files already downloaded and verified
load_finetuned_vit_for_inference_only=False
20 epochs of fine-tuning and evalutation on test set per epoch.
Files already downloaded and verified


  0%|          | 0/20 [00:00<?, ?it/s]


Epoch number 0


In [None]:
# I PUT test_inference() function and MixUp class inside this cell for debugging purposes, but you can just fold them up
def test_inference(pretrained_vit, testloader, criterion):
    test_start = time()
    pretrained_vit.eval()
    test_loss_per_epoch, test_accuracy = 0, 0

    with torch.inference_mode():
        for i, data in enumerate(testloader):
            inputs, labels = data[0].to(device), data[1].to(device)
            y_preds = pretrained_vit(inputs)
            # y_preds.shape is .  these are   logits.
            loss = criterion(y_preds, labels)

            test_pred_labels = y_preds.argmax(dim=1)
            test_accuracy += ((test_pred_labels == labels).sum().item()/len(test_pred_labels))
            test_loss_per_epoch += loss.item()
            if i % 2000 == 1999:
                print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, test_loss_per_epoch / 2000))
                test_loss_per_epoch = 0.0

    test_loss_per_epoch = test_loss_per_epoch / len(testloader)
    test_accuracy = test_accuracy / len(testloader)
    return test_loss_per_epoch, test_accuracy

batch_size = 20
testset = tv_datasets.CIFAR10(root='./data', train=False, download=True, transform=pretrained_transforms)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
criterion = torch.nn.CrossEntropyLoss()
criterion = criterion.to(device)

load_finetuned_vit_for_inference_only = True
print(f'load_finetuned_vit_for_inference_only={load_finetuned_vit_for_inference_only}')

if load_finetuned_vit_for_inference_only:
    # INFERENCE (ON TEST-SET) ONLY:
    print(f'One-off inference only')
    test_loss, test_acc = test_inference(pretrained_vit, testloader, criterion)

else:
    # FINE-TUNE AND EVALUATE ON TEST SET FOR 20 EPOCHS:
    print('20 epochs of fine-tuning and evalutation on test set per epoch.')
    trainset = tv_datasets.CIFAR10(root='./data', train=True, download=True, transform=pretrained_transforms)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
    opt = torch.optim.Adam(pretrained_vit.parameters(), lr=0.003)

    class MixUp(nn.Module):

        def augment(self, device, X, y, batch_size, sampling_method, alpha=0.2):
            """
            If sampling_method is 1: λ is sampled from a beta distribution as described in Zhang et al 2018.
            If sampling_method is 2: λ is sampled uniformly from a predefined range.

            "For mixup, we find that αlpha ∈ [0.1, 0.4] leads to improved performance over ERM,
            whereas for large αlpha, mixup leads to underfitting." Zhang et al.
            """
            np.random.seed(42)

            if sampling_method == 2:
                lambda_ = np.random.uniform(low=0.0, high=1.0)
            else:
                lambda_ = np.random.beta(alpha, alpha)

            lam = torch.tensor(lambda_, device=device)

            random_i = torch.randperm(batch_size).to(device)
            X2 = X[random_i, :, :, :]
            y2 = y[random_i]

            y = F.one_hot(y, num_classes=10) * 1.0
            y2 = F.one_hot(y2, num_classes=10) * 1.0
            new_X = (lam * X) + ((1. - lam) * X2)
            new_y = (lam * y) + ((1. - lam) * y2)

            return new_X, new_y

    epochs = 20
    mix_up = MixUp()

    start = time()
    for epoch in tqdm(range(epochs)):
        print(f'\nEpoch number {epoch}')
        train_loss_per_epoch, test_loss_per_epoch = 0, 0

        pretrained_vit.train()

        train_losses = torch.zeros(epochs)
        test_losses = torch.zeros(epochs)
        train_accs = torch.zeros(epochs)
        test_accs = torch.zeros(epochs)
        train_accuracy, test_accuracy = 0, 0

        for i, data in enumerate(trainloader):
            inputs, labels = data[0].to(device), data[1].to(device)
            # sampling_method = 1  # 1 FOR UNIFORM
            sampling_method = 2  # 2 for BETA
            X, y = mix_up.augment(device=device, X=inputs, y=labels, sampling_method=sampling_method, batch_size=inputs.shape[0])
            opt.zero_grad()
            y_preds = pretrained_vit(X)
            y_preds_class = torch.argmax(torch.softmax(y_preds, dim=1), dim=1)
            y = torch.argmax(y, dim=1) # convert one-hot back to original
            train_accuracy += (y_preds_class == y).sum().item()/len(y_preds)
            loss = criterion(y_preds, y)
            loss.backward()
            opt.step()

            train_loss_per_epoch += loss.item()
            if i % 2000 == 1999:  # print every 2000 mini-batches
                print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, train_loss_per_epoch / 2000))

        train_loss_per_epoch = train_loss_per_epoch / len(trainloader)
        train_accuracy = train_accuracy / len(trainloader)
        print(f'train_accuracy {train_accuracy}')

        print(f'type(train_loss_per_epoch) {type(train_loss_per_epoch)}')
        train_losses[epoch] = train_loss_per_epoch
        train_accs[epoch] = train_accuracy
        print(f'Epoch {epoch}, training took {round(((time() - start) / 60), 4)} mins')

        # 2. EVALUATE ON TEST-SET AFTER EACH EPOCH OF FINE-TUNING:
        test_start = time()
        test_loss_per_epoch, test_accuracy = test_inference(pretrained_vit, testloader, criterion)

        print(f'type(test_loss_per_epoch) {type(test_loss_per_epoch)}')
        test_losses[epoch] = test_loss_per_epoch
        test_accs[epoch] = test_accuracy
        print(f'Epoch {epoch}, test took {round(((time() - test_start) / 60), 4)} mins')

    # # SAVE LOSSES & ACCURACIES FOR EACH OF 20 EPOCHS TO CSV FILES:
    train_losses_np = train_losses.cpu().numpy()
    test_losses_np = test_losses.cpu().numpy()
    test_accs_np = test_accs.cpu().numpy()
    train_accs_np = train_accs.cpu().numpy()

    losses_accs_dirs = f'saved_models/acc_losses/sm_{sampling_method}'
    if not os.path.exists(losses_accs_dirs): os.makedirs(losses_accs_dirs)
    vit_train_losses_path = os.path.join(losses_accs_dirs, 'train_losses_np.csv')
    vit_test_losses_path = os.path.join(losses_accs_dirs, 'test_losses_np.csv')
    vit_train_accs_path = os.path.join(losses_accs_dirs, 'train_accs_np.csv')
    vit_test_accs_path = os.path.join(losses_accs_dirs, 'test_accs_np.csv')

    np.savetxt(vit_test_losses_path, test_losses_np, delimiter=',')
    np.savetxt(vit_test_accs_path, test_accs_np, delimiter=',')
    np.savetxt(vit_train_losses_path, train_losses_np, delimiter=',')
    np.savetxt(vit_train_accs_path, train_accs_np, delimiter=',')

    print(f'Classification accuracy per epoch for test set= {test_losses}')
    print(f'END - Fine-tuning model for {epochs} epochs took {round(((time() - start) / 60), 4)} mins')

# TRUE & TRUE IF YOU'VE JUST FINE-TUNED THE PRETRAINED MODEL ON CIFAR-10 AND WANT TO SAVE THE TRAIN LOSSES & ACCURACIES:

# TRUE & TRUE IF YOU'VE JUST FINE-TUNED THE PRETRAINED MODEL ON CIFAR-10
# AND YOU WANT TO SAVE THE NEW WEIGHTS (WARNING: OVER-WRITES PATH-FILENAME):
save_fine_tuned_model = False
if not load_finetuned_vit_for_inference_only:
    tuned_model_dirs = f'saved_models/pretrained_finetuned/sm_{sampling_method}'
    if not os.path.exists(tuned_model_dirs): os.makedirs(tuned_model_dirs)
    fine_tuned_path = os.path.join(tuned_model_dirs, 'vit_finetuned.pt')
    torch.save(pretrained_vit.state_dict(), fine_tuned_path)
    print('Trained model saved.')