In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
import os
import copy
import numpy as np

In [2]:
# ================= CONFIGURATION =================
CLEAN_DATA_DIR = "./ImageNet_images"
DISTORTED_DATA_DIR = "./Distorted_Images/Low_Light"
MODEL_SAVE_PATH = "resnet50_clean_baseline.pth"

BATCH_SIZE = 32
NUM_EPOCHS = 10  # Typically 5-10 is enough for fine-tuning
LEARNING_RATE = 0.001
NUM_CLASSES = 5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# =================================================

In [3]:
def get_transforms():
    # Standard ImageNet stats
    normalize = transforms.Normalize([0.485, 0.456, 0.406], 
                                     [0.229, 0.224, 0.225])
    
    return {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize
        ]),
        'val': transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize
        ]),
    }

In [4]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs):
    print(f"\n[Step 1] Starting Training on CLEAN data for {num_epochs} epochs...")
    best_acc = 0.0
    best_model_wts = copy.deepcopy(model.state_dict())

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_corrects = 0

        # Training Phase
        for inputs, labels in train_loader:
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            _, preds = torch.max(outputs, 1)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = running_corrects.double() / len(train_loader.dataset)

        # Validation Phase (Check against Clean Baseline)
        model.eval()
        val_corrects = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = inputs.to(DEVICE)
                labels = labels.to(DEVICE)
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                val_corrects += torch.sum(preds == labels.data)
        
        val_acc = val_corrects.double() / len(val_loader.dataset)

        print(f'Epoch {epoch+1}/{num_epochs} | Train Acc: {epoch_acc:.4f} | Val Acc (Clean): {val_acc:.4f}')

        if val_acc > best_acc:
            best_acc = val_acc
            best_model_wts = copy.deepcopy(model.state_dict())

    print(f"Training Complete. Best Clean Accuracy: {best_acc:.4f}")
    model.load_state_dict(best_model_wts)
    return model

In [5]:
def evaluate_model(model, dataloader, description):
    model.eval()
    corrects = 0
    total = 0
    
    print(f"\n[Step 2] Evaluating on {description}...")
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            corrects += torch.sum(preds == labels.data)
            total += inputs.size(0)
            
    acc = corrects.double() / total
    print(f"--> Accuracy on {description}: {acc:.4f} ({corrects}/{total})")
    return acc.item()

In [6]:
# --- 1. PREPARE DATA ---
data_transforms = get_transforms()

# A. CLEAN DATA (Used for Training and Validating Baseline)
# We load it twice: once for training (with augmentation), once for validation (standard)
full_clean_dataset_train = datasets.ImageFolder(CLEAN_DATA_DIR, transform=data_transforms['train'])
full_clean_dataset_val   = datasets.ImageFolder(CLEAN_DATA_DIR, transform=data_transforms['val'])

# B. DISTORTED DATA (Used ONLY for Testing the Drop)
# Load the ENTIRE folder. We do not split this. This fixes the IndexError.
full_distorted_dataset = datasets.ImageFolder(DISTORTED_DATA_DIR, transform=data_transforms['val'])

print(f"Clean Images Found: {len(full_clean_dataset_train)}")
print(f"Distorted Images Found: {len(full_distorted_dataset)}")

# --- 2. CREATE SPLITS (Clean Data Only) ---
# We split Clean Data into Train (80%) and Validation (20%)
targets = full_clean_dataset_train.targets
train_idx, val_idx = train_test_split(
    np.arange(len(targets)), 
    test_size=0.2, 
    random_state=42, 
    stratify=targets
)

# Create Subsets
train_subset     = Subset(full_clean_dataset_train, train_idx)
clean_val_subset = Subset(full_clean_dataset_val, val_idx)

# Create Loaders
train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True)
clean_val_loader = DataLoader(clean_val_subset, batch_size=BATCH_SIZE, shuffle=False)
# Important: Distorted loader uses the FULL dataset, not indices
distorted_val_loader = DataLoader(full_distorted_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"--> Training set: {len(train_subset)} images")
print(f"--> Clean Validation set (R): {len(clean_val_subset)} images")
print(f"--> Distorted Test set (R'): {len(full_distorted_dataset)} images")

# --- 3. INITIALIZE MODEL ---
print("\nInitializing ResNet-50...")
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)

# Replace the final layer for 5 classes
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, NUM_CLASSES)
model = model.to(DEVICE)

# Define Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)

# --- 4. TRAIN (Phase 1 & 2) ---
model = train_model(model, train_loader, clean_val_loader, criterion, optimizer, NUM_EPOCHS)

# Save Model
torch.save(model.state_dict(), MODEL_SAVE_PATH)
print(f"Model saved to {MODEL_SAVE_PATH}")

# --- 5. MEASURE THE DROP ---
print("="*50)
print("FINAL RESULTS: QUANTIFYING DEGRADATION")
print("="*50)

# R: Accuracy on Clean Validation
acc_clean = evaluate_model(model, clean_val_loader, "Clean Validation Set (Baseline R)")

# R': Accuracy on Distorted Data
acc_distorted = evaluate_model(model, distorted_val_loader, "Full Distorted Set (Degraded R')")

drop = acc_clean - acc_distorted

print("\n" + "="*50)
print(f"Baseline Accuracy (R):   {acc_clean*100:.2f}%")
print(f"Distorted Accuracy (R'): {acc_distorted*100:.2f}%")
print(f"PERFORMANCE DROP:        {drop*100:.2f}%")
print("="*50)

Clean Images Found: 610
Distorted Images Found: 680
--> Training set: 488 images
--> Clean Validation set (R): 122 images
--> Distorted Test set (R'): 680 images

Initializing ResNet-50...

[Step 1] Starting Training on CLEAN data for 10 epochs...
Epoch 1/10 | Train Acc: 0.5020 | Val Acc (Clean): 0.9262
Epoch 2/10 | Train Acc: 0.8934 | Val Acc (Clean): 0.9426
Epoch 3/10 | Train Acc: 0.9180 | Val Acc (Clean): 0.9672
Epoch 4/10 | Train Acc: 0.9385 | Val Acc (Clean): 0.9590
Epoch 5/10 | Train Acc: 0.9447 | Val Acc (Clean): 0.9754
Epoch 6/10 | Train Acc: 0.9447 | Val Acc (Clean): 0.9836
Epoch 7/10 | Train Acc: 0.9611 | Val Acc (Clean): 0.9918
Epoch 8/10 | Train Acc: 0.9672 | Val Acc (Clean): 0.9836
Epoch 9/10 | Train Acc: 0.9570 | Val Acc (Clean): 0.9836
Epoch 10/10 | Train Acc: 0.9713 | Val Acc (Clean): 0.9836
Training Complete. Best Clean Accuracy: 0.9918
Model saved to resnet50_clean_baseline.pth
FINAL RESULTS: QUANTIFYING DEGRADATION

[Step 2] Evaluating on Clean Validation Set (Baseli