# Anime Gender Classification

This notebook implements the training and evaluation pipeline for anime gender classification using PyTorch.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from PIL import Image
import os

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 1. Data Setup
Loading data using `dataloader.py` module.

In [None]:
# Import dataloader
from dataloader import get_dataloaders

# Initialize DataLoaders
BATCH_SIZE = 32
train_loader, val_loader, test_loader, class_names = get_dataloaders(batch_size=BATCH_SIZE, root_dir='.', val_split=0.2)

print(f"Classes: {class_names}")
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

## 2. Model Setup & Training
EfficientNetV2-M with Fine-Tuning and Early Stopping.

In [None]:
# Load Pretrained EfficientNetV2-M
weights = models.EfficientNet_V2_M_Weights.DEFAULT
model = models.efficientnet_v2_m(weights=weights)

# Freeze all parameters
for param in model.parameters():
    param.requires_grad = False

# Unfreeze Classifier and Last 2 Feature Blocks ("laft 2 layer")
# EfficientNet features are a Sequential of blocks. We unfreeze the last 2.
for param in model.classifier.parameters():
    param.requires_grad = True

# Unfreezing last 2 feature blocks
# model.features is a Sequential container. accessing [-1] and [-2] works.
for param in model.features[-1].parameters():
    param.requires_grad = True
for param in model.features[-2].parameters():
    param.requires_grad = True

# Modify Classifier for 2 classes
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, 2)
model = model.to(device)

print("Model setup complete. Layers unfreezed: Classifier, Features[-1], Features[-2]")

# Optimizer and Loss
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)

## 3. Training Loop
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=35, patience=5):
    best_model_wts = None
    best_loss = float('inf')
    counter = 0 # Early stopping counter
    
    train_losses = []
    val_losses = []
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
                dataloader = train_loader
            else:
                model.eval()
                dataloader = val_loader

            running_loss = 0.0
            running_corrects = 0
            
            # Iterate over data
            for inputs, labels in tqdm(dataloader, desc=f"{phase}"):
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                # Forward
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # Backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # Statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloader.dataset)
            epoch_acc = running_corrects.double() / len(dataloader.dataset)
            
            if phase == 'train':
                train_losses.append(epoch_loss)
            else:
                val_losses.append(epoch_loss)

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # Deep copy the model & Early Stopping Logic
            if phase == 'val':
                if epoch_loss < best_loss:
                    best_loss = epoch_loss
                    best_model_wts = model.state_dict()
                    torch.save(model.state_dict(), 'best_model.pth')
                    counter = 0
                    print("Validation loss improved. Model saved.")
                else:
                    counter += 1
                    print(f"EarlyStopping counter: {counter} out of {patience}")
                    
        if counter >= patience:
            print("Early stopping triggered.")
            break
            
        print()

    print(f'Best val loss: {best_loss:4f}')
    if best_model_wts:
        model.load_state_dict(best_model_wts)
    return model, train_losses, val_losses

# Train the model
model, train_losses, val_losses = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=35, patience=5)

# Plot Loss history
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

## 3. Evaluation
Evaluate on Validation/Test Set.

In [None]:
def evaluate_model(model, dataloader, device, class_names):
    """
    Evaluates the model and prints the classification report and plots the confusion matrix.
    """
    model.eval()
    y_true = []
    y_pred = []
    
    print("Starting evaluation...")
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())
            
    # 1. Classification Report
    print("\n--- Classification Report ---")
    print(classification_report(y_true, y_pred, target_names=class_names))
    
    # 2. Confusion Matrix
    print("--- Confusion Matrix ---")
    cm = confusion_matrix(y_true, y_pred)
    
    plt.figure(figsize=(10, 7))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title('Confusion Matrix')
    plt.show()
    
    return y_true, y_pred

# Evaluate on Validation Set
print("Evaluating on Validation Set:")
evaluate_model(model, val_loader, device, class_names)