## 📦 Imports and Environment Setup

In [1]:
import os
from pathlib import Path
import timm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.optim import Adam
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt


## ⚙️ Configuration

In [2]:
# Adjust these paths as needed
data_root = Path("../out_data")  # Root folder containing 'train' and 'val' subfolders
train_dir = data_root / "../out_data_split/train"
val_dir = data_root / "../out_data_split/val"
batch_size = 32
num_epochs = 10
learning_rate = 1e-4


## 🧪 Data Transforms and Loaders

In [3]:
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(degrees=10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])

val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])

train_dataset = datasets.ImageFolder(str(train_dir), transform=train_transforms)
val_dataset = datasets.ImageFolder(str(val_dir), transform=val_transforms)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

class_names = train_dataset.classes
print("Detected classes:", class_names)


Detected classes: ['4011', '4015', '4088', '4196', '7020097009819', '7020097026113', '7023026089401', '7035620058776', '7037203626563', '7037206100022', '7038010009457', '7038010013966', '7038010021145', '7038010054488', '7038010068980', '7039610000318', '7040513000022', '7040513001753', '7040913336684', '7044610874661', '7048840205868', '7071688004713', '7622210410337', '90433917', '90433924', '94011']


## 🧠 Model Setup

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = timm.create_model("efficientnet_b0", pretrained=True)
model.classifier = nn.Linear(model.classifier.in_features, len(class_names))
model = model.to(device)


model.safetensors:   0%|          | 0.00/21.4M [00:00<?, ?B/s]

## 🔧 Loss Function and Optimizer

In [5]:
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=learning_rate)


## 🔁 Training and Validation Functions

In [6]:
def train_one_epoch(model, loader):
    model.train()
    total_loss, correct = 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        correct += (out.argmax(1) == y).sum().item()
    return total_loss / len(loader), correct / len(loader.dataset)

def validate(model, loader):
    model.eval()
    total_loss, correct = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            total_loss += loss.item()
            correct += (out.argmax(1) == y).sum().item()
    return total_loss / len(loader), correct / len(loader.dataset)


## 🏁 Training Loop

In [None]:
best_val_acc = 0

for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader)
    val_loss, val_acc = validate(model, val_loader)
    print(f"Epoch {epoch+1}/{num_epochs} - Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "efficientnet_best.pth")


## 📊 Evaluation

In [None]:
model.load_state_dict(torch.load("efficientnet_best.pth"))
model.eval()

all_preds = []
all_labels = []

with torch.no_grad():
    for x, y in val_loader:
        x = x.to(device)
        out = model(x)
        preds = out.argmax(1).cpu()
        all_preds.extend(preds)
        all_labels.extend(y)

print(classification_report(all_labels, all_preds, target_names=class_names))

cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(10, 8))
plt.imshow(cm, cmap='Blues')
plt.title("Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.colorbar()
plt.show()
