In [None]:
import random
import numpy as np
import torch
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score
import seaborn as sns
import matplotlib.pyplot as plt
import os

# Set random seeds for reproducibility
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  # If using multiple GPUs

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

##################################
# 1. Data Augmentations & Loaders
##################################
image_size = 384

train_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.RandomApply([
       transforms.RandomRotation(15),
       transforms.RandomHorizontalFlip(),
       transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
       transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0))
    ], p=0.8),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Paths to your PlantDoc dataset
train_dir = '../PlantDoc-Dataset/train'
test_dir = '../PlantDoc-Dataset/test'

train_dataset = datasets.ImageFolder(root=train_dir, transform=train_transform)
test_dataset  = datasets.ImageFolder(root=test_dir,  transform=test_transform)

num_classes = len(train_dataset.classes)
class_names = train_dataset.classes
print("Number of classes:", num_classes)
print("Classes:", class_names)

batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

print(f"Dataset Loaded! {num_classes} classes found.")
print("Number of images in train_loader:", len(train_loader.dataset))
print("Number of images in test_loader:", len(test_loader.dataset))

##################################
# 2. Define the Model Architecture
##################################
class MLPClassifier(nn.Module):
    def __init__(self, in_features, num_classes):
        super(MLPClassifier, self).__init__()
        self.classifier = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Linear(512, num_classes)
        )
    def forward(self, x):
        return self.classifier(x)

def build_model():
    model = models.resnet101(weights=None)
    model.fc = MLPClassifier(in_features=2048, num_classes=num_classes)
    return model

##################################
# 3. Setup Loss, Optimizer, and Scheduler
##################################
def setup_training(model, lr=1e-5):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=5e-5)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
    return criterion, optimizer, scheduler

##################################
# 4. Fine-Tuning and Evaluation
##################################
def fine_tune_and_evaluate(checkpoint_path, run_idx=1, num_finetune_epochs=60):
    """
    run_idx is used to generate a unique filename when saving/loading best model.
    """
    print(f"\n=== Starting Fine-Tuning for Checkpoint: {checkpoint_path}, Run {run_idx} ===")
    
    # Build model and load SSL pre-trained weights from the given checkpoint.
    model = build_model()
    pretrained_weights = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(pretrained_weights)
    print("Loaded SSL pre-trained weights successfully.")
    model = model.to(device)
    
    criterion, optimizer, scheduler = setup_training(model, lr=1e-5)
    best_acc = 0.0
    
    # The unique filename for saving/loading the best model in this run
    best_model_filename = f"plantdoc_best_finetuned_{checkpoint_path}{run_idx}.pth"
    
    # Fine-Tuning Loop
    for epoch in range(num_finetune_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            correct += preds.eq(labels).sum().item()
            total += labels.size(0)

        train_loss = running_loss / total
        train_acc = correct / total

        # Validation Loop
        model.eval()
        val_running_loss = 0.0
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_running_loss += loss.item() * images.size(0)
                _, preds = torch.max(outputs, 1)
                val_correct += preds.eq(labels).sum().item()
                val_total += labels.size(0)

        val_loss = val_running_loss / val_total
        val_acc = val_correct / val_total

        scheduler.step()

        print(f"Epoch [{epoch+1}/{num_finetune_epochs}] - "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc*100:.2f}% | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc*100:.2f}%")

        # Save best model if improved
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), best_model_filename)
            print(f"Saved Best Model!")
    
    print("Fine-Tuning Complete for this run!\n")

    # =======================
    # Final Evaluation Phase
    # =======================
    # Load the best model from this run
    model.load_state_dict(torch.load(best_model_filename))
    model.eval()
    
    all_labels = []
    all_preds = []
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
    
    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)
    
    # Overall metrics
    accuracy = (all_preds == all_labels).mean()
    macro_precision = precision_score(all_labels, all_preds, average='macro')
    macro_recall    = recall_score(all_labels, all_preds, average='macro')
    macro_f1        = f1_score(all_labels, all_preds, average='macro')
    
    print(f"=== Final Evaluation for {checkpoint_path}, Run {run_idx} ===")
    print(f"Best model path  : {best_model_filename}")
    print(f"Accuracy         : {accuracy:.4f}")
    print(f"Macro Precision  : {macro_precision:.4f}")
    print(f"Macro Recall     : {macro_recall:.4f}")
    print(f"Macro F1-score   : {macro_f1:.4f}\n")
    
    # Per-class metrics
    per_class_precision = precision_score(all_labels, all_preds, average=None)
    per_class_recall    = recall_score(all_labels, all_preds, average=None)
    per_class_f1        = f1_score(all_labels, all_preds, average=None)
    
    print("=== Per-Class Metrics ===")
    for i, cls in enumerate(class_names):
        print(f"Class: {cls:15s} | Precision: {per_class_precision[i]:.4f} | "
              f"Recall: {per_class_recall[i]:.4f} | F1-score: {per_class_f1[i]:.4f}")
    
    # Plot Confusion Matrix
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                xticklabels=class_names, yticklabels=class_names)
    plt.title(f"Confusion Matrix - PlantDoc Test Set\n({checkpoint_path}, run {run_idx})")
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.tight_layout()
    plt.show()

##################################
# 5. Loop Over Checkpoints for Classification
##################################
checkpoint_list = [
    "byol_mim_contrastive_epoch85.pth"
]
for i, ckpt in enumerate(checkpoint_list, start=1):
    fine_tune_and_evaluate(ckpt, run_idx=i, num_finetune_epochs=45)
