# Baseline ResNet-50
Setting up a baseline ResNet-50 model for person-outfit binary classification ("good" vs. "bad").

In [3]:
import os

import numpy as np
import torch
import torch.nn as nn
import torchvision
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
from torch.utils.data import WeightedRandomSampler, DataLoader
from torchvision import transforms
from PIL import Image

from src import config

In [4]:
print("Using device:", config.DEVICE)

Using device: mps


In [5]:
os.makedirs(config.CHECKPOINT_PATH, exist_ok=True)

In [6]:
# Load pre-trained ResNet-50
model = torchvision.models.resnet50(pretrained=True)

# model_name = "arize-ai/resnet-50-fashion-mnist-quality-drift"
# model = AutoModelForImageClassification.from_pretrained(model_name)
# processor = AutoImageProcessor.from_pretrained(model_name)



In [7]:
# Freeze all layers (base layers) except the last fully connected layer
# for param in model.parameters():
#     param.requires_grad = False

# Unfreeze last few layers
for name, param in model.named_parameters():
    if "layer4" in name or "fc" in name: # "layer3" in name or 
        param.requires_grad = True
    else:
        param.requires_grad = False

In [8]:
# Modify for binary classification
model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, 256),
    nn.ReLU(),
    nn.Dropout(0.5),  # Regularization
    nn.Linear(256, 2)
)

In [9]:
model = model.to(config.DEVICE)

In [10]:
or_pos_dir = config.ORIGINAL_POS_OUTFITS_DIR
or_neg_dir = config.ORIGINAL_NEG_OUTFITS_DIR
seg_pos_dir = config.SEGMENTED_POS_OUTFITS_DIR
seg_neg_dir = config.SEGMENTED_NEG_OUTFITS_DIR

# Collect image paths and labels
image_paths = []
labels = []
for class_idx, folder in enumerate([seg_neg_dir, seg_pos_dir]):  # 0=negative, 1=positive
    for img_name in os.listdir(folder):
        if img_name.lower().endswith(config.IMAGE_FILE_EXTENSIONS):  # Filter images
            img_path = os.path.join(folder, img_name)
            image_paths.append(img_path)
            labels.append(class_idx)

for class_idx, folder in enumerate([or_neg_dir, or_pos_dir]):  # 0=negative, 1=positive
    for img_name in os.listdir(folder):
        if img_name.lower().endswith(config.IMAGE_FILE_EXTENSIONS):  # Filter images
            img_path = os.path.join(folder, img_name)
            image_paths.append(img_path)
            labels.append(class_idx)

import random

# Combine image_paths and labels
combined = list(zip(image_paths, labels))
random.seed(42)
random.shuffle(combined)

# Unzip back to separate lists
image_paths[:], labels[:] = zip(*combined)

# Split with stratification (80% train, 20% validation)
train_paths, val_paths, train_labels, val_labels = train_test_split(
    image_paths, 
    labels, 
    test_size=0.2, 
    stratify=labels,
    random_state=42
)

In [11]:
print(f"Train: {sum(train_labels)} positives, {len(train_labels)-sum(train_labels)} negatives")
print(f"Val: {sum(val_labels)} positives, {len(val_labels)-sum(val_labels)} negatives")

Train: 1306 positives, 7249 negatives
Val: 326 positives, 1813 negatives


In [12]:
class CustomImageDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')  # Ensure RGB format
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
        return image, label

In [13]:
# Image transformations
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [14]:
# Load datasets
train_dataset = CustomImageDataset(
    image_paths=train_paths,
    labels=train_labels,
    transform=train_transform
)

val_dataset = CustomImageDataset(
    image_paths=val_paths,
    labels=val_labels,
    transform=val_transform
)

class_sample_count = np.array([np.sum(labels == t) for t in np.unique(labels)])
weights = 1. / class_sample_count
samples_weight = np.array([weights[t] for t in labels])
samples_weight = torch.from_numpy(samples_weight).double()

sampler = WeightedRandomSampler(samples_weight, len(samples_weight), replacement=True)
# train_loader = DataLoader(train_dataset, batch_size=32, sampler=sampler)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

In [15]:
def get_class_counts(dataset):
    counts = [0, 0]  # [count_class0, count_class1]
    for _, label in dataset:
        # Handle both tensor and integer labels
        if isinstance(label, torch.Tensor):
            label = label.item()
        counts[label] += 1
    return counts

In [16]:
from torch.optim.lr_scheduler import ReduceLROnPlateau

class_counts = get_class_counts(train_dataset)
weights = torch.tensor([
    1.0 / max(class_counts[0], 1),  # Prevent division by zero
    1.0 / max(class_counts[1], 1)
], dtype=torch.float).to(config.DEVICE)

# Loss function (binary classification)
criterion = nn.CrossEntropyLoss(weight=weights)
#criterion = nn.CrossEntropyLoss()

print(f"Class 0: {class_counts[0]} samples")
print(f"Class 1: {class_counts[1]} samples")
print(f"Applied weights: {weights.cpu().numpy()}")

Class 0: 7249 samples
Class 1: 1306 samples
Applied weights: [0.00013795 0.0007657 ]


In [17]:
# Optimizer with weight decay
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5, weight_decay=1e-4)

# Learning rate scheduler
scheduler = ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=6,
    min_lr=1e-6
)

# Early stopping setup
best_val_loss = float('inf')
patience, patience_counter = 10, 0

# Checkpointing
best_accuracy = 0.0
best_f1 = 0
start_epoch = 0

In [18]:
# Load checkpoint if resuming
if config.RESUME_CHECKPOINT and os.path.exists(os.path.join(config.CHECKPOINT_PATH, f"best_model.pth")):
    checkpoint = torch.load(os.path.join(config.CHECKPOINT_PATH, f"best_model.pth"))
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"Resuming training from epoch {start_epoch}")

Resuming training from epoch 48


In [19]:
from sklearn.metrics import roc_curve, auc, confusion_matrix
from sklearn.metrics import precision_score, recall_score, f1_score
import numpy as np

previous_lr = optimizer.param_groups[0]['lr']
train_losses = []    # Track training loss per epoch
val_losses = []      # Track validation loss per epoch
all_epoch_probs = [] # Store probabilities per epoch for ROC

for epoch in range(start_epoch, config.EPOCHS):
    model.train()
    train_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(config.DEVICE), labels.to(config.DEVICE)
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels) 
        train_loss += loss.item() * inputs.size(0) # summing up all the elements
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # Calculate average training loss
    train_loss /= len(train_loader) # dividing by the total number of elements
    train_losses.append(train_loss)
    
    # Validation
    model.eval()
    val_loss = 0.0
    total_samples = 0
    all_preds, all_labels, all_probs = [], [], []
    
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(config.DEVICE), labels.to(config.DEVICE)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * inputs.size(0)
            total_samples += inputs.size(0)
            
            probs = torch.softmax(outputs, dim=1)[:, 1]
            
            # Store predictions/labels for metrics
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Calculate average validation loss
    val_loss /= total_samples
    val_losses.append(val_loss)
    
    # Calculate metrics
    accuracy = (np.array(all_preds) == np.array(all_labels)).mean()
    precision = precision_score(all_labels, all_preds, zero_division=0)
    recall = recall_score(all_labels, all_preds, zero_division=0)
    f1 = f1_score(all_labels, all_preds, zero_division=0)
    fpr, tpr, _ = roc_curve(all_labels, all_probs)
    roc_auc = auc(fpr, tpr)
    all_epoch_probs.append(all_probs)  # Store for later analysis
    
    if accuracy > best_accuracy:
        best_accuracy = accuracy
    
    print(f"Epoch {epoch+1}: "
          f"Val Loss: {val_loss:.4f} | "
          f"Accuracy: {accuracy:.2f} | "
          f"Precision: {precision:.2f} | "
          f"Recall: {recall:.2f} | "
          f"F1: {f1:.2f} | "
          f"AUC: {roc_auc:.2f}")
    
    # LR scheduler step
    scheduler.step(val_loss)
    
    current_lr = optimizer.param_groups[0]['lr']
    if current_lr < previous_lr:
        print(f"Epoch {epoch+1}: LR reduced to {current_lr}")
    previous_lr = current_lr
    
    # Save checkpoint
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss.item()
    }

    # Early stopping check
    #if val_loss < best_val_loss:
        # best_val_loss = val_loss
    if f1 > best_f1:
        best_f1 = f1
        patience_counter = 0
        torch.save(checkpoint, os.path.join(config.CHECKPOINT_PATH, f"best_model_{epoch+1}.pth"))
        torch.save(checkpoint, os.path.join(config.CHECKPOINT_PATH, f"best_model.pth"))
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

Epoch 49: Val Loss: 0.2883 | Accuracy: 0.93 | Precision: 0.71 | Recall: 0.88 | F1: 0.78 | AUC: 0.97
Epoch 50: Val Loss: 0.3142 | Accuracy: 0.94 | Precision: 0.78 | Recall: 0.85 | F1: 0.81 | AUC: 0.97
Epoch 51: Val Loss: 0.3098 | Accuracy: 0.94 | Precision: 0.77 | Recall: 0.87 | F1: 0.81 | AUC: 0.97
Epoch 52: Val Loss: 0.3018 | Accuracy: 0.94 | Precision: 0.78 | Recall: 0.87 | F1: 0.82 | AUC: 0.97
Epoch 53: Val Loss: 0.3012 | Accuracy: 0.94 | Precision: 0.77 | Recall: 0.87 | F1: 0.82 | AUC: 0.97
Epoch 54: Val Loss: 0.3121 | Accuracy: 0.94 | Precision: 0.78 | Recall: 0.87 | F1: 0.82 | AUC: 0.97
Epoch 55: Val Loss: 0.3195 | Accuracy: 0.94 | Precision: 0.76 | Recall: 0.88 | F1: 0.82 | AUC: 0.97
Epoch 56: Val Loss: 0.3112 | Accuracy: 0.94 | Precision: 0.76 | Recall: 0.87 | F1: 0.81 | AUC: 0.97
Epoch 56: LR reduced to 1.5625e-06
Epoch 57: Val Loss: 0.3120 | Accuracy: 0.94 | Precision: 0.77 | Recall: 0.87 | F1: 0.82 | AUC: 0.97
Epoch 58: Val Loss: 0.3221 | Accuracy: 0.94 | Precision: 0.77 | R

KeyboardInterrupt: 

In [20]:
import matplotlib.pyplot as plt
from sklearn.calibration import calibration_curve
from sklearn.metrics import precision_recall_curve

# 1. Plot loss curves
plt.figure(figsize=(8, 6))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Training/Validation Loss')
plt.savefig('training_validation_loss.png')
plt.close()

# 2. Precision-Recall Curve (from last epoch)
plt.figure(figsize=(8, 6))
precision, recall, thresholds = precision_recall_curve(all_labels, all_preds)
pr_auc = auc(recall, precision)
f1s = 2 * (precision * recall) / (precision + recall + 1e-8)
best_idx = np.argmax(f1s)
best_threshold = thresholds[best_idx]
print(f"Best threshold: {best_threshold:.3f}, Best F1: {f1s[best_idx]:.3f}")
plt.figure()
plt.plot(recall, precision, color='blue', lw=2, label=f'PR curve (area = {pr_auc:0.2f})')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.legend()
plt.title('Precision-Recall Curve (last epoch)')
plt.savefig('precision_recall_curve.png')
plt.close()

# 3. Plot ROC Curve (from last epoch)
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, 
         label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve (last epoch)')
plt.legend()
plt.savefig('roc_curve.png')
plt.close()

# 3. Confusion Matrix (from last epoch)
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(6, 6))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix (last epoch)')
plt.colorbar()
tick_marks = np.arange(2)
plt.xticks(tick_marks, ['Negative', 'Positive'])
plt.yticks(tick_marks, ['Negative', 'Positive'])
plt.ylabel('True label')
plt.xlabel('Predicted label')
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        plt.text(j, i, format(cm[i, j], 'd'),
                 ha="center", va="center",
                 color="white" if cm[i, j] > thresh else "black")
plt.savefig('confusion_matrix.png')
plt.close()

# 4. Probability Distribution (MSE equivalent)
plt.figure(figsize=(8, 6))
plt.hist(np.array(all_probs)[np.array(all_labels) == 0], 
         bins=30, alpha=0.5, label='Negative Class')
plt.hist(np.array(all_probs)[np.array(all_labels) == 1], 
         bins=30, alpha=0.5, label='Positive Class')
plt.xlabel('Predicted Probability')
plt.ylabel('Frequency')
plt.title('Predicted Probability Distribution')
plt.legend()
plt.savefig('probability_distribution.png')
plt.close()

# 5. Calibration
prob_true, prob_pred = calibration_curve(all_labels, all_probs, n_bins=10)
plt.plot(prob_pred, prob_true, marker='o')
plt.plot([0, 1], [0, 1], linestyle='--')
plt.xlabel('Mean predicted probability')
plt.ylabel('Fraction of positives')
plt.title('Calibration curve')
plt.savefig('calibration_curve.png')
plt.close()

# For actual calibration, fit a calibrator:
from sklearn.isotonic import IsotonicRegression

iso_reg = IsotonicRegression(out_of_bounds='clip')
calibrated_probs = iso_reg.fit_transform(all_probs, all_labels)

Best threshold: 1.000, Best F1: 0.829


### üìä **Visual Summary**
| Metric      | Focus                  | Best Value | Worst Value | Ideal When...                |
|-------------|------------------------|------------|-------------|------------------------------|
| **Val Loss**| Prediction confidence  | 0.0        | ‚àû           | Monitoring training progress |
| **Accuracy**| Overall correctness    | 1.0        | 0.0         | Balanced classes             |
| **Precision**| False positives       | 1.0        | 0.0         | Costly false alarms          |
| **Recall**  | False negatives        | 1.0        | 0.0         | Missing positives is bad     |
| **F1**      | Precision-Recall tradeoff | 1.0     | 0.0         | Class imbalance exists       |

### üîç **Interpretation**
Validation shows:
```
Epoch 28: Val Loss: 0.3879 | Accuracy: 0.84 | Precision: 0.83 | Recall: 0.82 | F1: 0.83
Epoch 5: Val Loss: 0.3822 | Accuracy: 0.84 | Precision: 0.87 | Recall: 0.79 | F1: 0.83
Epoch 6: Val Loss: 0.3011 | Accuracy: 0.86 | Precision: 0.89 | Recall: 0.83 | F1: 0.86
Epoch 38: Val Loss: 0.2807 | Accuracy: 0.91 | Precision: 0.92 | Recall: 0.90 | F1: 0.91
```
This means:
1. Model is confident in correct predictions (low loss=39%)
2. Shows strong overall performance (very good accuracy=84%)
3. Is correct 83% of time, but has false positives (precision=83% means 17% of "positive" predictions are wrong)
4. Catches most positives (high recall=82% means misses only 18% of true positives)
5. F1=0.83 shows good balance

**Actionable insight**: To improve precision, increase classification threshold.

### üí° **When to Prioritize Which Metric**
| Scenario                | Priority Metric  | Why                              |
|-------------------------|------------------|----------------------------------|
| Balanced classes        | Accuracy         | Simple overall measure           |
| Costly false positives  | Precision        | Avoid false alarms               |
| Costly false negatives  | Recall           | Don't miss critical cases        |
| Class imbalance         | F1 Score         | Balances both error types        |
| Training decisions      | Val Loss         | Most sensitive to model changes  |

In [21]:
import torch
from PIL import Image

# Load checkpoint
checkpoint = torch.load(os.path.join(config.CHECKPOINT_PATH, f"best_model.pth"))
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(config.DEVICE)
model.eval()  # Set to evaluation mode

def preprocess_image(image_path):
    image = Image.open(image_path).convert('RGB')
    return val_transform(image).unsqueeze(0)  # Add batch dimension

In [22]:
test_image_files = [
    f for f in os.listdir(config.TEST_DIR)
    if f.lower().endswith(config.IMAGE_FILE_EXTENSIONS)
]

test_image_files.sort()

correct_count = 0
faulty_count = 0

for i, image_file in enumerate(test_image_files):
    image_path = os.path.join(config.TEST_DIR, image_file)
    image_tensor = preprocess_image(image_path).to(config.DEVICE)

    with torch.no_grad():
        output = model(image_tensor)
        probabilities = torch.nn.functional.softmax(output, dim=1)
        predicted_class = torch.argmax(probabilities, 1).item()

    class_names = ['bad', 'good']  # 0=negative, 1=positive
    print(f"Image {i+1}: {image_file}")
    print(f"Predicted: {class_names[predicted_class]}")
    print(f"Confidence: {probabilities[0][predicted_class].item():.2%}")
    
    # Condition checks
    if (class_names[predicted_class] == 'good' and 'good' in image_file) or \
       (class_names[predicted_class] == 'bad' and 'bad' in image_file):
        print("classification correct\n")
        correct_count += 1
    else:
        print("classification faulty\n")
        faulty_count += 1

not_rated = 0

# Final summary
print(f"Total correct classifications: {correct_count} out of {correct_count + faulty_count - not_rated} ({correct_count / (correct_count + faulty_count - not_rated) * 100:.2f}%)")
print(f"Total faulty classifications: {faulty_count - not_rated} out of {correct_count + faulty_count - not_rated} ({(faulty_count - not_rated) / (correct_count + faulty_count - not_rated) * 100:.2f}%)")

Image 1: 10_good_o.jpg
Predicted: good
Confidence: 99.62%
classification correct

Image 2: 11_good_o.jpg
Predicted: good
Confidence: 99.12%
classification correct

Image 3: 12_good_o.jpg
Predicted: good
Confidence: 80.69%
classification correct

Image 4: 13_bad_o.jpg
Predicted: bad
Confidence: 100.00%
classification correct

Image 5: 14_bad_o.jpg
Predicted: bad
Confidence: 99.98%
classification correct

Image 6: 15_good_o.jpg
Predicted: bad
Confidence: 99.80%
classification faulty

Image 7: 16_good_o.jpg
Predicted: bad
Confidence: 99.51%
classification faulty

Image 8: 17_bad_o.jpg
Predicted: good
Confidence: 94.74%
classification faulty

Image 9: 18_bad_o.jpg
Predicted: bad
Confidence: 99.92%
classification correct

Image 10: 19_good_o.jpg
Predicted: good
Confidence: 99.88%
classification correct

Image 11: 1_good_o.jpg
Predicted: good
Confidence: 67.17%
classification correct

Image 12: 1_good_s.jpg
Predicted: good
Confidence: 53.38%
classification correct

Image 13: 20_good_o.jpg
Pr