## Semi-Supervised Learning with ResNet101 for Sewer Defect Classification

### 1. Library Imports and Setup
- Import essential PyTorch libraries and utilities
- Set up Weights & Biases (wandb) for experiment tracking
- Configure API keys and project settings

In [None]:
import os
import torch
import torch.nn as nn
import torchvision.models as models
import pandas as pd
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, Subset
from PIL import Image
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.nn import BCEWithLogitsLoss
from sklearn.metrics import precision_recall_fscore_support
import wandb
from torchvision.utils import save_image

# Set the wandb API key
os.environ['WANDB_API_KEY'] = 'YOUR_WANDB_API_KEY'  # Replace with your key

# Initialize wandb
wandb.login()
wandb.init(project="your_project_name")  # Replace with your project name

### 2. Dataset Preparation
- Define CustomDataset class for handling multi-label image data
- Implement data transformations for training and inference:
  - Resize images to 224x224
  - Apply data augmentation (horizontal flip, color jitter)
  - Normalize with domain-specific mean/std values (from Sewer-ML paper)
- Create separate transforms for training and inference

In [None]:
# Custom dataset class
class CustomDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        self.annotations = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        img_path = f"{self.img_dir}/{self.annotations.iloc[idx, 0]}"
        image = Image.open(img_path).convert("RGB")
        labels = torch.tensor(self.annotations.iloc[idx, 1:].astype('float32').values)
        if self.transform:
            image = self.transform(image)
        return image, labels
    

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.523, 0.453, 0.345], std=[0.210, 0.199, 0.154])
])


inference_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.523, 0.453, 0.345], std=[0.210, 0.199, 0.154])
])

### 3. Data Loading
- Initialize train and validation datasets
- Configure DataLoader with:
  - Batch size of 128
  - 8 worker processes
  - Pin memory for faster GPU transfer
- Enable shuffling for training data

In [None]:
# Create dataset instances
train_dataset = CustomDataset(csv_file='${TRAIN_BATCH_CSV_ROOT}',
                              img_dir='${TRAIN_IMAGE_ROOT}', transform=transform)
val_dataset = CustomDataset(csv_file='${VAL_BATCH_CSV_ROOT}', 
                            img_dir='${VAL_IMAGE_ROOT}', transform=inference_transform)

# Create dataloaderss
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=8, pin_memory=True)

### 4. SSL Model Loading
Loads the self-supervised pre-trained model:
- Loads checkpoint from SwAV pre-training
- Extracts state dictionary and handles key mapping
- Initializes ResNet101 architecture
- Removes prefix from state dict keys
- Key difference: Loads SSL pre-trained weights instead of ImageNet weights

In [None]:
# Load the entire checkpoint
checkpoint = torch.load('../checkpoints/SSL_model_weights/${SSL_MODEL_WEIGHTS}', map_location=torch.device('cpu'))

# Check the keys in the checkpoint
print("Keys in the checkpoint:", list(checkpoint.keys()))

# Extract the state_dict from 'classy_state_dict'
classy_state_dict = checkpoint['classy_state_dict']

# Check the keys inside classy_state_dict
print("Keys in 'classy_state_dict':", list(classy_state_dict.keys()))

# Extract the state_dict from 'base_model'
base_model_state_dict = classy_state_dict['base_model']['model']['trunk']

# Print some keys to understand the structure
print("Keys in the extracted base_model_state_dict:")
print(list(base_model_state_dict.keys())[:10])

model = models.resnet101(weights=None)

# Function to remove prefix
def remove_prefix(state_dict, prefix):
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith(prefix):
            new_key = k[len(prefix):]  # Remove the prefix
            new_state_dict[new_key] = v
        else:
            new_state_dict[k] = v
    return new_state_dict

# Remove '_feature_blocks.' prefix from keys
filtered_state_dict = remove_prefix(base_model_state_dict, '_feature_blocks.')

# Print filtered keys to verify
print("Keys in the filtered state_dict:")
print(list(filtered_state_dict.keys())[:10])

# Filter out keys that do not match
filtered_state_dict = {k: v for k, v in filtered_state_dict.items() if k in model.state_dict()}

# Load the filtered state dictionary into the model
missing_keys, unexpected_keys = model.load_state_dict(filtered_state_dict, strict=False)

# Print missing and unexpected keys
print("Missing keys after loading state_dict:")
print(missing_keys)

print("Unexpected keys after loading state_dict:")
print(unexpected_keys)

### 5. Model Architecture Modification
Adapts the pre-trained model for the defect classification task:
- Replaces final FC layer with 17-class output
- Moves model to available device (GPU/CPU)
- Verifies parameter statistics

In [None]:
# Redefine the fully connected layer to match the number of classes in your task (e.g., 18 classes)
num_ftrs = model.fc.in_features
model.fc = torch.nn.Sequential(
    torch.nn.Linear(num_ftrs, 17)
    # torch.nn.Sigmoid()
)

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Verify parameter statistics
for name, param in model.named_parameters():
    print(f"{name}: mean={param.mean().item()}, std={param.std().item()}")


### 6. Loss Function and Class Weighting
Implements weighted loss for handling class imbalance:
- Calculates positive/negative sample ratios
- Applies square root moderation to weights
- Implements BCEWithLogitsLoss with pos_weight

In [None]:
# Balanced pos_weight with moderation for better precision-recall balance
all_labels = []
for _, labels in train_loader:
    all_labels.append(labels)
all_labels = torch.cat(all_labels, dim=0) 

num_pos = all_labels.sum(dim=0)
num_neg = (all_labels == 0).sum(dim=0)
raw_pos_weight = num_neg / (num_pos + 1e-8)

# Moderate the pos_weight to balance precision and recall
# Clip extreme weights and apply square root to reduce impact
pos_weight = torch.sqrt(torch.clamp(raw_pos_weight, min=1.0, max=10.0))
pos_weight = pos_weight.to(device)

print("Class distribution (positive samples):", num_pos.numpy())
print("Applied pos_weight:", pos_weight.cpu().numpy())

criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)

### 7. Training Configuration
- Define SGD optimizer with:
  - Learning rate: 0.01
  - Momentum: 0.9
  - Weight decay: 0.0001
- Set up MultiStepLR scheduler
- Implement metric calculation for precision, recall, and F1
- Define training loop with validation

In [None]:
# - Higher initial LR since less data needs more aggressive learning
# - SGD with momentum (as per original paper) for better generalization on small data
# - More conservative LR decay schedule adapted for fewer epochs
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0001)
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20, 25], gamma=0.1)


def calculate_metrics(labels, outputs, threshold=0.5):
    # Apply sigmoid and thresholding
    outputs = torch.sigmoid(outputs)
    outputs = (outputs > threshold).float()

    # Derive "ND" (no defect) as 18th class
    labels_nd = (labels.sum(dim=1) == 0).float().unsqueeze(1)
    outputs_nd = (outputs.sum(dim=1) == 0).float().unsqueeze(1)

    # Extend label and prediction tensors
    labels_ext = torch.cat([labels, labels_nd], dim=1)
    outputs_ext = torch.cat([outputs, outputs_nd], dim=1)

    # Convert to numpy
    y_true = labels_ext.numpy()
    y_pred = outputs_ext.numpy()

    # Compute metrics
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='macro', zero_division=0)
    return precision, recall, f1


# Training function
def train_model(model, criterion, optimizer, scheduler, num_epochs=30, save_interval=15, checkpoint_dir="${YOUR_PROJECT_ROOT}/checkpoints/fine_tuning/SSL/${FINE_TUNING_BATCH}"):
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
   
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * inputs.size(0)
        
        scheduler.step()
        
        epoch_loss = running_loss / len(train_loader.dataset)
        
        # Log training loss to wandb
        wandb.log({"Train Loss": epoch_loss, "epoch": epoch})

        # Validate the model
        model.eval()
        val_loss = 0.0
        all_labels = []
        all_outputs = []
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * inputs.size(0)
                all_labels.append(labels.cpu())
                all_outputs.append(outputs.cpu())
        
        val_loss /= len(val_loader.dataset)
        all_labels = torch.cat(all_labels)
        all_outputs = torch.cat(all_outputs)
        precision, recall, f1 = calculate_metrics(all_labels, all_outputs)
        
        # Log validation metrics to wandb
        wandb.log({
            "Val Loss": val_loss,
            "Val Precision": precision,
            "Val Recall": recall,
            "Val F1 Score": f1,
            "epoch": epoch
        })
        
        print(f'Epoch {epoch + 1}/{num_epochs}, Train Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}, Val Precision: {precision:.4f}, Val Recall: {recall:.4f}, Val F1 Score: {f1:.4f}')
        
        # Save the model checkpoint every save_interval epochs
        if (epoch + 1) % save_interval == 0:
            checkpoint_path = os.path.join(checkpoint_dir, f"model_epoch_{epoch + 1}.pth")
            torch.save(model.state_dict(), checkpoint_path)
            print(f"Model checkpoint saved at epoch {epoch + 1}")

    return model

### 8. Model Training
- Execute training loop for 30 epochs
- Track metrics using wandb
- Save model checkpoints periodically

In [None]:
model = train_model(model, criterion, optimizer, scheduler, num_epochs=30)

### 9. Evaluation Function
- Define comprehensive evaluation function
- Calculate per-class and overall metrics
- Handle "No Defect" (ND) class specially

In [None]:
import torch
from sklearn.metrics import precision_recall_fscore_support

def evaluate_model(model, dataloader, threshold=0.5, save_images=False):
    model.eval()
    all_labels = []
    all_outputs = []
    images = []
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            all_labels.append(labels.cpu())
            all_outputs.append(outputs.cpu())
            if save_images:
                images.append(inputs.cpu())

    # Concatenate all batches
    all_labels = torch.cat(all_labels)
    all_outputs = torch.cat(all_outputs)

    # Apply sigmoid and thresholding
    all_outputs = torch.sigmoid(all_outputs)
    all_outputs = (all_outputs > threshold).float()

    # Derive "ND" (no defect) as 18th class
    labels_nd = (all_labels.sum(dim=1) == 0).float().unsqueeze(1)
    outputs_nd = (all_outputs.sum(dim=1) == 0).float().unsqueeze(1)

    # Extend label and output tensors
    labels_ext = torch.cat([all_labels, labels_nd], dim=1)
    outputs_ext = torch.cat([all_outputs, outputs_nd], dim=1)

    # Convert to numpy
    y_true = labels_ext.numpy()
    y_pred = outputs_ext.numpy()

    # Compute per-class metrics
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average=None, zero_division=0)
    overall_precision, overall_recall, overall_f1, _ = precision_recall_fscore_support(y_true, y_pred, average='micro', zero_division=0)

    for i in range(len(precision)):
        class_label = f'Class {i}' if i < labels_ext.shape[1] - 1 else 'Class 17 (ND)'
        print(f'{class_label} - Precision: {precision[i]:.4f}, Recall: {recall[i]:.4f}, F1 Score: {f1[i]:.4f}')

    print(f'\nOverall - Precision: {overall_precision:.4f}, Recall: {overall_recall:.4f}, F1 Score: {overall_f1:.4f}')
    return y_true, y_pred

### 10. Test Dataset Evaluation
- Load and prepare test dataset
- Run model evaluation on test set
- Calculate final performance metrics

In [None]:
test_dataset = CustomDataset(csv_file='{YOUR_PROJECT_ROOT}/data/fine_tuning/annotations/test/test_labels.csv', 
                            img_dir='{YOUR_PROJECT_ROOT}/data/fine_tuning/images/test', transform=inference_transform)

test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=8, pin_memory=True)

In [None]:
y_true, y_pred = evaluate_model(model, test_loader)

### 11. Metrics Computation
- Apply CIW (Class Importance Weights) from Sewer-ML paper
- Calculate specialized metrics:
  - Class-weighted F2 score
  - Normal class F1 score
  - Mean Average Precision
  - Exact Match Accuracy

In [None]:
from metrics import evaluation
import numpy as np

# Actual CIW weights from the Sewer-ML paper
ciw_weights = np.array([
    1.0000,  # RB
    0.5518,  # OB
    0.2896,  # PF
    0.1622,  # DE
    0.6419,  # FS
    0.1847,  # IS
    0.3559,  # RO
    0.3131,  # IN
    0.0811,  # AF
    0.2275,  # BE
    0.2477,  # FO
    0.0901,  # GR
    0.4167,  # PH
    0.4167,  # PB
    0.9009,  # OS
    0.3829,  # OP
    0.4396   # OK
])

# Example: y_pred = model outputs after sigmoid and thresholding
# y_true = ground truth labels

# If you have torch tensors, convert to numpy:
# y_true = y_true_tensor.numpy()
# y_pred = y_pred_tensor.numpy()
y_true_defects = y_true[:, :17]
y_pred_defects = y_pred[:, :17]


new_metrics, main_metrics, aux_metrics = evaluation(y_pred_defects, y_true_defects, ciw_weights, threshold=0.5)
f1_normal = new_metrics["F1_Normal"]
print("F1-score for ND (Normal):", f1_normal)
print("Main metrics:", main_metrics)
print("Class-weighted F2 (CIW-F2):", new_metrics["F2"])
print("Per-class F2:", new_metrics["F2_class"])
print("Macro F1:", main_metrics["MF1"])
print("Micro F1:", main_metrics["mF1"])
print("Mean Average Precision (mAP):", main_metrics["mAP"])
print("Exact Match Accuracy:", main_metrics["EMAcc"])