In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
import os
from tqdm import tqdm
import timm  # ✅ Use timm for InceptionResNetV2

# ------------------ Configs ------------------ #
data_dir = '/kaggle/input/ip02-dataset/classification/'
batch_size = 64
learning_rate = 0.01
num_epochs = 30
lr_step_size = 40
lr_gamma = 0.1
momentum = 0.9
weight_decay = 0.0005
dropout_rate = 0.3
input_size = 299  # ✅ InceptionResNetV2 prefers 299x299 input
num_workers = 4

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ------------------ Data Transforms ------------------ #
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5],
                             [0.5, 0.5, 0.5])  # Optional: use standard normalization if pretrained model differs
    ]),
    'val': transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5],
                             [0.5, 0.5, 0.5])
    ]),
    'test': transforms.Compose([
        transforms.Resize((input_size, input_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5],
                             [0.5, 0.5, 0.5])
    ])
}

# ------------------ Datasets & Dataloaders ------------------ #
image_datasets = {
    x: datasets.ImageFolder(os.path.join(data_dir, x), transform=data_transforms[x])
    for x in ['train', 'val', 'test']
}

num_classes = len(image_datasets['train'].classes)
print("✅ num_classes set to:", num_classes)

dataloaders = {
    x: DataLoader(image_datasets[x], batch_size=batch_size, shuffle=(x != 'test'), num_workers=num_workers)
    for x in ['train', 'val', 'test']
}

class_names = image_datasets['train'].classes

# ------------------ Model Setup (InceptionResNetV2) ------------------ #
model = timm.create_model('inception_resnet_v2', pretrained=True)
model.classifier = nn.Sequential(
    nn.Dropout(p=dropout_rate),
    nn.Linear(model.classifier.in_features, num_classes)
)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate,
                      momentum=momentum, weight_decay=weight_decay)
scheduler = StepLR(optimizer, step_size=lr_step_size, gamma=lr_gamma)

# ------------------ Training Loop ------------------ #
def train_model():
    best_val_acc = 0.0
    best_model_path = 'best_inceptionresnet_model.pth'

    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        print('-' * 20)

        for phase in ['train', 'val']:
            model.train() if phase == 'train' else model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(image_datasets[phase])
            epoch_acc = running_corrects.double() / len(image_datasets[phase])

            print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

            if phase == 'val' and epoch_acc > best_val_acc:
                best_val_acc = epoch_acc
                torch.save(model.state_dict(), best_model_path)
                print(f"✅ Best model saved at epoch {epoch+1} with val acc: {epoch_acc:.4f}")

        scheduler.step()

    print(f"\n🏁 Training completed. Best Val Acc: {best_val_acc:.4f}")

# ------------------ Evaluation on Test Set ------------------ #
def evaluate_model():
    model.eval()
    all_preds = []
    all_labels = []
    top1_correct = 0
    top5_correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in tqdm(dataloaders['test'], desc="Evaluating"):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, top1_preds = outputs.topk(1, dim=1)
            top5_preds = outputs.topk(5, dim=1).indices

            all_preds.extend(top1_preds.squeeze().cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            top1_correct += (top1_preds.squeeze() == labels).sum().item()

            for i in range(labels.size(0)):
                if labels[i] in top5_preds[i]:
                    top5_correct += 1

            total += labels.size(0)

    top1_acc = top1_correct / total
    top5_acc = top5_correct / total

    print(f"\n✅ Top-1 Accuracy: {top1_acc:.4f}")
    print(f"✅ Top-5 Accuracy: {top5_acc:.4f}\n")

    print("📊 Classification Report:")
    print(classification_report(all_labels, all_preds, target_names=class_names))

    print("🧾 Confusion Matrix:")
    print(confusion_matrix(all_labels, all_preds))


# ------------------ Run ------------------ #
if __name__ == "__main__":
    train_model()
    evaluate_model()
