In [None]:
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, accuracy_score, precision_score, recall_score, f1_score
from PIL import Image
from tqdm import tqdm
import argparse
import collections

# ================================
# Custom Dataset Definition
# ================================
class FaceDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []
        self.class_to_idx = {}
        self.idx_to_class = {}
        self.image_labels = []

        persons = sorted(os.listdir(root_dir))
        for idx, person in enumerate(persons):
            self.class_to_idx[person] = idx
            self.idx_to_class[idx] = person

            person_folder = os.path.join(root_dir, person)
            for root, _, files in os.walk(person_folder):
                for file in files:
                    if file.lower().endswith(('jpg', 'jpeg', 'png')):
                        full_path = os.path.join(root, file)
                        self.samples.append((full_path, idx))
                        self.image_labels.append((full_path, idx))  # for optional distortion-type analysis

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

# ================================
# Evaluation Function
# ================================
def evaluate(model, loader, device, split_name):
    model.eval()
    y_true, y_pred = [], []

    with torch.no_grad():
        for images, labels in tqdm(loader, desc=f"Evaluating on {split_name} set"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())

    acc = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, average='macro', zero_division=0)
    rec = recall_score(y_true, y_pred, average='macro', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)

    print(f"\n {split_name} Evaluation:")
    print(f"Accuracy:  {acc:.4f}")
    print(f"Precision: {prec:.4f}")
    print(f"Recall:    {rec:.4f}")
    print(f"F1-Score:  {f1:.4f}")
    print("\nClassification Report:")
    print(classification_report(y_true, y_pred, zero_division=0))

# ================================
# Optional: Print Samples per Class
# ================================
def print_class_distribution(dataset):
    counter = collections.Counter()
    for _, label in dataset.samples:
        counter[label] += 1
    print("\n Class distribution (Label Index → Sample Count):")
    for k, v in sorted(counter.items()):
        print(f"Class {k}: {v} images")

# ================================
# Main
# ================================
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_path', type=str, default='/content/drive/MyDrive/Task_B')
    parser.add_argument('--model_path', type=str, default='model_b.pth')
    parser.add_argument('--img_size', type=int, default=224)
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--num_workers', type=int, default=4)
    args = parser.parse_args()

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Transforms
    eval_transforms = transforms.Compose([
        transforms.Resize((args.img_size, args.img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
    ])

    # Load Datasets
    train_dataset = FaceDataset(os.path.join(args.data_path, "train"), transform=eval_transforms)
    val_dataset = FaceDataset(os.path.join(args.data_path, "val"), transform=eval_transforms)

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

    num_classes = len(train_dataset.class_to_idx)

    # Load Model
    model = models.resnet50(pretrained=False)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    model.load_state_dict(torch.load(args.model_path, map_location=device))
    model = model.to(device)

    # Optional: show how many samples per class
    print_class_distribution(train_dataset)
    print_class_distribution(val_dataset)

    # Evaluate on both train and val sets
    evaluate(model, train_loader, device, split_name="Train")
    evaluate(model, val_loader, device, split_name="Validation")
