# 🏁 Quick Start: EfficientNet Fine-tuning on CPU

In [58]:
import os
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from torch.optim import Adam
from sklearn.metrics import classification_report
import timm


In [59]:
# Config
data_root = Path("../../../data/out_data_split")
train_dir = data_root / "train"
val_dir = data_root / "val"
batch_size = 32
num_epochs = 20  # Keep it short for first run
learning_rate = 1e-4
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)


Using device: mps


In [60]:
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])

val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])

train_dataset = datasets.ImageFolder(str(train_dir), transform=train_transforms)
val_dataset = datasets.ImageFolder(str(val_dir), transform=val_transforms)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

class_names = train_dataset.classes
print(f"Classes: {class_names}")


Classes: ['4011', '4015', '4088', '4196', '7020097009819', '7020097026113', '7023026089401', '7035620058776', '7037203626563', '7037206100022', '7038010009457', '7038010013966', '7038010021145', '7038010054488', '7038010068980', '7039610000318', '7040513000022', '7040513001753', '7040913336684', '7044610874661', '7048840205868', '7071688004713', '7622210410337', '90433917', '90433924', '94011']


In [61]:
model = timm.create_model("efficientnet_b0", pretrained=True)

# Freeze all layers
for param in model.parameters():
    param.requires_grad = False

# Replace classifier
model.classifier = nn.Linear(model.classifier.in_features, len(class_names))

# Only train classifier
for param in model.classifier.parameters():
    param.requires_grad = True

model = model.to(device)


In [62]:
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = Adam(model.classifier.parameters(), lr=learning_rate)
scheduler = StepLR(optimizer, step_size=5, gamma=0.5)

def train_one_epoch(model, loader):
    model.train()
    total_loss, correct = 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        correct += (out.argmax(1) == y).sum().item()
    return total_loss / len(loader), correct / len(loader.dataset)

def validate(model, loader):
    model.eval()
    total_loss, correct = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            total_loss += loss.item()
            correct += (out.argmax(1) == y).sum().item()
    return total_loss / len(loader), correct / len(loader.dataset)


In [63]:
best_val_acc = 0

for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader)
    val_loss, val_acc = validate(model, val_loader)
    print(f"Epoch {epoch+1}/{num_epochs} - Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")
    current_lr = optimizer.param_groups[0]['lr']
    print(f"Current learning rate: {current_lr:.6f}")

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "efficientnet_best.pth")

    # 🔁 Update learning rate
    scheduler.step()


Epoch 1/20 - Train Acc: 0.1499, Val Acc: 0.3697
Current learning rate: 0.000100
Epoch 2/20 - Train Acc: 0.3906, Val Acc: 0.5530
Current learning rate: 0.000100
Epoch 3/20 - Train Acc: 0.5077, Val Acc: 0.6500
Current learning rate: 0.000100
Epoch 4/20 - Train Acc: 0.5850, Val Acc: 0.6985
Current learning rate: 0.000100
Epoch 5/20 - Train Acc: 0.6476, Val Acc: 0.7561
Current learning rate: 0.000100
Epoch 6/20 - Train Acc: 0.6789, Val Acc: 0.7758
Current learning rate: 0.000050
Epoch 7/20 - Train Acc: 0.7063, Val Acc: 0.7939
Current learning rate: 0.000050
Epoch 8/20 - Train Acc: 0.6963, Val Acc: 0.8030
Current learning rate: 0.000050
Epoch 9/20 - Train Acc: 0.7500, Val Acc: 0.8318
Current learning rate: 0.000050
Epoch 10/20 - Train Acc: 0.7415, Val Acc: 0.8348
Current learning rate: 0.000050
Epoch 11/20 - Train Acc: 0.7477, Val Acc: 0.8364
Current learning rate: 0.000025
Epoch 12/20 - Train Acc: 0.7519, Val Acc: 0.8439
Current learning rate: 0.000025
Epoch 13/20 - Train Acc: 0.7693, Val 

In [65]:
# Load best model and evaluate
model.load_state_dict(torch.load("efficientnet_best.pth"))
model.eval()

all_preds, all_labels = [], []

with torch.no_grad():
    for x, y in val_loader:
        x = x.to(device)
        out = model(x)
        preds = out.argmax(1).cpu()
        all_preds.extend(preds)
        all_labels.extend(y)

print(classification_report(all_labels, all_preds, target_names=class_names))


               precision    recall  f1-score   support

         4011       1.00      0.71      0.83        24
         4015       0.91      0.88      0.90        49
         4088       0.87      0.94      0.91        36
         4196       0.96      1.00      0.98        48
7020097009819       0.71      1.00      0.83        37
7020097026113       1.00      0.54      0.70        13
7023026089401       0.95      1.00      0.98        20
7035620058776       1.00      0.67      0.80         6
7037203626563       1.00      0.20      0.33        10
7037206100022       0.82      1.00      0.90        33
7038010009457       0.93      1.00      0.97        14
7038010013966       0.94      0.88      0.91        33
7038010021145       1.00      0.93      0.96        14
7038010054488       0.85      0.71      0.77        24
7038010068980       0.92      1.00      0.96        33
7039610000318       0.89      1.00      0.94        24
7040513000022       0.84      0.96      0.90        28
704051300