In [1]:
import os
import sys

# Save current working directory
cwd = os.getcwd()

# Change to parent directory
parent_dir = os.path.abspath(os.path.join(cwd, '..'))
os.chdir(parent_dir)

# Temporarily add parent directory to sys.path
sys.path.insert(0, parent_dir)
import framework
sys.path.pop(0)

# Return to original directory
os.chdir(cwd)

In [2]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.models import vit_b_32, ViT_B_32_Weights
from torch import nn, optim
from tqdm import tqdm


weights = ViT_B_32_Weights.DEFAULT
# Transform matching ImageNet-trained model
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    weights.transforms()
])


# CIFAR-10 test data
dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
loader = DataLoader(dataset, batch_size=32, shuffle=False)


Files already downloaded and verified


In [3]:
from torch import nn

# Load pretrained model
model = vit_b_32(weights=weights)
model.eval()

model.heads = nn.Linear(in_features = 768, out_features = 100, bias = True)

device = "cuda"

In [None]:
# Datasets & loaders
train_data = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
test_data = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.to(device)


# Fine-tuning loop
for epoch in range(1):  # adjust epochs as needed
    model.train()
    running_loss = 0.0
    for batch_idx, (inputs, labels) in enumerate(train_loader, desc=f"Epoch {epoch+1}"):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        
        print(f"Epoch {epoch+1}, Batch {batch_idx+1}/{len(train_loader)}, Loss: {loss.item():.4f}")
        
    print(f"Epoch {epoch+1}, Average Loss: {running_loss/len(train_loader):.4f}")



Files already downloaded and verified
Files already downloaded and verified


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
Epoch 1:   0%|          | 1/782 [00:00<08:07,  1.60it/s]

Epoch 1, Batch 1/782, Loss: 4.8034


Epoch 1:   0%|          | 2/782 [00:01<06:49,  1.91it/s]

Epoch 1, Batch 2/782, Loss: 4.8527


Epoch 1:   0%|          | 3/782 [00:01<06:18,  2.06it/s]

Epoch 1, Batch 3/782, Loss: 4.7932


Epoch 1:   1%|          | 4/782 [00:01<05:59,  2.16it/s]

Epoch 1, Batch 4/782, Loss: 4.8101


Epoch 1:   1%|          | 5/782 [00:02<05:49,  2.22it/s]

Epoch 1, Batch 5/782, Loss: 4.7321


Epoch 1:   1%|          | 6/782 [00:02<05:52,  2.20it/s]

Epoch 1, Batch 6/782, Loss: 4.8611


Epoch 1:   1%|          | 7/782 [00:03<05:50,  2.21it/s]

Epoch 1, Batch 7/782, Loss: 4.8361


Epoch 1:   1%|          | 8/782 [00:03<05:48,  2.22it/s]

Epoch 1, Batch 8/782, Loss: 5.0251


Epoch 1:   1%|          | 9/782 [00:04<05:46,  2.23it/s]

Epoch 1, Batch 9/782, Loss: 5.0129


Epoch 1:   1%|▏         | 10/782 [00:04<05:44,  2.24it/s]

Epoch 1, Batch 10/782, Loss: 4.8517


Epoch 1:   1%|▏         | 11/782 [00:05<05:41,  2.25it/s]

Epoch 1, Batch 11/782, Loss: 4.7415


Epoch 1:   2%|▏         | 12/782 [00:05<05:43,  2.24it/s]

Epoch 1, Batch 12/782, Loss: 4.8010


Epoch 1:   2%|▏         | 13/782 [00:05<05:42,  2.25it/s]

Epoch 1, Batch 13/782, Loss: 4.8131


Epoch 1:   2%|▏         | 14/782 [00:06<05:40,  2.26it/s]

Epoch 1, Batch 14/782, Loss: 4.7482


Epoch 1:   2%|▏         | 15/782 [00:06<05:40,  2.25it/s]

Epoch 1, Batch 15/782, Loss: 4.6909


Epoch 1:   2%|▏         | 16/782 [00:07<05:41,  2.24it/s]

Epoch 1, Batch 16/782, Loss: 4.7511


Epoch 1:   2%|▏         | 17/782 [00:07<05:42,  2.24it/s]

Epoch 1, Batch 17/782, Loss: 4.7847


Epoch 1:   2%|▏         | 18/782 [00:08<05:40,  2.24it/s]

Epoch 1, Batch 18/782, Loss: 4.8090


Epoch 1:   2%|▏         | 19/782 [00:08<05:37,  2.26it/s]

Epoch 1, Batch 19/782, Loss: 4.7668


Epoch 1:   3%|▎         | 20/782 [00:09<05:36,  2.27it/s]

Epoch 1, Batch 20/782, Loss: 4.8267


Epoch 1:   3%|▎         | 21/782 [00:09<05:39,  2.24it/s]

Epoch 1, Batch 21/782, Loss: 4.7734


Epoch 1:   3%|▎         | 22/782 [00:09<05:39,  2.24it/s]

Epoch 1, Batch 22/782, Loss: 4.7607


Epoch 1:   3%|▎         | 23/782 [00:10<05:37,  2.25it/s]

Epoch 1, Batch 23/782, Loss: 4.7100


Epoch 1:   3%|▎         | 24/782 [00:10<05:45,  2.19it/s]

Epoch 1, Batch 24/782, Loss: 4.6413


Epoch 1:   3%|▎         | 25/782 [00:11<05:42,  2.21it/s]

Epoch 1, Batch 25/782, Loss: 4.7368


Epoch 1:   3%|▎         | 26/782 [00:11<05:43,  2.20it/s]

Epoch 1, Batch 26/782, Loss: 4.7372


Epoch 1:   3%|▎         | 27/782 [00:12<05:41,  2.21it/s]

Epoch 1, Batch 27/782, Loss: 4.6388


Epoch 1:   4%|▎         | 28/782 [00:12<05:41,  2.21it/s]

Epoch 1, Batch 28/782, Loss: 4.6863


Epoch 1:   4%|▎         | 29/782 [00:13<05:41,  2.20it/s]

Epoch 1, Batch 29/782, Loss: 4.8721


Epoch 1:   4%|▍         | 30/782 [00:13<05:41,  2.20it/s]

Epoch 1, Batch 30/782, Loss: 4.6796


Epoch 1:   4%|▍         | 31/782 [00:14<05:40,  2.20it/s]

Epoch 1, Batch 31/782, Loss: 4.7363


Epoch 1:   4%|▍         | 32/782 [00:14<05:40,  2.21it/s]

Epoch 1, Batch 32/782, Loss: 4.7278


Epoch 1:   4%|▍         | 33/782 [00:14<05:39,  2.21it/s]

Epoch 1, Batch 33/782, Loss: 4.5616


Epoch 1:   4%|▍         | 34/782 [00:15<05:38,  2.21it/s]

Epoch 1, Batch 34/782, Loss: 4.5765


Epoch 1:   4%|▍         | 35/782 [00:15<05:38,  2.21it/s]

Epoch 1, Batch 35/782, Loss: 4.6049


Epoch 1:   5%|▍         | 36/782 [00:16<05:36,  2.22it/s]

Epoch 1, Batch 36/782, Loss: 4.6363


Epoch 1:   5%|▍         | 37/782 [00:16<05:35,  2.22it/s]

Epoch 1, Batch 37/782, Loss: 4.5705


Epoch 1:   5%|▍         | 38/782 [00:17<05:38,  2.20it/s]

Epoch 1, Batch 38/782, Loss: 4.6317


Epoch 1:   5%|▍         | 39/782 [00:17<05:36,  2.21it/s]

Epoch 1, Batch 39/782, Loss: 4.6763


Epoch 1:   5%|▌         | 40/782 [00:18<05:34,  2.22it/s]

Epoch 1, Batch 40/782, Loss: 4.5799


Epoch 1:   5%|▌         | 41/782 [00:18<05:40,  2.18it/s]

Epoch 1, Batch 41/782, Loss: 4.5922


Epoch 1:   5%|▌         | 42/782 [00:19<05:37,  2.19it/s]

Epoch 1, Batch 42/782, Loss: 4.5683


Epoch 1:   5%|▌         | 43/782 [00:19<05:35,  2.20it/s]

Epoch 1, Batch 43/782, Loss: 4.4529


Epoch 1:   6%|▌         | 44/782 [00:19<05:33,  2.21it/s]

Epoch 1, Batch 44/782, Loss: 4.6087


Epoch 1:   6%|▌         | 45/782 [00:20<05:31,  2.22it/s]

Epoch 1, Batch 45/782, Loss: 4.5418


Epoch 1:   6%|▌         | 46/782 [00:20<05:32,  2.22it/s]

Epoch 1, Batch 46/782, Loss: 4.6073


Epoch 1:   6%|▌         | 47/782 [00:21<05:31,  2.22it/s]

Epoch 1, Batch 47/782, Loss: 4.5309


Epoch 1:   6%|▌         | 48/782 [00:21<05:31,  2.21it/s]

Epoch 1, Batch 48/782, Loss: 4.7009


Epoch 1:   6%|▋         | 49/782 [00:22<05:29,  2.22it/s]

Epoch 1, Batch 49/782, Loss: 4.5617


Epoch 1:   6%|▋         | 50/782 [00:22<05:34,  2.19it/s]

Epoch 1, Batch 50/782, Loss: 4.5815


Epoch 1:   7%|▋         | 51/782 [00:23<05:29,  2.22it/s]

Epoch 1, Batch 51/782, Loss: 4.6556


Epoch 1:   7%|▋         | 52/782 [00:23<05:27,  2.23it/s]

Epoch 1, Batch 52/782, Loss: 4.6917


Epoch 1:   7%|▋         | 53/782 [00:23<05:26,  2.24it/s]

Epoch 1, Batch 53/782, Loss: 4.4310


Epoch 1:   7%|▋         | 54/782 [00:24<05:26,  2.23it/s]

Epoch 1, Batch 54/782, Loss: 4.5651


Epoch 1:   7%|▋         | 55/782 [00:24<05:28,  2.21it/s]

Epoch 1, Batch 55/782, Loss: 4.5139


Epoch 1:   7%|▋         | 56/782 [00:25<05:29,  2.20it/s]

Epoch 1, Batch 56/782, Loss: 4.6098


Epoch 1:   7%|▋         | 57/782 [00:25<05:30,  2.19it/s]

Epoch 1, Batch 57/782, Loss: 4.3281


Epoch 1:   7%|▋         | 58/782 [00:26<05:31,  2.18it/s]

Epoch 1, Batch 58/782, Loss: 4.4920


Epoch 1:   8%|▊         | 59/782 [00:26<05:31,  2.18it/s]

Epoch 1, Batch 59/782, Loss: 4.2986


Epoch 1:   8%|▊         | 60/782 [00:27<05:32,  2.17it/s]

Epoch 1, Batch 60/782, Loss: 4.4280


Epoch 1:   8%|▊         | 61/782 [00:27<05:30,  2.18it/s]

Epoch 1, Batch 61/782, Loss: 4.4074


Epoch 1:   8%|▊         | 62/782 [00:28<05:27,  2.20it/s]

Epoch 1, Batch 62/782, Loss: 4.4388


Epoch 1:   8%|▊         | 63/782 [00:28<05:25,  2.21it/s]

Epoch 1, Batch 63/782, Loss: 4.1828


Epoch 1:   8%|▊         | 64/782 [00:28<05:24,  2.22it/s]

Epoch 1, Batch 64/782, Loss: 4.6191


Epoch 1:   8%|▊         | 65/782 [00:29<05:24,  2.21it/s]

Epoch 1, Batch 65/782, Loss: 4.5762


Epoch 1:   8%|▊         | 66/782 [00:29<05:24,  2.21it/s]

Epoch 1, Batch 66/782, Loss: 4.5005


Epoch 1:   9%|▊         | 67/782 [00:30<05:30,  2.16it/s]

Epoch 1, Batch 67/782, Loss: 4.3518


Epoch 1:   9%|▊         | 68/782 [00:30<05:25,  2.19it/s]

Epoch 1, Batch 68/782, Loss: 4.3356


Epoch 1:   9%|▉         | 69/782 [00:31<05:24,  2.20it/s]

Epoch 1, Batch 69/782, Loss: 4.4043


Epoch 1:   9%|▉         | 70/782 [00:31<05:23,  2.20it/s]

Epoch 1, Batch 70/782, Loss: 4.5733


Epoch 1:   9%|▉         | 71/782 [00:32<05:22,  2.21it/s]

Epoch 1, Batch 71/782, Loss: 4.3691


Epoch 1:   9%|▉         | 72/782 [00:32<05:18,  2.23it/s]

Epoch 1, Batch 72/782, Loss: 4.2784


Epoch 1:   9%|▉         | 73/782 [00:33<05:15,  2.25it/s]

Epoch 1, Batch 73/782, Loss: 4.4908


Epoch 1:   9%|▉         | 74/782 [00:33<05:16,  2.24it/s]

Epoch 1, Batch 74/782, Loss: 4.4911


Epoch 1:  10%|▉         | 75/782 [00:33<05:15,  2.24it/s]

Epoch 1, Batch 75/782, Loss: 4.5226


Epoch 1:  10%|▉         | 76/782 [00:34<05:19,  2.21it/s]

Epoch 1, Batch 76/782, Loss: 4.4302


Epoch 1:  10%|▉         | 77/782 [00:34<05:20,  2.20it/s]

Epoch 1, Batch 77/782, Loss: 4.2353


Epoch 1:  10%|▉         | 78/782 [00:35<05:20,  2.20it/s]

Epoch 1, Batch 78/782, Loss: 4.4867


Epoch 1:  10%|█         | 79/782 [00:35<05:16,  2.22it/s]

Epoch 1, Batch 79/782, Loss: 4.4924


Epoch 1:  10%|█         | 80/782 [00:36<05:11,  2.25it/s]

Epoch 1, Batch 80/782, Loss: 4.4139


Epoch 1:  10%|█         | 81/782 [00:36<05:11,  2.25it/s]

Epoch 1, Batch 81/782, Loss: 4.4290


Epoch 1:  10%|█         | 82/782 [00:37<05:09,  2.26it/s]

Epoch 1, Batch 82/782, Loss: 4.4299


Epoch 1:  11%|█         | 83/782 [00:37<05:07,  2.27it/s]

Epoch 1, Batch 83/782, Loss: 4.3477


Epoch 1:  11%|█         | 84/782 [00:37<05:09,  2.25it/s]

Epoch 1, Batch 84/782, Loss: 4.5567


Epoch 1:  11%|█         | 85/782 [00:38<05:07,  2.27it/s]

Epoch 1, Batch 85/782, Loss: 4.3080


Epoch 1:  11%|█         | 86/782 [00:38<05:12,  2.23it/s]

Epoch 1, Batch 86/782, Loss: 4.3376


Epoch 1:  11%|█         | 87/782 [00:39<05:12,  2.22it/s]

Epoch 1, Batch 87/782, Loss: 4.4719


Epoch 1:  11%|█▏        | 88/782 [00:39<05:12,  2.22it/s]

Epoch 1, Batch 88/782, Loss: 4.4734


Epoch 1:  11%|█▏        | 89/782 [00:40<05:10,  2.23it/s]

Epoch 1, Batch 89/782, Loss: 4.3036


Epoch 1:  11%|█▏        | 89/782 [00:40<05:16,  2.19it/s]


KeyboardInterrupt: 

In [5]:
# Evaluation
model.eval()
correct = total = 0
with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

print(f"Top-1 Accuracy on CIFAR-100 (fine-tuned): {correct / total:.4f}")

100%|██████████| 157/157 [00:30<00:00,  5.07it/s]

Top-1 Accuracy on CIFAR-100 (fine-tuned): 0.0348





In [6]:
# Save only the model weights (recommended)
torch.save(model.state_dict(), 'model_weights.pth')