In [1]:
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR100
import torchvision.transforms as transforms

# Data loading and splitting
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize for CIFAR-100
])

full_dataset = CIFAR100(root="./data", train=True, download=True, transform=transform)

# Split the dataset into training, validation, and testing sets
train_size = int(0.7 * len(full_dataset))
val_size = int(0.2 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])

batch_size = 256

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:02<00:00, 65040491.21it/s]


Extracting ./data/cifar-100-python.tar.gz to ./data


In [2]:
import torch
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
num_classes = 100  # CIFAR-100 has 100 classes
lr = 0.003 #1e-4
patch_size = 4 #8
num_layers = 12
num_heads = 16 #12
hidden_dim = 512 #768
mlp_dim = 3072
image_size = 32
# max_len = 100 # All sequences must be less than 1000 including class token
# channels = 3

num_epochs = 30

# Model, loss, and optimizer
weights = ResNet50_Weights.DEFAULT
model = resnet50(weights=weights).to(device)

# Freeze all weights except the last layer
for name, param in model.named_parameters():
    if not name.startswith('fc'):  # assuming the last layer is named 'fc'
        param.requires_grad = False

# Model, loss, and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
scheduler = StepLR(optimizer, step_size=5, gamma=0.5)  # Optional: learning rate scheduler


# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        inputs, labels = inputs.to(device), labels.to(device)

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Optional: Adjust learning rate
    scheduler.step()

    # Print average loss per epoch
    average_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {average_loss:.4f}")

    # Validation loop
    model.eval()
    correct = 0
    total = 0
    total_val_loss = 0

    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc="Validation"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            loss = criterion(outputs, labels)

            total_val_loss += loss.item()

    val_accuracy = correct / total
    average_val_loss = total_val_loss / len(val_loader)
    print(f"Validation Accuracy: {val_accuracy * 100:.2f}%, , Validation Loss: {average_val_loss:.4f}")

# Save the trained model
torch.save(model.state_dict(), "resnet_model.pth")

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:01<00:00, 85.1MB/s]
Epoch 1/30: 100%|██████████| 137/137 [00:17<00:00,  7.93it/s]


Epoch 1/30, Loss: 6.2877


Validation: 100%|██████████| 40/40 [00:02<00:00, 14.74it/s]


Validation Accuracy: 20.70%, , Validation Loss: 4.2344


Epoch 2/30: 100%|██████████| 137/137 [00:10<00:00, 13.02it/s]


Epoch 2/30, Loss: 3.3751


Validation: 100%|██████████| 40/40 [00:02<00:00, 15.29it/s]


Validation Accuracy: 24.85%, , Validation Loss: 3.7697


Epoch 3/30: 100%|██████████| 137/137 [00:10<00:00, 13.08it/s]


Epoch 3/30, Loss: 2.7183


Validation: 100%|██████████| 40/40 [00:03<00:00, 13.30it/s]


Validation Accuracy: 26.03%, , Validation Loss: 3.6986


Epoch 4/30: 100%|██████████| 137/137 [00:10<00:00, 13.21it/s]


Epoch 4/30, Loss: 2.4009


Validation: 100%|██████████| 40/40 [00:02<00:00, 15.15it/s]


Validation Accuracy: 26.62%, , Validation Loss: 3.7407


Epoch 5/30: 100%|██████████| 137/137 [00:09<00:00, 13.74it/s]


Epoch 5/30, Loss: 2.2091


Validation: 100%|██████████| 40/40 [00:03<00:00, 12.87it/s]


Validation Accuracy: 26.15%, , Validation Loss: 3.8669


Epoch 6/30: 100%|██████████| 137/137 [00:09<00:00, 15.05it/s]


Epoch 6/30, Loss: 2.0047


Validation: 100%|██████████| 40/40 [00:03<00:00, 10.40it/s]


Validation Accuracy: 26.01%, , Validation Loss: 3.8555


Epoch 7/30: 100%|██████████| 137/137 [00:08<00:00, 15.45it/s]


Epoch 7/30, Loss: 1.9364


Validation: 100%|██████████| 40/40 [00:03<00:00, 10.44it/s]


Validation Accuracy: 26.70%, , Validation Loss: 3.7908


Epoch 8/30: 100%|██████████| 137/137 [00:09<00:00, 14.27it/s]


Epoch 8/30, Loss: 1.8766


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.35it/s]


Validation Accuracy: 26.25%, , Validation Loss: 3.8613


Epoch 9/30: 100%|██████████| 137/137 [00:10<00:00, 13.53it/s]


Epoch 9/30, Loss: 1.8429


Validation: 100%|██████████| 40/40 [00:02<00:00, 15.12it/s]


Validation Accuracy: 26.64%, , Validation Loss: 3.8506


Epoch 10/30: 100%|██████████| 137/137 [00:10<00:00, 12.98it/s]


Epoch 10/30, Loss: 1.7858


Validation: 100%|██████████| 40/40 [00:02<00:00, 15.13it/s]


Validation Accuracy: 26.45%, , Validation Loss: 3.8566


Epoch 11/30: 100%|██████████| 137/137 [00:10<00:00, 12.86it/s]


Epoch 11/30, Loss: 1.7115


Validation: 100%|██████████| 40/40 [00:02<00:00, 14.92it/s]


Validation Accuracy: 26.85%, , Validation Loss: 3.8469


Epoch 12/30: 100%|██████████| 137/137 [00:10<00:00, 13.03it/s]


Epoch 12/30, Loss: 1.7242


Validation: 100%|██████████| 40/40 [00:02<00:00, 15.13it/s]


Validation Accuracy: 26.53%, , Validation Loss: 3.8534


Epoch 13/30: 100%|██████████| 137/137 [00:10<00:00, 13.16it/s]


Epoch 13/30, Loss: 1.6982


Validation: 100%|██████████| 40/40 [00:02<00:00, 15.46it/s]


Validation Accuracy: 27.03%, , Validation Loss: 3.9138


Epoch 14/30: 100%|██████████| 137/137 [00:10<00:00, 13.52it/s]


Epoch 14/30, Loss: 1.6779


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.33it/s]


Validation Accuracy: 27.07%, , Validation Loss: 3.8509


Epoch 15/30: 100%|██████████| 137/137 [00:09<00:00, 14.62it/s]


Epoch 15/30, Loss: 1.6758


Validation: 100%|██████████| 40/40 [00:03<00:00, 11.43it/s]


Validation Accuracy: 26.82%, , Validation Loss: 3.8914


Epoch 16/30: 100%|██████████| 137/137 [00:08<00:00, 15.44it/s]


Epoch 16/30, Loss: 1.6189


Validation: 100%|██████████| 40/40 [00:03<00:00, 10.43it/s]


Validation Accuracy: 26.92%, , Validation Loss: 3.8644


Epoch 17/30: 100%|██████████| 137/137 [00:09<00:00, 14.46it/s]


Epoch 17/30, Loss: 1.6237


Validation: 100%|██████████| 40/40 [00:03<00:00, 13.03it/s]


Validation Accuracy: 26.69%, , Validation Loss: 3.9892


Epoch 18/30: 100%|██████████| 137/137 [00:10<00:00, 13.68it/s]


Epoch 18/30, Loss: 1.6124


Validation: 100%|██████████| 40/40 [00:02<00:00, 15.07it/s]


Validation Accuracy: 26.87%, , Validation Loss: 3.9522


Epoch 19/30: 100%|██████████| 137/137 [00:10<00:00, 13.09it/s]


Epoch 19/30, Loss: 1.6062


Validation: 100%|██████████| 40/40 [00:02<00:00, 15.48it/s]


Validation Accuracy: 26.99%, , Validation Loss: 3.8916


Epoch 20/30: 100%|██████████| 137/137 [00:10<00:00, 13.14it/s]


Epoch 20/30, Loss: 1.6018


Validation: 100%|██████████| 40/40 [00:02<00:00, 14.82it/s]


Validation Accuracy: 26.73%, , Validation Loss: 3.8872


Epoch 21/30: 100%|██████████| 137/137 [00:10<00:00, 13.23it/s]


Epoch 21/30, Loss: 1.5737


Validation: 100%|██████████| 40/40 [00:02<00:00, 14.65it/s]


Validation Accuracy: 27.05%, , Validation Loss: 3.9885


Epoch 22/30: 100%|██████████| 137/137 [00:10<00:00, 12.96it/s]


Epoch 22/30, Loss: 1.5704


Validation: 100%|██████████| 40/40 [00:02<00:00, 15.21it/s]


Validation Accuracy: 26.81%, , Validation Loss: 4.0025


Epoch 23/30: 100%|██████████| 137/137 [00:10<00:00, 13.54it/s]


Epoch 23/30, Loss: 1.5829


Validation: 100%|██████████| 40/40 [00:03<00:00, 13.06it/s]


Validation Accuracy: 27.01%, , Validation Loss: 3.8921


Epoch 24/30: 100%|██████████| 137/137 [00:09<00:00, 15.08it/s]


Epoch 24/30, Loss: 1.5755


Validation: 100%|██████████| 40/40 [00:03<00:00, 10.32it/s]


Validation Accuracy: 26.98%, , Validation Loss: 3.9180


Epoch 25/30: 100%|██████████| 137/137 [00:08<00:00, 15.27it/s]


Epoch 25/30, Loss: 1.5801


Validation: 100%|██████████| 40/40 [00:03<00:00, 11.01it/s]


Validation Accuracy: 26.89%, , Validation Loss: 3.9798


Epoch 26/30: 100%|██████████| 137/137 [00:09<00:00, 14.50it/s]


Epoch 26/30, Loss: 1.5543


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.35it/s]


Validation Accuracy: 26.42%, , Validation Loss: 3.9347


Epoch 27/30: 100%|██████████| 137/137 [00:09<00:00, 13.82it/s]


Epoch 27/30, Loss: 1.5788


Validation: 100%|██████████| 40/40 [00:02<00:00, 15.19it/s]


Validation Accuracy: 26.92%, , Validation Loss: 4.0882


Epoch 28/30: 100%|██████████| 137/137 [00:10<00:00, 13.26it/s]


Epoch 28/30, Loss: 1.5519


Validation: 100%|██████████| 40/40 [00:02<00:00, 14.93it/s]


Validation Accuracy: 26.98%, , Validation Loss: 3.9767


Epoch 29/30: 100%|██████████| 137/137 [00:10<00:00, 13.20it/s]


Epoch 29/30, Loss: 1.5426


Validation: 100%|██████████| 40/40 [00:02<00:00, 15.68it/s]


Validation Accuracy: 27.04%, , Validation Loss: 4.0092


Epoch 30/30: 100%|██████████| 137/137 [00:10<00:00, 13.16it/s]


Epoch 30/30, Loss: 1.5364


Validation: 100%|██████████| 40/40 [00:02<00:00, 15.32it/s]


Validation Accuracy: 26.62%, , Validation Loss: 3.9392


In [None]:
# Evaluation loop
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in tqdm(test_loader, desc="Evaluating"):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print(f"Test Accuracy: {accuracy * 100:.2f}%")

Evaluating: 100%|██████████| 20/20 [00:02<00:00,  8.20it/s]

Test Accuracy: 58.76%



