In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Subset, ConcatDataset
from sklearn.model_selection import train_test_split
import numpy as np
import os
import time
import copy

In [2]:
# ================= CONFIGURATION =================
# Paths
CLEAN_DATA_DIR     = "../ImageNet_images"
# UPDATED: Pointing to the Water Occlusion dataset
DISTORTED_DATA_DIR = "../Distorted_Images/Water_Occlusion"
MODEL_SAVE_PATH    = "resnet50_finetuned_water.pth"

BATCH_SIZE = 32
NUM_EPOCHS = 10
LEARNING_RATE = 0.001
NUM_CLASSES = 5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using Device: {DEVICE}")

Using Device: cuda


In [3]:
# ================= HELPERS =================
def get_transforms():
    # Training: Augmentation to prevent overfitting
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    # Validation: No augmentation
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    return train_transform, val_transform

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10):
    since = time.time()
    best_acc = 0.0
    best_model_wts = copy.deepcopy(model.state_dict())
    
    train_hist = []
    val_hist = []

    print(f"Starting Fine-Tuning for {num_epochs} epochs...")

    for epoch in range(num_epochs):
        # --- Training Phase ---
        model.train()
        running_loss = 0.0
        running_corrects = 0
        
        for inputs, labels in train_loader:
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            _, preds = torch.max(outputs, 1)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = running_corrects.double() / len(train_loader.dataset)
        train_hist.append(epoch_acc.item())

        # --- Validation Phase ---
        model.eval()
        val_corrects = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = inputs.to(DEVICE)
                labels = labels.to(DEVICE)
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                val_corrects += torch.sum(preds == labels.data)
        
        val_acc = val_corrects.double() / len(val_loader.dataset)
        val_hist.append(val_acc.item())

        print(f'Epoch {epoch+1}/{num_epochs} | Train Acc: {epoch_acc:.4f} | Mixed Val Acc: {val_acc:.4f}')

        if val_acc > best_acc:
            best_acc = val_acc
            best_model_wts = copy.deepcopy(model.state_dict())

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best Mixed Val Acc: {best_acc:.4f}')

    model.load_state_dict(best_model_wts)
    return model

In [4]:
# ================= EXECUTION =================

# 1. SETUP TRANSFORMS
train_tf, val_tf = get_transforms()

# 2. LOAD DATASETS
if not os.path.exists(DISTORTED_DATA_DIR):
    print(f"Error: {DISTORTED_DATA_DIR} not found. Please run water generation script.")
else:
    clean_ds = datasets.ImageFolder(CLEAN_DATA_DIR, transform=train_tf)
    distorted_ds = datasets.ImageFolder(DISTORTED_DATA_DIR, transform=train_tf)

    print(f"Clean Images: {len(clean_ds)}")
    print(f"Distorted Images: {len(distorted_ds)}")

    # 3. CREATE INDEPENDENT SPLITS
    # We split Clean data 80/20
    targets_clean = clean_ds.targets
    clean_train_idx, clean_val_idx = train_test_split(
        np.arange(len(targets_clean)), 
        test_size=0.2, 
        random_state=42, 
        stratify=targets_clean
    )

    # We split Distorted data 80/20 (independently)
    targets_distorted = distorted_ds.targets
    dist_train_idx, dist_val_idx = train_test_split(
        np.arange(len(targets_distorted)), 
        test_size=0.2, 
        random_state=42, 
        stratify=targets_distorted
    )

    # Create Subsets using their OWN indices
    # -- Clean --
    clean_train_subset = Subset(datasets.ImageFolder(CLEAN_DATA_DIR, transform=train_tf), clean_train_idx)
    clean_val_subset   = Subset(datasets.ImageFolder(CLEAN_DATA_DIR, transform=val_tf), clean_val_idx)

    # -- Distorted --
    distorted_train_subset = Subset(datasets.ImageFolder(DISTORTED_DATA_DIR, transform=train_tf), dist_train_idx)
    distorted_val_subset   = Subset(datasets.ImageFolder(DISTORTED_DATA_DIR, transform=val_tf), dist_val_idx)

    # ====================================================================

    # 4. COMBINE DATASETS (The "Mix")
    # Train on 80% Clean + 80% Distorted
    mixed_train_dataset = ConcatDataset([clean_train_subset, distorted_train_subset])

    # Validate on 20% Clean + 20% Distorted (Held out)
    mixed_val_dataset   = ConcatDataset([clean_val_subset, distorted_val_subset])

    # 5. DATALOADERS
    train_loader = DataLoader(mixed_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader   = DataLoader(mixed_val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    print(f"Total Training Images: {len(mixed_train_dataset)} (Clean+Distorted)")
    print(f"Total Validation Images: {len(mixed_val_dataset)} (Clean+Distorted)")

    # 6. INITIALIZE & TRAIN
    model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, NUM_CLASSES)
    model = model.to(DEVICE)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)

    # Start Fine-Tuning
    tuned_model = train_model(model, train_loader, val_loader, criterion, optimizer, NUM_EPOCHS)

    # 7. SAVE
    torch.save(tuned_model.state_dict(), MODEL_SAVE_PATH)
    print(f"Fine-Tuned Model saved to {MODEL_SAVE_PATH}")

Clean Images: 610
Distorted Images: 530
Total Training Images: 912 (Clean+Distorted)
Total Validation Images: 228 (Clean+Distorted)
Starting Fine-Tuning for 10 epochs...
Epoch 1/10 | Train Acc: 0.5559 | Mixed Val Acc: 0.8246
Epoch 2/10 | Train Acc: 0.8235 | Mixed Val Acc: 0.8553
Epoch 3/10 | Train Acc: 0.8607 | Mixed Val Acc: 0.8904
Epoch 4/10 | Train Acc: 0.8662 | Mixed Val Acc: 0.9167
Epoch 5/10 | Train Acc: 0.8991 | Mixed Val Acc: 0.9167
Epoch 6/10 | Train Acc: 0.9123 | Mixed Val Acc: 0.9298
Epoch 7/10 | Train Acc: 0.9430 | Mixed Val Acc: 0.9342
Epoch 8/10 | Train Acc: 0.9309 | Mixed Val Acc: 0.9430
Epoch 9/10 | Train Acc: 0.9375 | Mixed Val Acc: 0.9386
Epoch 10/10 | Train Acc: 0.9441 | Mixed Val Acc: 0.9123
Training complete in 6m 4s
Best Mixed Val Acc: 0.9430
Fine-Tuned Model saved to resnet50_finetuned_water.pth


In [6]:
# ================= EXECUTION (V3: With Preprocessed Real Data) =================

# 1. SETUP TRANSFORMS
train_tf, val_tf = get_transforms()

# 2. PATHS
# Use the PREPROCESSED folder for real data
REAL_DATA_DIR      = "../Manual_Images/Water_Occlusion_Preproccessed/Water_Occlusion"
MODEL_LOAD_PATH = "resnet50_finetuned_water.pth" # Load the previous model
MODEL_SAVE_PATH = "resnet50_finetuned_water_v2.pth" # Save with a new name

# 3. LOAD DATASETS
clean_ds = datasets.ImageFolder(CLEAN_DATA_DIR, transform=train_tf)
syn_distorted_ds = datasets.ImageFolder(DISTORTED_DATA_DIR, transform=train_tf)

if os.path.exists(REAL_DATA_DIR):
    print(f"Found Preprocessed Real-World Data at {REAL_DATA_DIR}")
    real_distorted_ds = datasets.ImageFolder(REAL_DATA_DIR, transform=train_tf)
else:
    print("Warning: Real Data folder not found. Using only synthetic.")
    real_distorted_ds = None

print(f"Clean Images: {len(clean_ds)}")
print(f"Synthetic Distorted: {len(syn_distorted_ds)}")
if real_distorted_ds: print(f"Real Distorted: {len(real_distorted_ds)}")

# 4. CREATE SPLITS
# A. Clean Splits
targets_clean = clean_ds.targets
clean_train_idx, clean_val_idx = train_test_split(
    np.arange(len(targets_clean)), test_size=0.2, random_state=42, stratify=targets_clean
)
clean_train_sub = Subset(datasets.ImageFolder(CLEAN_DATA_DIR, transform=train_tf), clean_train_idx)
clean_val_sub   = Subset(datasets.ImageFolder(CLEAN_DATA_DIR, transform=val_tf), clean_val_idx)

# B. Synthetic Splits
targets_syn = syn_distorted_ds.targets
syn_train_idx, syn_val_idx = train_test_split(
    np.arange(len(targets_syn)), test_size=0.2, random_state=42, stratify=targets_syn
)
syn_train_sub = Subset(datasets.ImageFolder(DISTORTED_DATA_DIR, transform=train_tf), syn_train_idx)
syn_val_sub   = Subset(datasets.ImageFolder(DISTORTED_DATA_DIR, transform=val_tf), syn_val_idx)

# C. Real Splits (If available)
train_datasets = [clean_train_sub, syn_train_sub]
val_datasets   = [clean_val_sub, syn_val_sub]

if real_distorted_ds:
    targets_real = real_distorted_ds.targets
    # Ensure we have enough samples to split
    if len(targets_real) > 5:
        real_train_idx, real_val_idx = train_test_split(
            np.arange(len(targets_real)), test_size=0.2, random_state=42, stratify=targets_real
        )
        real_train_sub = Subset(datasets.ImageFolder(REAL_DATA_DIR, transform=train_tf), real_train_idx)
        real_val_sub   = Subset(datasets.ImageFolder(REAL_DATA_DIR, transform=val_tf), real_val_idx)
        
        train_datasets.append(real_train_sub)
        val_datasets.append(real_val_sub)

# 5. COMBINE EVERYTHING
mixed_train_dataset = ConcatDataset(train_datasets)
mixed_val_dataset   = ConcatDataset(val_datasets)

# 6. DATALOADERS
train_loader = DataLoader(mixed_train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(mixed_val_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Total Training: {len(mixed_train_dataset)} images")
print(f"Total Validation: {len(mixed_val_dataset)} images")

# 7. LOAD PREVIOUS MODEL & RETRAIN
print(f"Loading weights from {MODEL_LOAD_PATH}...")
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, NUM_CLASSES)

# Load the saved state dict
model.load_state_dict(torch.load(MODEL_LOAD_PATH))
model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()
# Lower learning rate slightly since we are fine-tuning an already tuned model
optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE * 0.1, momentum=0.9)

# Start Fine-Tuning
tuned_model = train_model(model, train_loader, val_loader, criterion, optimizer, NUM_EPOCHS)

# 8. SAVE
torch.save(tuned_model.state_dict(), MODEL_SAVE_PATH)
print(f"Retrained Super-Model saved to {MODEL_SAVE_PATH}")

Found Preprocessed Real-World Data at ../Manual_Images/Water_Occlusion_Preproccessed/Water_Occlusion
Clean Images: 610
Synthetic Distorted: 530
Real Distorted: 107
Total Training: 997 images
Total Validation: 250 images
Loading weights from resnet50_finetuned_water.pth...
Starting Fine-Tuning for 10 epochs...
Epoch 1/10 | Train Acc: 0.9348 | Mixed Val Acc: 0.9320
Epoch 2/10 | Train Acc: 0.9338 | Mixed Val Acc: 0.9280
Epoch 3/10 | Train Acc: 0.9318 | Mixed Val Acc: 0.9280
Epoch 4/10 | Train Acc: 0.9438 | Mixed Val Acc: 0.9280
Epoch 5/10 | Train Acc: 0.9318 | Mixed Val Acc: 0.9440
Epoch 6/10 | Train Acc: 0.9418 | Mixed Val Acc: 0.9400
Epoch 7/10 | Train Acc: 0.9488 | Mixed Val Acc: 0.9280
Epoch 8/10 | Train Acc: 0.9468 | Mixed Val Acc: 0.9280
Epoch 9/10 | Train Acc: 0.9388 | Mixed Val Acc: 0.9440
Epoch 10/10 | Train Acc: 0.9418 | Mixed Val Acc: 0.9320
Training complete in 6m 34s
Best Mixed Val Acc: 0.9440
Retrained Super-Model saved to resnet50_finetuned_water_v2.pth
