In [1]:
import os
from pathlib import Path
from collections import Counter

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, datasets, models
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix, top_k_accuracy_score

# 1. Configuration
DATA_DIR = Path("/content/drive/MyDrive/Animal Classification/dataset")  # replace with your root folder
BATCH_SIZE = 32
NUM_CLASSES = 15
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 20
MODEL_SAVE_PATH = "best_model.pth"

# 2. Data transforms (ImageNet normalization)
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet stats
                         std=[0.229, 0.224, 0.225]),
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

# 3. Dataset & loaders (assumes train/val folders or use torch.utils.data.random_split)
full_dataset = datasets.ImageFolder(DATA_DIR, transform=train_transform)
class_names = full_dataset.classes

# Simple stratified split
num_total = len(full_dataset)
num_val = int(0.1 * num_total)
num_test = int(0.1 * num_total)
num_train = num_total - num_val - num_test

train_set, val_set, test_set = torch.utils.data.random_split(
    full_dataset,
    [num_train, num_val, num_test],
    generator=torch.Generator().manual_seed(42)
)
# Replace val/test transforms
val_set.dataset.transform = val_transform
test_set.dataset.transform = val_transform

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)

# 4. Baseline model: pretrained ResNet50
model = models.resnet50(pretrained=True)
# Freeze backbone initially
for param in model.parameters():
    param.requires_grad = False

# Replace head
in_features = model.fc.in_features
model.fc = nn.Sequential(
    nn.Linear(in_features, 256),
    nn.ReLU(),
    nn.Dropout(0.4),
    nn.Linear(256, NUM_CLASSES)
)
model = model.to(DEVICE)

# 5. Loss, optimizer, scheduler
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5, verbose=True)

# 6. Training loop with validation
best_val_loss = float("inf")
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0
    for images, labels in train_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
    train_loss = running_loss / len(train_loader.dataset)

    # Validation
    model.eval()
    val_loss = 0
    correct = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * images.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    val_loss /= len(val_loader.dataset)
    val_acc = correct / len(val_loader.dataset)
    scheduler.step(val_loss)

    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), MODEL_SAVE_PATH)
        print("Saved best model.")

# 7. Evaluation on test set
model.load_state_dict(torch.load(MODEL_SAVE_PATH))
model.eval()
all_preds = []
all_probs = []
all_labels = []
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        outputs = model(images)
        probs = torch.softmax(outputs, dim=1)
        preds = outputs.argmax(dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

print("Classification Report:")
print(classification_report(all_labels, all_preds, target_names=class_names))

# Confusion matrix
cm = confusion_matrix(all_labels, all_preds)
print("Confusion Matrix:\n", cm)

# Top-2 accuracy (optional)
top2 = top_k_accuracy_score(all_labels, np.array(all_probs), k=2, labels=range(NUM_CLASSES))
print(f"Top-2 Accuracy: {top2:.4f}")

# 8. Inference utility
from PIL import Image

def predict_image(image_path, model, class_names, transform, device):
    model.eval()
    img = Image.open(image_path).convert("RGB")
    inp = transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        out = model(inp)
        probs = torch.softmax(out, dim=1)[0]
        top_prob, top_idx = torch.topk(probs, 3)
    return [(class_names[i], top_prob[i].item()) for i in top_idx]

# Example usage:
# preds = predict_image("some_image.jpg", model, class_names, val_transform, DEVICE)
# print(preds)


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 146MB/s]


Epoch 1/20 | Train Loss: 1.3176 | Val Loss: 0.3980 | Val Acc: 0.9175
Saved best model.
Epoch 2/20 | Train Loss: 0.4054 | Val Loss: 0.2566 | Val Acc: 0.9330
Saved best model.
Epoch 3/20 | Train Loss: 0.2735 | Val Loss: 0.1819 | Val Acc: 0.9433
Saved best model.
Epoch 4/20 | Train Loss: 0.1898 | Val Loss: 0.1904 | Val Acc: 0.9485
Epoch 5/20 | Train Loss: 0.1596 | Val Loss: 0.1271 | Val Acc: 0.9639
Saved best model.
Epoch 6/20 | Train Loss: 0.1247 | Val Loss: 0.1269 | Val Acc: 0.9691
Saved best model.
Epoch 7/20 | Train Loss: 0.1114 | Val Loss: 0.1333 | Val Acc: 0.9639
Epoch 8/20 | Train Loss: 0.1059 | Val Loss: 0.1223 | Val Acc: 0.9588
Saved best model.
Epoch 9/20 | Train Loss: 0.0713 | Val Loss: 0.1212 | Val Acc: 0.9588
Saved best model.
Epoch 10/20 | Train Loss: 0.0976 | Val Loss: 0.0993 | Val Acc: 0.9742
Saved best model.
Epoch 11/20 | Train Loss: 0.1050 | Val Loss: 0.1027 | Val Acc: 0.9742
Epoch 12/20 | Train Loss: 0.0636 | Val Loss: 0.1034 | Val Acc: 0.9639
Epoch 13/20 | Train Loss: