In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

# Data pre-processing transformations
transform = transforms.Compose([
    transforms.Lambda(lambda img: img.convert('RGB')),  # Convert grayscale to RGB
    transforms.Resize((128, 128)),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
])

# Download Caltech101 dataset
dataset = datasets.Caltech101(root='./data', download=True, transform=transform)

# Split into train, validation, and test sets
train_size = int(0.7 * len(dataset))
val_size = int(0.2 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# Initialize data loaders
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=512, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False)

Files already downloaded and verified


In [2]:
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torch.optim import lr_scheduler

# Define the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load pre-trained VGG-19 with batch normalization
model = models.efficientnet_b7(weights=None)
model.load_state_dict(torch.load('pretrained_efficientnet.pth'))
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, 101)

# Move the model to the device
model = model.to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
#optimizer = optim.SGD(model.parameters(), lr=0.00148, momentum=0.9395, weight_decay=0.001415, dampening=0.00289)
optimizer = optim.AdamW(model.parameters(), lr=0.0005, weight_decay=0.01)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=1, verbose=True)

# Train and validate the model
num_epochs = 20

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")

    # Training phase
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(train_loader, desc="Training"):
        images, labels = images.to(device), labels.to(device)  # Move data to the device
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    train_loss = running_loss / len(train_loader)
    print(f"Training Loss: {train_loss}")
    
    # Validation phase
    model.eval()
    val_running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Validation"):
            images, labels = images.to(device), labels.to(device)  # Move data to the device
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    val_loss = val_running_loss / len(val_loader)
    val_accuracy = 100 * correct / total
    print(f"Validation Loss: {val_loss}, Validation Accuracy: {val_accuracy}%")
    scheduler.step(val_loss)

Epoch 1/20


Training: 100%|█████████████████████████████████| 12/12 [00:17<00:00,  1.48s/it]


Training Loss: 4.533093849817912


Validation: 100%|█████████████████████████████████| 4/4 [00:04<00:00,  1.13s/it]


Validation Loss: 9.539034366607666, Validation Accuracy: 0.5763688760806917%
Epoch 2/20


Training: 100%|█████████████████████████████████| 12/12 [00:17<00:00,  1.44s/it]


Training Loss: 4.174160083134969


Validation: 100%|█████████████████████████████████| 4/4 [00:03<00:00,  1.17it/s]


Validation Loss: 9.67212963104248, Validation Accuracy: 0.345821325648415%
Epoch 3/20


Training:  25%|████████▌                         | 3/12 [00:04<00:12,  1.41s/it]


KeyboardInterrupt: 