In [None]:
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR100
import torchvision.transforms as transforms

# Data loading and splitting
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize for CIFAR-100
])

full_dataset = CIFAR100(root="./data", train=True, download=True, transform=transform)

# Split the dataset into training, validation, and testing sets
train_size = int(0.7 * len(full_dataset))
val_size = int(0.2 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])

batch_size = 256

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:03<00:00, 48358672.43it/s]


Extracting ./data/cifar-100-python.tar.gz to ./data


# First Run

In [None]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from torchvision.models import efficientnet
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
num_classes = 100  # CIFAR-100 has 100 classes
lr = 0.003 #1e-4
patch_size = 4 #8
num_layers = 12
num_heads = 16 #12
hidden_dim = 512 #768
mlp_dim = 3072
image_size = 32

num_epochs = 30

model = efficientnet.efficientnet_b1(pretrained=True).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=5, gamma=0.5)  # Optional: learning rate scheduler

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        inputs, labels = inputs.to(device), labels.to(device)

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Optional: Adjust learning rate
    scheduler.step()

    # Print average loss per epoch
    average_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {average_loss:.4f}")

    # Validation loop
    model.eval()
    correct = 0
    total = 0
    total_val_loss = 0

    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc="Validation"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            loss = criterion(outputs, labels)

            total_val_loss += loss.item()

    val_accuracy = correct / total
    average_val_loss = total_val_loss / len(val_loader)
    print(f"Validation Accuracy: {val_accuracy * 100:.2f}%, , Validation Loss: {average_val_loss:.4f}")

# Save the trained model
torch.save(model.state_dict(), "resnet_model.pth")

Epoch 1/30: 100%|██████████| 137/137 [00:25<00:00,  5.37it/s]


Epoch 1/30, Loss: 3.8932


Validation: 100%|██████████| 40/40 [00:03<00:00, 13.05it/s]


Validation Accuracy: 32.72%, , Validation Loss: 2.6497


Epoch 2/30: 100%|██████████| 137/137 [00:17<00:00,  8.00it/s]


Epoch 2/30, Loss: 2.3633


Validation: 100%|██████████| 40/40 [00:03<00:00, 10.34it/s]


Validation Accuracy: 42.32%, , Validation Loss: 2.1832


Epoch 3/30: 100%|██████████| 137/137 [00:17<00:00,  7.68it/s]


Epoch 3/30, Loss: 1.8892


Validation: 100%|██████████| 40/40 [00:03<00:00, 13.32it/s]


Validation Accuracy: 47.62%, , Validation Loss: 1.9762


Epoch 4/30: 100%|██████████| 137/137 [00:16<00:00,  8.28it/s]


Epoch 4/30, Loss: 1.6221


Validation: 100%|██████████| 40/40 [00:03<00:00, 12.45it/s]


Validation Accuracy: 48.24%, , Validation Loss: 1.9508


Epoch 5/30: 100%|██████████| 137/137 [00:17<00:00,  7.99it/s]


Epoch 5/30, Loss: 1.3959


Validation: 100%|██████████| 40/40 [00:04<00:00,  9.63it/s]


Validation Accuracy: 50.79%, , Validation Loss: 1.9753


Epoch 6/30: 100%|██████████| 137/137 [00:16<00:00,  8.08it/s]


Epoch 6/30, Loss: 0.9208


Validation: 100%|██████████| 40/40 [00:03<00:00, 12.93it/s]


Validation Accuracy: 55.61%, , Validation Loss: 1.7938


Epoch 7/30: 100%|██████████| 137/137 [00:16<00:00,  8.10it/s]


Epoch 7/30, Loss: 0.6823


Validation: 100%|██████████| 40/40 [00:03<00:00, 10.01it/s]


Validation Accuracy: 55.13%, , Validation Loss: 1.9748


Epoch 8/30: 100%|██████████| 137/137 [00:17<00:00,  7.77it/s]


Epoch 8/30, Loss: 0.5770


Validation: 100%|██████████| 40/40 [00:03<00:00, 13.08it/s]


Validation Accuracy: 55.50%, , Validation Loss: 2.0188


Epoch 9/30: 100%|██████████| 137/137 [00:17<00:00,  8.05it/s]


Epoch 9/30, Loss: 0.4943


Validation: 100%|██████████| 40/40 [00:03<00:00, 12.90it/s]


Validation Accuracy: 54.35%, , Validation Loss: 2.1271


Epoch 10/30: 100%|██████████| 137/137 [00:17<00:00,  7.70it/s]


Epoch 10/30, Loss: 0.4531


Validation: 100%|██████████| 40/40 [00:03<00:00, 10.46it/s]


Validation Accuracy: 54.68%, , Validation Loss: 2.2667


Epoch 11/30: 100%|██████████| 137/137 [00:16<00:00,  8.09it/s]


Epoch 11/30, Loss: 0.2617


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.64it/s]


Validation Accuracy: 57.24%, , Validation Loss: 2.2693


Epoch 12/30: 100%|██████████| 137/137 [00:16<00:00,  8.43it/s]


Epoch 12/30, Loss: 0.1486


Validation: 100%|██████████| 40/40 [00:03<00:00, 10.48it/s]


Validation Accuracy: 57.10%, , Validation Loss: 2.4758


Epoch 13/30: 100%|██████████| 137/137 [00:16<00:00,  8.10it/s]


Epoch 13/30, Loss: 0.1252


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.84it/s]


Validation Accuracy: 57.18%, , Validation Loss: 2.6352


Epoch 14/30: 100%|██████████| 137/137 [00:16<00:00,  8.19it/s]


Epoch 14/30, Loss: 0.1063


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.83it/s]


Validation Accuracy: 56.64%, , Validation Loss: 2.7148


Epoch 15/30: 100%|██████████| 137/137 [00:16<00:00,  8.35it/s]


Epoch 15/30, Loss: 0.1072


Validation: 100%|██████████| 40/40 [00:04<00:00,  9.63it/s]


Validation Accuracy: 56.46%, , Validation Loss: 2.8211


Epoch 16/30: 100%|██████████| 137/137 [00:16<00:00,  8.41it/s]


Epoch 16/30, Loss: 0.0796


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.39it/s]


Validation Accuracy: 58.04%, , Validation Loss: 2.8023


Epoch 17/30: 100%|██████████| 137/137 [00:16<00:00,  8.45it/s]


Epoch 17/30, Loss: 0.0516


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.40it/s]


Validation Accuracy: 57.71%, , Validation Loss: 2.8173


Epoch 18/30: 100%|██████████| 137/137 [00:16<00:00,  8.08it/s]


Epoch 18/30, Loss: 0.0447


Validation: 100%|██████████| 40/40 [00:03<00:00, 11.58it/s]


Validation Accuracy: 57.61%, , Validation Loss: 2.8789


Epoch 19/30: 100%|██████████| 137/137 [00:16<00:00,  8.43it/s]


Epoch 19/30, Loss: 0.0400


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.63it/s]


Validation Accuracy: 57.70%, , Validation Loss: 2.9060


Epoch 20/30: 100%|██████████| 137/137 [00:16<00:00,  8.41it/s]


Epoch 20/30, Loss: 0.0345


Validation: 100%|██████████| 40/40 [00:03<00:00, 10.25it/s]


Validation Accuracy: 58.01%, , Validation Loss: 3.0005


Epoch 21/30: 100%|██████████| 137/137 [00:16<00:00,  8.11it/s]


Epoch 21/30, Loss: 0.0328


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.35it/s]


Validation Accuracy: 58.17%, , Validation Loss: 2.9845


Epoch 22/30: 100%|██████████| 137/137 [00:16<00:00,  8.37it/s]


Epoch 22/30, Loss: 0.0252


Validation: 100%|██████████| 40/40 [00:03<00:00, 12.19it/s]


Validation Accuracy: 58.03%, , Validation Loss: 3.0390


Epoch 23/30: 100%|██████████| 137/137 [00:16<00:00,  8.15it/s]


Epoch 23/30, Loss: 0.0226


Validation: 100%|██████████| 40/40 [00:03<00:00, 10.07it/s]


Validation Accuracy: 58.00%, , Validation Loss: 3.0353


Epoch 24/30: 100%|██████████| 137/137 [00:16<00:00,  8.50it/s]


Epoch 24/30, Loss: 0.0231


Validation: 100%|██████████| 40/40 [00:03<00:00, 12.66it/s]


Validation Accuracy: 57.85%, , Validation Loss: 3.0353


Epoch 25/30: 100%|██████████| 137/137 [00:16<00:00,  8.33it/s]


Epoch 25/30, Loss: 0.0220


Validation: 100%|██████████| 40/40 [00:03<00:00, 12.09it/s]


Validation Accuracy: 57.72%, , Validation Loss: 3.0844


Epoch 26/30: 100%|██████████| 137/137 [00:17<00:00,  7.87it/s]


Epoch 26/30, Loss: 0.0165


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.40it/s]


Validation Accuracy: 57.87%, , Validation Loss: 3.0799


Epoch 27/30: 100%|██████████| 137/137 [00:16<00:00,  8.37it/s]


Epoch 27/30, Loss: 0.0157


Validation: 100%|██████████| 40/40 [00:03<00:00, 12.84it/s]


Validation Accuracy: 58.03%, , Validation Loss: 3.1109


Epoch 28/30: 100%|██████████| 137/137 [00:16<00:00,  8.35it/s]


Epoch 28/30, Loss: 0.0144


Validation: 100%|██████████| 40/40 [00:04<00:00,  9.16it/s]


Validation Accuracy: 58.18%, , Validation Loss: 3.1043


Epoch 29/30: 100%|██████████| 137/137 [00:16<00:00,  8.25it/s]


Epoch 29/30, Loss: 0.0132


Validation: 100%|██████████| 40/40 [00:03<00:00, 13.31it/s]


Validation Accuracy: 58.09%, , Validation Loss: 3.1198


Epoch 30/30: 100%|██████████| 137/137 [00:16<00:00,  8.38it/s]


Epoch 30/30, Loss: 0.0143


Validation: 100%|██████████| 40/40 [00:03<00:00, 13.23it/s]


Validation Accuracy: 58.25%, , Validation Loss: 3.1211


In [None]:
# Evaluation loop
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in tqdm(test_loader, desc="Evaluating"):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print(f"Test Accuracy: {accuracy * 100:.2f}%")

Evaluating: 100%|██████████| 20/20 [00:02<00:00,  8.49it/s]

Test Accuracy: 58.80%





# Second Run

In [None]:
# new run
import torch
import torch.nn as nn
from torchvision import models, transforms
# from torchvision.models import resnet50, ResNet50_Weights
from torchvision.models import mobilenet_v3_large, MobileNet_V3_Large_Weights
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
num_classes = 100  # CIFAR-100 has 100 classes
lr = 0.003 #1e-4
patch_size = 4 #8
num_layers = 12
num_heads = 16 #12
hidden_dim = 512 #768
mlp_dim = 3072
image_size = 224
# max_len = 100 # All sequences must be less than 1000 including class token
# channels = 3

num_epochs = 30

# Model, loss, and optimizer
# weights = ResNet50_Weights.DEFAULT
# weights = MobileNet_V3_Large_Weights.DEFAULT

model = efficientnet.efficientnet_b1(pretrained=True).to(device)
# model = resnet50(weights=weights).to(device)
#model = resnet50(num_classes=num_classes, image_size=image_size, patch_size=patch_size, num_layers=num_layers, num_heads=num_heads, hidden_dim=hidden_dim, mlp_dim=mlp_dim).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=5, gamma=0.5)  # Optional: learning rate scheduler

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        inputs, labels = inputs.to(device), labels.to(device)

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Optional: Adjust learning rate
    scheduler.step()

    # Print average loss per epoch
    average_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {average_loss:.4f}")

    # Validation loop
    model.eval()
    correct = 0
    total = 0
    total_val_loss = 0

    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc="Validation"):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            loss = criterion(outputs, labels)

            total_val_loss += loss.item()

    val_accuracy = correct / total
    average_val_loss = total_val_loss / len(val_loader)
    print(f"Validation Accuracy: {val_accuracy * 100:.2f}%, , Validation Loss: {average_val_loss:.4f}")

# Save the trained model
torch.save(model.state_dict(), "resnet_model.pth")

Epoch 1/30: 100%|██████████| 137/137 [00:13<00:00, 10.02it/s]


Epoch 1/30, Loss: 3.5275


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.47it/s]


Validation Accuracy: 21.97%, , Validation Loss: 3.6388


Epoch 2/30: 100%|██████████| 137/137 [00:13<00:00, 10.10it/s]


Epoch 2/30, Loss: 2.2330


Validation: 100%|██████████| 40/40 [00:03<00:00, 10.15it/s]


Validation Accuracy: 39.51%, , Validation Loss: 2.4388


Epoch 3/30: 100%|██████████| 137/137 [00:12<00:00, 10.60it/s]


Epoch 3/30, Loss: 1.8019


Validation: 100%|██████████| 40/40 [00:04<00:00,  9.48it/s]


Validation Accuracy: 40.42%, , Validation Loss: 2.4487


Epoch 4/30: 100%|██████████| 137/137 [00:14<00:00,  9.59it/s]


Epoch 4/30, Loss: 1.6084


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.82it/s]


Validation Accuracy: 46.05%, , Validation Loss: 2.1662


Epoch 5/30: 100%|██████████| 137/137 [00:15<00:00,  8.75it/s]


Epoch 5/30, Loss: 1.3235


Validation: 100%|██████████| 40/40 [00:03<00:00, 13.11it/s]


Validation Accuracy: 43.37%, , Validation Loss: 2.4177


Epoch 6/30: 100%|██████████| 137/137 [00:14<00:00,  9.40it/s]


Epoch 6/30, Loss: 0.7997


Validation: 100%|██████████| 40/40 [00:02<00:00, 13.61it/s]


Validation Accuracy: 49.83%, , Validation Loss: 2.1048


Epoch 7/30: 100%|██████████| 137/137 [00:14<00:00,  9.41it/s]


Epoch 7/30, Loss: 0.8147


Validation: 100%|██████████| 40/40 [00:03<00:00, 10.60it/s]


Validation Accuracy: 42.20%, , Validation Loss: 2.6326


Epoch 8/30: 100%|██████████| 137/137 [00:13<00:00,  9.87it/s]


Epoch 8/30, Loss: 0.7977


Validation: 100%|██████████| 40/40 [00:04<00:00, 10.00it/s]


Validation Accuracy: 47.94%, , Validation Loss: 2.3495


Epoch 9/30: 100%|██████████| 137/137 [00:13<00:00,  9.89it/s]


Epoch 9/30, Loss: 0.5334


Validation: 100%|██████████| 40/40 [00:03<00:00, 13.11it/s]


Validation Accuracy: 47.11%, , Validation Loss: 2.5900


Epoch 10/30: 100%|██████████| 137/137 [00:14<00:00,  9.40it/s]


Epoch 10/30, Loss: 0.4192


Validation: 100%|██████████| 40/40 [00:03<00:00, 12.84it/s]


Validation Accuracy: 48.95%, , Validation Loss: 2.7743


Epoch 11/30: 100%|██████████| 137/137 [00:15<00:00,  9.04it/s]


Epoch 11/30, Loss: 0.1845


Validation: 100%|██████████| 40/40 [00:03<00:00, 12.10it/s]


Validation Accuracy: 52.11%, , Validation Loss: 2.6096


Epoch 12/30: 100%|██████████| 137/137 [00:15<00:00,  9.05it/s]


Epoch 12/30, Loss: 0.0886


Validation: 100%|██████████| 40/40 [00:03<00:00, 10.37it/s]


Validation Accuracy: 52.37%, , Validation Loss: 2.7427


Epoch 13/30: 100%|██████████| 137/137 [00:15<00:00,  8.96it/s]


Epoch 13/30, Loss: 0.0649


Validation: 100%|██████████| 40/40 [00:04<00:00,  9.34it/s]


Validation Accuracy: 52.50%, , Validation Loss: 2.8880


Epoch 14/30: 100%|██████████| 137/137 [00:15<00:00,  8.93it/s]


Epoch 14/30, Loss: 0.0521


Validation: 100%|██████████| 40/40 [00:03<00:00, 12.24it/s]


Validation Accuracy: 51.98%, , Validation Loss: 3.0654


Epoch 15/30: 100%|██████████| 137/137 [00:14<00:00,  9.24it/s]


Epoch 15/30, Loss: 0.1538


Validation: 100%|██████████| 40/40 [00:03<00:00, 12.93it/s]


Validation Accuracy: 50.25%, , Validation Loss: 3.1980


Epoch 16/30: 100%|██████████| 137/137 [00:14<00:00,  9.32it/s]


Epoch 16/30, Loss: 0.0687


Validation: 100%|██████████| 40/40 [00:03<00:00, 12.83it/s]


Validation Accuracy: 52.37%, , Validation Loss: 3.0734


Epoch 17/30:  90%|████████▉ | 123/137 [00:13<00:01, 12.08it/s]