In [None]:
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR100
import torchvision.transforms as transforms

# Data loading and splitting
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize for CIFAR-100
])

full_dataset = CIFAR100(root="./data", train=True, download=True, transform=transform)

# Split the dataset into training, validation, and testing sets
train_size = int(0.7 * len(full_dataset))
val_size = int(0.2 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])

batch_size = 256

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:02<00:00, 77143703.26it/s]


Extracting ./data/cifar-100-python.tar.gz to ./data


# First Run
Note: Forgot to save trained model as "mobilenet_v3_large.pth" instead of "resnet_model.pth"

In [None]:
import torch
import torch.nn as nn
from torchvision import models, transforms
# from torchvision.models import resnet50, ResNet50_Weights
from torchvision.models import mobilenet_v3_large, MobileNet_V3_Large_Weights
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
num_classes = 100  # CIFAR-100 has 100 classes
lr = 0.003 #1e-4
patch_size = 4 #8
num_layers = 12
num_heads = 16 #12
hidden_dim = 512 #768
mlp_dim = 3072
image_size = 32
# max_len = 100 # All sequences must be less than 1000 including class token
# channels = 3

num_epochs = 30

# Model, loss, and optimizer
# weights = ResNet50_Weights.DEFAULT
weights = MobileNet_V3_Large_Weights.DEFAULT

model = mobilenet_v3_large(weights=weights).to(device)
# model = resnet50(weights=weights).to(device)
#model = resnet50(num_classes=num_classes, image_size=image_size, patch_size=patch_size, num_layers=num_layers, num_heads=num_heads, hidden_dim=hidden_dim, mlp_dim=mlp_dim).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=5, gamma=0.5)  # Optional: learning rate scheduler

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        inputs, labels = inputs.to(device), labels.to(device)

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Optional: Adjust learning rate
    scheduler.step()

    # Print average loss per epoch
    average_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {average_loss:.4f}")

    # Validation loop
    model.eval()
    correct = 0
    total = 0
    total_val_loss = 0

    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc="Validation"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            loss = criterion(outputs, labels)

            total_val_loss += loss.item()

    val_accuracy = correct / total
    average_val_loss = total_val_loss / len(val_loader)
    print(f"Validation Accuracy: {val_accuracy * 100:.2f}%, , Validation Loss: {average_val_loss:.4f}")

# Save the trained model
torch.save(model.state_dict(), "resnet_model.pth")

Downloading: "https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v3_large-5c1a4163.pth
100%|██████████| 21.1M/21.1M [00:00<00:00, 102MB/s]
Epoch 1/30: 100%|██████████| 137/137 [00:14<00:00,  9.36it/s]


Epoch 1/30, Loss: 3.7098


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.65it/s]


Validation Accuracy: 21.77%, , Validation Loss: 3.3256


Epoch 2/30: 100%|██████████| 137/137 [00:14<00:00,  9.73it/s]


Epoch 2/30, Loss: 2.3914


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.62it/s]


Validation Accuracy: 37.11%, , Validation Loss: 2.4927


Epoch 3/30: 100%|██████████| 137/137 [00:13<00:00,  9.96it/s]


Epoch 3/30, Loss: 1.9046


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.63it/s]


Validation Accuracy: 42.19%, , Validation Loss: 2.3009


Epoch 4/30: 100%|██████████| 137/137 [00:13<00:00, 10.04it/s]


Epoch 4/30, Loss: 1.6317


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.65it/s]


Validation Accuracy: 41.78%, , Validation Loss: 2.4138


Epoch 5/30: 100%|██████████| 137/137 [00:13<00:00,  9.89it/s]


Epoch 5/30, Loss: 1.4099


Validation: 100%|██████████| 40/40 [00:03<00:00, 10.39it/s]


Validation Accuracy: 41.95%, , Validation Loss: 2.4647


Epoch 6/30: 100%|██████████| 137/137 [00:13<00:00, 10.29it/s]


Epoch 6/30, Loss: 0.8691


Validation: 100%|██████████| 40/40 [00:04<00:00,  9.63it/s]


Validation Accuracy: 50.16%, , Validation Loss: 2.0983


Epoch 7/30: 100%|██████████| 137/137 [00:13<00:00, 10.01it/s]


Epoch 7/30, Loss: 0.5841


Validation: 100%|██████████| 40/40 [00:03<00:00, 12.83it/s]


Validation Accuracy: 48.47%, , Validation Loss: 2.4198


Epoch 8/30: 100%|██████████| 137/137 [00:14<00:00,  9.69it/s]


Epoch 8/30, Loss: 0.4701


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.51it/s]


Validation Accuracy: 48.77%, , Validation Loss: 2.5707


Epoch 9/30: 100%|██████████| 137/137 [00:13<00:00,  9.91it/s]


Epoch 9/30, Loss: 0.4282


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.40it/s]


Validation Accuracy: 47.17%, , Validation Loss: 2.8017


Epoch 10/30: 100%|██████████| 137/137 [00:13<00:00,  9.97it/s]


Epoch 10/30, Loss: 0.4147


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.85it/s]


Validation Accuracy: 47.06%, , Validation Loss: 2.9722


Epoch 11/30: 100%|██████████| 137/137 [00:13<00:00,  9.96it/s]


Epoch 11/30, Loss: 0.2048


Validation: 100%|██████████| 40/40 [00:03<00:00, 13.07it/s]


Validation Accuracy: 51.18%, , Validation Loss: 2.7362


Epoch 12/30: 100%|██████████| 137/137 [00:13<00:00, 10.25it/s]


Epoch 12/30, Loss: 0.0946


Validation: 100%|██████████| 40/40 [00:04<00:00,  9.47it/s]


Validation Accuracy: 51.13%, , Validation Loss: 2.8607


Epoch 13/30: 100%|██████████| 137/137 [00:13<00:00, 10.08it/s]


Epoch 13/30, Loss: 0.0990


Validation: 100%|██████████| 40/40 [00:03<00:00, 10.66it/s]


Validation Accuracy: 47.74%, , Validation Loss: 3.2626


Epoch 14/30: 100%|██████████| 137/137 [00:13<00:00,  9.83it/s]


Epoch 14/30, Loss: 0.1125


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.57it/s]


Validation Accuracy: 50.62%, , Validation Loss: 3.1077


Epoch 15/30: 100%|██████████| 137/137 [00:13<00:00,  9.92it/s]


Epoch 15/30, Loss: 0.0929


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.43it/s]


Validation Accuracy: 51.00%, , Validation Loss: 3.2069


Epoch 16/30: 100%|██████████| 137/137 [00:13<00:00,  9.88it/s]


Epoch 16/30, Loss: 0.0559


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.46it/s]


Validation Accuracy: 52.21%, , Validation Loss: 3.1382


Epoch 17/30: 100%|██████████| 137/137 [00:13<00:00,  9.85it/s]


Epoch 17/30, Loss: 0.0407


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.45it/s]


Validation Accuracy: 52.29%, , Validation Loss: 3.1908


Epoch 18/30: 100%|██████████| 137/137 [00:13<00:00,  9.85it/s]


Epoch 18/30, Loss: 0.0352


Validation: 100%|██████████| 40/40 [00:03<00:00, 12.23it/s]


Validation Accuracy: 51.87%, , Validation Loss: 3.2387


Epoch 19/30: 100%|██████████| 137/137 [00:13<00:00, 10.01it/s]


Epoch 19/30, Loss: 0.0356


Validation: 100%|██████████| 40/40 [00:04<00:00,  9.32it/s]


Validation Accuracy: 52.28%, , Validation Loss: 3.3054


Epoch 20/30: 100%|██████████| 137/137 [00:13<00:00, 10.37it/s]


Epoch 20/30, Loss: 0.0382


Validation: 100%|██████████| 40/40 [00:03<00:00, 10.69it/s]


Validation Accuracy: 51.82%, , Validation Loss: 3.3784


Epoch 21/30: 100%|██████████| 137/137 [00:13<00:00,  9.94it/s]


Epoch 21/30, Loss: 0.0332


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.51it/s]


Validation Accuracy: 52.13%, , Validation Loss: 3.3650


Epoch 22/30: 100%|██████████| 137/137 [00:13<00:00, 10.04it/s]


Epoch 22/30, Loss: 0.0294


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.63it/s]


Validation Accuracy: 52.61%, , Validation Loss: 3.4048


Epoch 23/30: 100%|██████████| 137/137 [00:13<00:00,  9.83it/s]


Epoch 23/30, Loss: 0.0295


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.39it/s]


Validation Accuracy: 52.55%, , Validation Loss: 3.4077


Epoch 24/30: 100%|██████████| 137/137 [00:13<00:00,  9.81it/s]


Epoch 24/30, Loss: 0.0292


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.53it/s]


Validation Accuracy: 52.39%, , Validation Loss: 3.4464


Epoch 25/30: 100%|██████████| 137/137 [00:13<00:00, 10.06it/s]


Epoch 25/30, Loss: 0.0296


Validation: 100%|██████████| 40/40 [00:03<00:00, 12.44it/s]


Validation Accuracy: 51.94%, , Validation Loss: 3.4650


Epoch 26/30: 100%|██████████| 137/137 [00:13<00:00, 10.25it/s]


Epoch 26/30, Loss: 0.0277


Validation: 100%|██████████| 40/40 [00:04<00:00,  9.56it/s]


Validation Accuracy: 52.31%, , Validation Loss: 3.4642


Epoch 27/30: 100%|██████████| 137/137 [00:13<00:00, 10.53it/s]


Epoch 27/30, Loss: 0.0250


Validation: 100%|██████████| 40/40 [00:03<00:00, 10.13it/s]


Validation Accuracy: 52.61%, , Validation Loss: 3.4753


Epoch 28/30: 100%|██████████| 137/137 [00:13<00:00,  9.98it/s]


Epoch 28/30, Loss: 0.0235


Validation: 100%|██████████| 40/40 [00:03<00:00, 13.26it/s]


Validation Accuracy: 52.49%, , Validation Loss: 3.4818


Epoch 29/30: 100%|██████████| 137/137 [00:13<00:00,  9.94it/s]


Epoch 29/30, Loss: 0.0246


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.68it/s]


Validation Accuracy: 52.47%, , Validation Loss: 3.4960


Epoch 30/30: 100%|██████████| 137/137 [00:14<00:00,  9.78it/s]


Epoch 30/30, Loss: 0.0233


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.65it/s]


Validation Accuracy: 52.60%, , Validation Loss: 3.5170


In [None]:
# Evaluation loop
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in tqdm(test_loader, desc="Evaluating"):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print(f"Test Accuracy: {accuracy * 100:.2f}%")

Evaluating: 100%|██████████| 20/20 [00:01<00:00, 12.36it/s]

Test Accuracy: 53.60%





# Second Run
Notes:
1.   Forgot to save trained model as "mobilenet_v3_large.pth" instead of "resnet_model.pth"
2.   Changed image size from 32 to 224

In [None]:
# new run
import torch
import torch.nn as nn
from torchvision import models, transforms
# from torchvision.models import resnet50, ResNet50_Weights
from torchvision.models import mobilenet_v3_large, MobileNet_V3_Large_Weights
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
num_classes = 100  # CIFAR-100 has 100 classes
lr = 0.003 #1e-4
patch_size = 4 #8
num_layers = 12
num_heads = 16 #12
hidden_dim = 512 #768
mlp_dim = 3072
image_size = 224
# max_len = 100 # All sequences must be less than 1000 including class token
# channels = 3

num_epochs = 30

# Model, loss, and optimizer
# weights = ResNet50_Weights.DEFAULT
weights = MobileNet_V3_Large_Weights.DEFAULT

model = mobilenet_v3_large(weights=weights).to(device)
# model = resnet50(weights=weights).to(device)
#model = resnet50(num_classes=num_classes, image_size=image_size, patch_size=patch_size, num_layers=num_layers, num_heads=num_heads, hidden_dim=hidden_dim, mlp_dim=mlp_dim).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=5, gamma=0.5)  # Optional: learning rate scheduler

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        inputs, labels = inputs.to(device), labels.to(device)

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Optional: Adjust learning rate
    scheduler.step()

    # Print average loss per epoch
    average_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {average_loss:.4f}")

    # Validation loop
    model.eval()
    correct = 0
    total = 0
    total_val_loss = 0

    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc="Validation"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            loss = criterion(outputs, labels)

            total_val_loss += loss.item()

    val_accuracy = correct / total
    average_val_loss = total_val_loss / len(val_loader)
    print(f"Validation Accuracy: {val_accuracy * 100:.2f}%, , Validation Loss: {average_val_loss:.4f}")

# Save the trained model
torch.save(model.state_dict(), "resnet_model.pth")

Epoch 1/30: 100%|██████████| 137/137 [00:13<00:00, 10.02it/s]


Epoch 1/30, Loss: 3.5275


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.47it/s]


Validation Accuracy: 21.97%, , Validation Loss: 3.6388


Epoch 2/30: 100%|██████████| 137/137 [00:13<00:00, 10.10it/s]


Epoch 2/30, Loss: 2.2330


Validation: 100%|██████████| 40/40 [00:03<00:00, 10.15it/s]


Validation Accuracy: 39.51%, , Validation Loss: 2.4388


Epoch 3/30: 100%|██████████| 137/137 [00:12<00:00, 10.60it/s]


Epoch 3/30, Loss: 1.8019


Validation: 100%|██████████| 40/40 [00:04<00:00,  9.48it/s]


Validation Accuracy: 40.42%, , Validation Loss: 2.4487


Epoch 4/30: 100%|██████████| 137/137 [00:14<00:00,  9.59it/s]


Epoch 4/30, Loss: 1.6084


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.82it/s]


Validation Accuracy: 46.05%, , Validation Loss: 2.1662


Epoch 5/30: 100%|██████████| 137/137 [00:15<00:00,  8.75it/s]


Epoch 5/30, Loss: 1.3235


Validation: 100%|██████████| 40/40 [00:03<00:00, 13.11it/s]


Validation Accuracy: 43.37%, , Validation Loss: 2.4177


Epoch 6/30: 100%|██████████| 137/137 [00:14<00:00,  9.40it/s]


Epoch 6/30, Loss: 0.7997


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.61it/s]


Validation Accuracy: 49.83%, , Validation Loss: 2.1048


Epoch 7/30: 100%|██████████| 137/137 [00:14<00:00,  9.41it/s]


Epoch 7/30, Loss: 0.8147


Validation: 100%|██████████| 40/40 [00:03<00:00, 10.60it/s]


Validation Accuracy: 42.20%, , Validation Loss: 2.6326


Epoch 8/30: 100%|██████████| 137/137 [00:13<00:00,  9.87it/s]


Epoch 8/30, Loss: 0.7977


Validation: 100%|██████████| 40/40 [00:04<00:00, 10.00it/s]


Validation Accuracy: 47.94%, , Validation Loss: 2.3495


Epoch 9/30: 100%|██████████| 137/137 [00:13<00:00,  9.89it/s]


Epoch 9/30, Loss: 0.5334


Validation: 100%|██████████| 40/40 [00:03<00:00, 13.11it/s]


Validation Accuracy: 47.11%, , Validation Loss: 2.5900


Epoch 10/30: 100%|██████████| 137/137 [00:14<00:00,  9.40it/s]


Epoch 10/30, Loss: 0.4192


Validation: 100%|██████████| 40/40 [00:03<00:00, 12.84it/s]


Validation Accuracy: 48.95%, , Validation Loss: 2.7743


Epoch 11/30: 100%|██████████| 137/137 [00:15<00:00,  9.04it/s]


Epoch 11/30, Loss: 0.1845


Validation: 100%|██████████| 40/40 [00:03<00:00, 12.10it/s]


Validation Accuracy: 52.11%, , Validation Loss: 2.6096


Epoch 12/30: 100%|██████████| 137/137 [00:15<00:00,  9.05it/s]


Epoch 12/30, Loss: 0.0886


Validation: 100%|██████████| 40/40 [00:03<00:00, 10.37it/s]


Validation Accuracy: 52.37%, , Validation Loss: 2.7427


Epoch 13/30: 100%|██████████| 137/137 [00:15<00:00,  8.96it/s]


Epoch 13/30, Loss: 0.0649


Validation: 100%|██████████| 40/40 [00:04<00:00,  9.34it/s]


Validation Accuracy: 52.50%, , Validation Loss: 2.8880


Epoch 14/30: 100%|██████████| 137/137 [00:15<00:00,  8.93it/s]


Epoch 14/30, Loss: 0.0521


Validation: 100%|██████████| 40/40 [00:03<00:00, 12.24it/s]


Validation Accuracy: 51.98%, , Validation Loss: 3.0654


Epoch 15/30: 100%|██████████| 137/137 [00:14<00:00,  9.24it/s]


Epoch 15/30, Loss: 0.1538


Validation: 100%|██████████| 40/40 [00:03<00:00, 12.93it/s]


Validation Accuracy: 50.25%, , Validation Loss: 3.1980


Epoch 16/30: 100%|██████████| 137/137 [00:14<00:00,  9.32it/s]


Epoch 16/30, Loss: 0.0687


Validation: 100%|██████████| 40/40 [00:03<00:00, 12.83it/s]


Validation Accuracy: 52.37%, , Validation Loss: 3.0734


Epoch 17/30: 100%|██████████| 137/137 [00:14<00:00,  9.44it/s]


Epoch 17/30, Loss: 0.0420


Validation: 100%|██████████| 40/40 [00:03<00:00, 10.30it/s]


Validation Accuracy: 52.76%, , Validation Loss: 3.1082


Epoch 18/30: 100%|██████████| 137/137 [00:14<00:00,  9.50it/s]


Epoch 18/30, Loss: 0.0389


Validation: 100%|██████████| 40/40 [00:04<00:00,  9.72it/s]


Validation Accuracy: 52.21%, , Validation Loss: 3.2143


Epoch 19/30: 100%|██████████| 137/137 [00:14<00:00,  9.71it/s]


Epoch 19/30, Loss: 0.0352


Validation: 100%|██████████| 40/40 [00:03<00:00, 13.26it/s]


Validation Accuracy: 52.81%, , Validation Loss: 3.2339


Epoch 20/30: 100%|██████████| 137/137 [00:14<00:00,  9.63it/s]


Epoch 20/30, Loss: 0.0327


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.42it/s]


Validation Accuracy: 53.05%, , Validation Loss: 3.3019


Epoch 21/30: 100%|██████████| 137/137 [00:14<00:00,  9.78it/s]


Epoch 21/30, Loss: 0.0268


Validation: 100%|██████████| 40/40 [00:03<00:00, 13.19it/s]


Validation Accuracy: 53.09%, , Validation Loss: 3.3177


Epoch 22/30: 100%|██████████| 137/137 [00:14<00:00,  9.69it/s]


Epoch 22/30, Loss: 0.0254


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.37it/s]


Validation Accuracy: 53.30%, , Validation Loss: 3.3635


Epoch 23/30: 100%|██████████| 137/137 [00:13<00:00,  9.81it/s]


Epoch 23/30, Loss: 0.0229


Validation: 100%|██████████| 40/40 [00:04<00:00,  9.68it/s]


Validation Accuracy: 53.25%, , Validation Loss: 3.3665


Epoch 24/30: 100%|██████████| 137/137 [00:13<00:00, 10.12it/s]


Epoch 24/30, Loss: 0.0271


Validation: 100%|██████████| 40/40 [00:03<00:00, 10.17it/s]


Validation Accuracy: 52.78%, , Validation Loss: 3.4171


Epoch 25/30: 100%|██████████| 137/137 [00:13<00:00,  9.82it/s]


Epoch 25/30, Loss: 0.0251


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.61it/s]


Validation Accuracy: 52.97%, , Validation Loss: 3.4577


Epoch 26/30: 100%|██████████| 137/137 [00:13<00:00,  9.85it/s]


Epoch 26/30, Loss: 0.0227


Validation: 100%|██████████| 40/40 [00:03<00:00, 13.21it/s]


Validation Accuracy: 53.13%, , Validation Loss: 3.4523


Epoch 27/30: 100%|██████████| 137/137 [00:14<00:00,  9.53it/s]


Epoch 27/30, Loss: 0.0208


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.34it/s]


Validation Accuracy: 52.94%, , Validation Loss: 3.4642


Epoch 28/30: 100%|██████████| 137/137 [00:14<00:00,  9.68it/s]


Epoch 28/30, Loss: 0.0201


Validation: 100%|██████████| 40/40 [00:03<00:00, 13.19it/s]


Validation Accuracy: 53.06%, , Validation Loss: 3.4610


Epoch 29/30: 100%|██████████| 137/137 [00:14<00:00,  9.44it/s]


Epoch 29/30, Loss: 0.0201


Validation: 100%|██████████| 40/40 [00:03<00:00, 10.21it/s]


Validation Accuracy: 53.01%, , Validation Loss: 3.4938


Epoch 30/30: 100%|██████████| 137/137 [00:13<00:00,  9.97it/s]


Epoch 30/30, Loss: 0.0193


Validation: 100%|██████████| 40/40 [00:03<00:00, 10.01it/s]


Validation Accuracy: 53.20%, , Validation Loss: 3.5055


In [None]:
# Evaluation loop
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in tqdm(test_loader, desc="Evaluating"):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print(f"Test Accuracy: {accuracy * 100:.2f}%")

Evaluating: 100%|██████████| 20/20 [00:02<00:00,  7.93it/s]

Test Accuracy: 52.56%





# Third Run
Notes:
1.   Changed image size back to 32
2.   Changed learning rate from 0.003 to 0.001

In [None]:
# new run
import torch
import torch.nn as nn
from torchvision import models, transforms
# from torchvision.models import resnet50, ResNet50_Weights
from torchvision.models import mobilenet_v3_large, MobileNet_V3_Large_Weights
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
num_classes = 100  # CIFAR-100 has 100 classes
lr = 0.001 #1e-4
patch_size = 4 #8
num_layers = 12
num_heads = 16 #12
hidden_dim = 512 #768
mlp_dim = 3072
image_size = 32
# max_len = 100 # All sequences must be less than 1000 including class token
# channels = 3

num_epochs = 30

# Model, loss, and optimizer
# weights = ResNet50_Weights.DEFAULT
weights = MobileNet_V3_Large_Weights.DEFAULT

model = mobilenet_v3_large(weights=weights).to(device)
# model = resnet50(weights=weights).to(device)
#model = resnet50(num_classes=num_classes, image_size=image_size, patch_size=patch_size, num_layers=num_layers, num_heads=num_heads, hidden_dim=hidden_dim, mlp_dim=mlp_dim).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=5, gamma=0.5)  # Optional: learning rate scheduler

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        inputs, labels = inputs.to(device), labels.to(device)

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Optional: Adjust learning rate
    scheduler.step()

    # Print average loss per epoch
    average_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {average_loss:.4f}")

    # Validation loop
    model.eval()
    correct = 0
    total = 0
    total_val_loss = 0

    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc="Validation"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            loss = criterion(outputs, labels)

            total_val_loss += loss.item()

    val_accuracy = correct / total
    average_val_loss = total_val_loss / len(val_loader)
    print(f"Validation Accuracy: {val_accuracy * 100:.2f}%, , Validation Loss: {average_val_loss:.4f}")

# Save the trained model
torch.save(model.state_dict(), "mobilenet_v3_large.pth")

Epoch 1/30: 100%|██████████| 137/137 [00:14<00:00,  9.66it/s]


Epoch 1/30, Loss: 4.1013


Validation: 100%|██████████| 40/40 [00:03<00:00, 11.01it/s]


Validation Accuracy: 22.97%, , Validation Loss: 3.1913


Epoch 2/30: 100%|██████████| 137/137 [00:15<00:00,  9.12it/s]


Epoch 2/30, Loss: 2.4232


Validation: 100%|██████████| 40/40 [00:04<00:00,  8.94it/s]


Validation Accuracy: 36.99%, , Validation Loss: 2.4605


Epoch 3/30: 100%|██████████| 137/137 [00:16<00:00,  8.54it/s]


Epoch 3/30, Loss: 1.8801


Validation: 100%|██████████| 40/40 [00:03<00:00, 11.63it/s]


Validation Accuracy: 42.21%, , Validation Loss: 2.2674


Epoch 4/30: 100%|██████████| 137/137 [00:15<00:00,  8.88it/s]


Epoch 4/30, Loss: 1.4567


Validation: 100%|██████████| 40/40 [00:03<00:00, 11.34it/s]


Validation Accuracy: 44.90%, , Validation Loss: 2.2470


Epoch 5/30: 100%|██████████| 137/137 [00:15<00:00,  8.71it/s]


Epoch 5/30, Loss: 1.1283


Validation: 100%|██████████| 40/40 [00:04<00:00,  8.33it/s]


Validation Accuracy: 44.78%, , Validation Loss: 2.3492


Epoch 6/30: 100%|██████████| 137/137 [00:14<00:00,  9.34it/s]


Epoch 6/30, Loss: 0.6485


Validation: 100%|██████████| 40/40 [00:03<00:00, 10.79it/s]


Validation Accuracy: 47.54%, , Validation Loss: 2.2832


Epoch 7/30: 100%|██████████| 137/137 [00:15<00:00,  8.86it/s]


Epoch 7/30, Loss: 0.4125


Validation: 100%|██████████| 40/40 [00:03<00:00, 11.84it/s]


Validation Accuracy: 47.42%, , Validation Loss: 2.4485


Epoch 8/30: 100%|██████████| 137/137 [00:16<00:00,  8.12it/s]


Epoch 8/30, Loss: 0.2926


Validation: 100%|██████████| 40/40 [00:04<00:00,  9.03it/s]


Validation Accuracy: 46.92%, , Validation Loss: 2.7213


Epoch 9/30: 100%|██████████| 137/137 [00:15<00:00,  8.58it/s]


Epoch 9/30, Loss: 0.2315


Validation: 100%|██████████| 40/40 [00:04<00:00,  8.75it/s]


Validation Accuracy: 46.72%, , Validation Loss: 2.8548


Epoch 10/30: 100%|██████████| 137/137 [00:16<00:00,  8.47it/s]


Epoch 10/30, Loss: 0.2351


Validation: 100%|██████████| 40/40 [00:03<00:00, 12.69it/s]


Validation Accuracy: 45.24%, , Validation Loss: 3.0821


Epoch 11/30: 100%|██████████| 137/137 [00:14<00:00,  9.17it/s]


Epoch 11/30, Loss: 0.1577


Validation: 100%|██████████| 40/40 [00:03<00:00, 12.48it/s]


Validation Accuracy: 47.79%, , Validation Loss: 2.9541


Epoch 12/30: 100%|██████████| 137/137 [00:14<00:00,  9.18it/s]


Epoch 12/30, Loss: 0.0929


Validation: 100%|██████████| 40/40 [00:03<00:00, 11.06it/s]


Validation Accuracy: 48.22%, , Validation Loss: 3.0175


Epoch 13/30: 100%|██████████| 137/137 [00:13<00:00, 10.13it/s]


Epoch 13/30, Loss: 0.0726


Validation: 100%|██████████| 40/40 [00:05<00:00,  7.31it/s]


Validation Accuracy: 48.16%, , Validation Loss: 3.0680


Epoch 14/30: 100%|██████████| 137/137 [00:16<00:00,  8.36it/s]


Epoch 14/30, Loss: 0.0626


Validation: 100%|██████████| 40/40 [00:03<00:00, 11.63it/s]


Validation Accuracy: 48.09%, , Validation Loss: 3.1922


Epoch 15/30: 100%|██████████| 137/137 [00:14<00:00,  9.43it/s]


Epoch 15/30, Loss: 0.0578


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.38it/s]


Validation Accuracy: 47.75%, , Validation Loss: 3.2361


Epoch 16/30: 100%|██████████| 137/137 [00:14<00:00,  9.22it/s]


Epoch 16/30, Loss: 0.0483


Validation: 100%|██████████| 40/40 [00:03<00:00, 10.91it/s]


Validation Accuracy: 48.58%, , Validation Loss: 3.2163


Epoch 17/30: 100%|██████████| 137/137 [00:13<00:00,  9.85it/s]


Epoch 17/30, Loss: 0.0415


Validation: 100%|██████████| 40/40 [00:04<00:00,  9.52it/s]


Validation Accuracy: 48.76%, , Validation Loss: 3.2493


Epoch 18/30: 100%|██████████| 137/137 [00:13<00:00,  9.81it/s]


Epoch 18/30, Loss: 0.0386


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.43it/s]


Validation Accuracy: 48.66%, , Validation Loss: 3.2619


Epoch 19/30: 100%|██████████| 137/137 [00:14<00:00,  9.64it/s]


Epoch 19/30, Loss: 0.0397


Validation: 100%|██████████| 40/40 [00:03<00:00, 12.05it/s]


Validation Accuracy: 48.63%, , Validation Loss: 3.3248


Epoch 20/30: 100%|██████████| 137/137 [00:14<00:00,  9.54it/s]


Epoch 20/30, Loss: 0.0384


Validation: 100%|██████████| 40/40 [00:03<00:00, 13.17it/s]


Validation Accuracy: 48.64%, , Validation Loss: 3.3382


Epoch 21/30: 100%|██████████| 137/137 [00:14<00:00,  9.73it/s]


Epoch 21/30, Loss: 0.0344


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.40it/s]


Validation Accuracy: 48.89%, , Validation Loss: 3.3315


Epoch 22/30: 100%|██████████| 137/137 [00:14<00:00,  9.71it/s]


Epoch 22/30, Loss: 0.0330


Validation: 100%|██████████| 40/40 [00:03<00:00, 10.15it/s]


Validation Accuracy: 49.15%, , Validation Loss: 3.3577


Epoch 23/30: 100%|██████████| 137/137 [00:13<00:00, 10.18it/s]


Epoch 23/30, Loss: 0.0314


Validation: 100%|██████████| 40/40 [00:04<00:00,  8.95it/s]


Validation Accuracy: 49.08%, , Validation Loss: 3.3595


Epoch 24/30: 100%|██████████| 137/137 [00:13<00:00,  9.88it/s]


Epoch 24/30, Loss: 0.0317


Validation: 100%|██████████| 40/40 [00:03<00:00, 13.16it/s]


Validation Accuracy: 49.06%, , Validation Loss: 3.3905


Epoch 25/30: 100%|██████████| 137/137 [00:14<00:00,  9.71it/s]


Epoch 25/30, Loss: 0.0306


Validation: 100%|██████████| 40/40 [00:03<00:00, 13.27it/s]


Validation Accuracy: 49.07%, , Validation Loss: 3.3962


Epoch 26/30: 100%|██████████| 137/137 [00:14<00:00,  9.62it/s]


Epoch 26/30, Loss: 0.0311


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.52it/s]


Validation Accuracy: 49.15%, , Validation Loss: 3.3989


Epoch 27/30: 100%|██████████| 137/137 [00:13<00:00,  9.79it/s]


Epoch 27/30, Loss: 0.0301


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.63it/s]


Validation Accuracy: 48.99%, , Validation Loss: 3.4073


Epoch 28/30: 100%|██████████| 137/137 [00:14<00:00,  9.77it/s]


Epoch 28/30, Loss: 0.0292


Validation: 100%|██████████| 40/40 [00:03<00:00, 11.98it/s]


Validation Accuracy: 48.99%, , Validation Loss: 3.4174


Epoch 29/30: 100%|██████████| 137/137 [00:13<00:00, 10.20it/s]


Epoch 29/30, Loss: 0.0295


Validation: 100%|██████████| 40/40 [00:04<00:00,  8.76it/s]


Validation Accuracy: 49.16%, , Validation Loss: 3.4122


Epoch 30/30: 100%|██████████| 137/137 [00:14<00:00,  9.37it/s]


Epoch 30/30, Loss: 0.0288


Validation: 100%|██████████| 40/40 [00:03<00:00, 12.26it/s]


Validation Accuracy: 49.26%, , Validation Loss: 3.4240


In [None]:
# Evaluation loop
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in tqdm(test_loader, desc="Evaluating"):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print(f"Test Accuracy: {accuracy * 100:.2f}%")

Evaluating: 100%|██████████| 20/20 [00:01<00:00, 12.25it/s]

Test Accuracy: 48.68%



