In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
import os
import copy
import numpy as np

In [2]:
# ================= CONFIGURATION =================
CLEAN_DATA_DIR = "./ImageNet_images"
# POINTING TO WATER DISTORTION DATASET
DISTORTED_DATA_DIR = "./Distorted_Images/Water_Occlusion"

# WE USE THE SAME BASELINE MODEL TRAINED EARLIER
MODEL_LOAD_PATH = "./resnet50_clean_baseline.pth"

BATCH_SIZE = 32
NUM_CLASSES = 5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# =================================================

In [3]:
def get_transforms():
    # Standard ImageNet stats
    normalize = transforms.Normalize([0.485, 0.456, 0.406], 
                                     [0.229, 0.224, 0.225])
    
    return {
        'val': transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize
        ]),
    }

In [4]:
def evaluate_model(model, dataloader, description):
    model.eval()
    corrects = 0
    total = 0
    
    print(f"\nEvaluating on {description}...")
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            corrects += torch.sum(preds == labels.data)
            total += inputs.size(0)
            
    acc = corrects.double() / total
    print(f"--> Accuracy on {description}: {acc:.4f} ({corrects}/{total})")
    return acc.item()

In [5]:
# --- 1. PREPARE DATA ---
data_transforms = get_transforms()

if not os.path.exists(DISTORTED_DATA_DIR):
    print(f"ERROR: Distorted directory {DISTORTED_DATA_DIR} not found.")
    print("Please run 'create_water_occlusion_with_masks.py' first.")
else:
    # A. CLEAN DATA (For Baseline R reference)
    full_clean_dataset_val = datasets.ImageFolder(CLEAN_DATA_DIR, transform=data_transforms['val'])
    
    # Create the same validation split as before for consistency
    targets = full_clean_dataset_val.targets
    _, val_idx = train_test_split(
        np.arange(len(targets)), 
        test_size=0.2, 
        random_state=42, 
        stratify=targets
    )
    clean_val_subset = Subset(full_clean_dataset_val, val_idx)
    clean_val_loader = DataLoader(clean_val_subset, batch_size=BATCH_SIZE, shuffle=False)

    # B. DISTORTED DATA (Water Occlusion)
    full_distorted_dataset = datasets.ImageFolder(DISTORTED_DATA_DIR, transform=data_transforms['val'])
    distorted_val_loader = DataLoader(full_distorted_dataset, batch_size=BATCH_SIZE, shuffle=False)

    print(f"Clean Validation Images (R): {len(clean_val_subset)}")
    print(f"Water Distorted Images (R'): {len(full_distorted_dataset)}")

    # --- 2. LOAD MODEL ---
    print("\nInitializing ResNet-50...")
    model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
    model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
    
    if os.path.exists(MODEL_LOAD_PATH):
        print(f"Loading pre-trained clean weights from {MODEL_LOAD_PATH}...")
        model.load_state_dict(torch.load(MODEL_LOAD_PATH))
    else:
        print(f"WARNING: {MODEL_LOAD_PATH} not found! You need to run 'Baseline.ipynb' first to train the clean model.")
    
    model = model.to(DEVICE)

    # --- 3. MEASURE THE DROP ---
    print("="*50)
    print("RESULTS: WATER OCCLUSION DEGRADATION")
    print("="*50)

    # R: Accuracy on Clean Validation
    acc_clean = evaluate_model(model, clean_val_loader, "Clean Validation Set (Baseline R)")

    # R': Accuracy on Water Distorted Data
    acc_distorted = evaluate_model(model, distorted_val_loader, "Water Distorted Set (Degraded R')")

    drop = acc_clean - acc_distorted

    print("\n" + "="*50)
    print(f"Baseline Accuracy (R):   {acc_clean*100:.2f}%")
    print(f"Water Accuracy (R'):     {acc_distorted*100:.2f}%")
    print(f"PERFORMANCE DROP:        {drop*100:.2f}%")
    print("="*50)

Clean Validation Images (R): 106
Water Distorted Images (R'): 544

Initializing ResNet-50...
Loading pre-trained clean weights from ./resnet50_clean_baseline.pth...
RESULTS: WATER OCCLUSION DEGRADATION

Evaluating on Clean Validation Set (Baseline R)...
--> Accuracy on Clean Validation Set (Baseline R): 0.9717 (103/106)

Evaluating on Water Distorted Set (Degraded R')...
--> Accuracy on Water Distorted Set (Degraded R'): 0.5037 (274/544)

Baseline Accuracy (R):   97.17%
Water Accuracy (R'):     50.37%
PERFORMANCE DROP:        46.80%
