# Fine-Tuning ResNet-50 for Meeting Context Classification

This notebook guides you through the process of fine-tuning a pre-trained ResNet-50 model to distinguish between:
1.  **Slides/Screen Share** (High informational content)
2.  **People/Camera** (Social/Interactional content)

**Environment**: Kaggle (LFW dataset present in `../input/lfw-dataset/`)

In [None]:
import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 1. Setup LFW Dataset (People)
Using the pre-loaded LFW dataset from Kaggle input.

In [None]:
# Define Paths
lfw_root = "../input/lfw-dataset/lfw-deepfunneled/lfw-deepfunneled/"
csv_root = "../input/lfw-dataset/"

# Load CSVs (User provided logic)
peopleDevTrain = pd.read_csv(os.path.join(csv_root, "peopleDevTrain.csv"))
peopleDevTest = pd.read_csv(os.path.join(csv_root, "peopleDevTest.csv"))

print(f"Train People: {len(peopleDevTrain)}")
print(f"Test People: {len(peopleDevTest)}")

# Helper to get all image paths for a person
def get_person_image_paths(person_name, image_count):
    paths = []
    person_dir = os.path.join(lfw_root, person_name)
    # Filename format: name_0001.jpg
    for i in range(1, image_count + 1):
        filename = f"{person_name}_{i:04d}.jpg"
        paths.append(os.path.join(person_dir, filename))
    return paths

# Collect all paths
train_person_paths = []
for _, row in peopleDevTrain.iterrows():
    train_person_paths.extend(get_person_image_paths(row['name'], row['images']))

test_person_paths = []
for _, row in peopleDevTest.iterrows():
    test_person_paths.extend(get_person_image_paths(row['name'], row['images']))

print(f"Total Train Person Images: {len(train_person_paths)}")
print(f"Total Test Person Images: {len(test_person_paths)}")

## 2. Setup SlideAudit Dataset (Slides)
Cloning the SlideAudit dataset from GitHub.

In [None]:
import subprocess
import shutil
from sklearn.model_selection import train_test_split

slide_root = "./slide_dataset" # Writable directory
slide_images_dir = os.path.join(slide_root, "SlideAudit", "data", "images")
slide_repo = "https://github.com/zhuohaouw/SlideAudit.git"

if not os.path.exists(slide_images_dir):
    print("Cloning SlideAudit dataset...")
    if os.path.exists(slide_root):
        shutil.rmtree(slide_root)
    os.makedirs(slide_root, exist_ok=True)
    subprocess.run(["git", "clone", slide_repo, os.path.join(slide_root, "SlideAudit")], check=True)
    print("SlideAudit setup complete.")
else:
    print("SlideAudit dataset already exists.")

# Collect all slide image paths
all_slide_paths = []
for img in os.listdir(slide_images_dir):
    if img.lower().endswith(('.jpg', '.jpeg', '.png')):
        all_slide_paths.append(os.path.join(slide_images_dir, img))

# Split Slides into Train/Test (70/30 split)
train_slide_paths, test_slide_paths = train_test_split(all_slide_paths, test_size=0.3, random_state=42)

print(f"Total Train Slide Images: {len(train_slide_paths)}")
print(f"Total Test Slide Images: {len(test_slide_paths)}")

## 3. Handling "Picture-in-Picture" (Mixed Contexts)

**Problem**: Real meetings often show a slide WITH a small webcam overlay (Picture-in-Picture). We want the model to classify these as **Slides** because the slide is the dominant content.

**Solution**: We will synthetically generate "PiP" images by randomly pasting small person images onto slide images and adding them to the training set labeled as 'Slide'.

In [None]:
def generate_pip_samples(slide_paths, person_paths, output_dir, num_samples=500):
    os.makedirs(output_dir, exist_ok=True)
    pip_paths = []
    
    print(f"Generating {num_samples} synthetic Picture-in-Picture samples...")
    
    for i in range(num_samples):
        # 1. Pick random slide and person
        slide_path = random.choice(slide_paths)
        person_path = random.choice(person_paths)
        
        try:
            slide_img = Image.open(slide_path).convert('RGB')
            person_img = Image.open(person_path).convert('RGB')
            
            # 2. Resize person to be smaller (e.g., 20-30% of slide width)
            scale = random.uniform(0.2, 0.3)
            new_width = int(slide_img.width * scale)
            aspect_ratio = person_img.height / person_img.width
            new_height = int(new_width * aspect_ratio)
            person_img = person_img.resize((new_width, new_height))
            
            # 3. Paste in a random corner
            # Corners: Top-Right, Bottom-Right, Top-Left, Bottom-Left
            corners = [
                (slide_img.width - new_width - 10, 10), # TR
                (slide_img.width - new_width - 10, slide_img.height - new_height - 10), # BR
                (10, 10), # TL
                (10, slide_img.height - new_height - 10) # BL
            ]
            pos = random.choice(corners)
            
            slide_img.paste(person_img, pos)
            
            # 4. Save
            save_path = os.path.join(output_dir, f"pip_{i}.jpg")
            slide_img.save(save_path)
            pip_paths.append(save_path)
            
        except Exception as e:
            print(f"Error generating PiP sample: {e}")
            
    return pip_paths

# Generate PiP samples for Training only
pip_train_dir = "./pip_dataset/train"
train_pip_paths = generate_pip_samples(train_slide_paths, train_person_paths, pip_train_dir, num_samples=1000)

# Add these to our training list (Label = 1 for Slide)
print(f"Added {len(train_pip_paths)} PiP images to training set.")

## 4. Custom Dataset Class
Updated to include the PiP paths.

In [None]:
class MeetingContextDataset(Dataset):
    def __init__(self, person_paths, slide_paths, pip_paths=[], transform=None):
        self.person_paths = person_paths
        self.slide_paths = slide_paths
        self.pip_paths = pip_paths
        self.transform = transform
        
        # 0 = Person, 1 = Slide
        self.data = []
        for p in person_paths:
            self.data.append((p, 0))
        for p in slide_paths:
            self.data.append((p, 1))
        for p in pip_paths:
            self.data.append((p, 1)) # PiP counts as Slide
            
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        try:
            image = Image.open(img_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            return image, label
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            return torch.zeros((3, 224, 224)), label

# Transforms
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Create Datasets (Include PiP in training)
train_dataset = MeetingContextDataset(train_person_paths, train_slide_paths, train_pip_paths, data_transforms['train'])
val_dataset = MeetingContextDataset(test_person_paths, test_slide_paths, [], data_transforms['val'])

dataloaders = {
    'train': DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2),
    'val': DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)
}

dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}
print(f"Dataset Sizes: {dataset_sizes}")

## 5. Train ResNet-50 with Class Balancing
We calculate class weights to handle the imbalance between Person (many) and Slide (fewer) images.

In [None]:
# Calculate Class Weights
count_person = len(train_person_paths)
count_slide = len(train_slide_paths) + len(train_pip_paths)
total_samples = count_person + count_slide

print(f"Training Samples - Person: {count_person}, Slide: {count_slide}")

# Inverse frequency weights
weight_person = 1.0 / count_person
weight_slide = 1.0 / count_slide

# Normalize
norm_factor = (weight_person + weight_slide) / 2.0
weight_person /= norm_factor
weight_slide /= norm_factor

class_weights = torch.tensor([weight_person, weight_slide]).to(device)
print(f"Class Weights: Person={weight_person:.4f}, Slide={weight_slide:.4f}")

# Initialize Model
model = models.resnet50(pretrained=True)

# Freeze layers (optional)
for param in model.parameters():
    param.requires_grad = False

# Modify Output Layer (2 classes: Person, Slide)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
model = model.to(device)

# Use Weighted Loss
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)

# Training Loop
def train_model(model, criterion, optimizer, num_epochs=5):
    since = time.time()
    best_acc = 0.0
    best_model_wts = copy.deepcopy(model.state_dict())

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')
    model.load_state_dict(best_model_wts)
    return model

model = train_model(model, criterion, optimizer, num_epochs=5)

# Save Model
save_path = "resnet50_meeting_context.pth"
torch.save(model.state_dict(), save_path)
print(f"Model saved to {save_path}")

## 6. Evaluation Metrics
We will evaluate the model on the validation set to calculate:
- Confusion Matrix (TP, FP, TN, FN)
- Precision, Recall, F1-Score

In [None]:
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

def evaluate_model(model, dataloader):
    model.eval()
    all_preds = []
    all_labels = []
    
    print("Running evaluation on validation set...")
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
    # Confusion Matrix
    # Labels: 0 = Person, 1 = Slide
    cm = confusion_matrix(all_labels, all_preds)
    tn, fp, fn, tp = cm.ravel()
    
    print("\n--- Confusion Matrix ---")
    print(f"True Negatives (Person correctly identified): {tn}")
    print(f"False Positives (Person identified as Slide): {fp}")
    print(f"False Negatives (Slide identified as Person): {fn}")
    print(f"True Positives (Slide correctly identified): {tp}")
    
    # Plotting
    plt.figure(figsize=(6,5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Person', 'Slide'], yticklabels=['Person', 'Slide'])
    plt.ylabel('Actual')
    plt.xlabel('Predicted')
    plt.title('Confusion Matrix')
    plt.show()
    
    # Classification Report
    print("\n--- Classification Report ---")
    print(classification_report(all_labels, all_preds, target_names=['Person', 'Slide']))

# Run Evaluation
evaluate_model(model, dataloaders['val'])

## 7. Visual Inference (Sanity Check)
Let's see the model in action by visualizing predictions on random validation images.

In [None]:
def visualize_predictions(model, dataloader, num_images=6):
    model.eval()
    images_so_far = 0
    plt.figure(figsize=(15, 10))
    
    class_names = ['Person', 'Slide']
    
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images // 3 + 1, 3, images_so_far)
                ax.axis('off')
                
                # Un-normalize image for display
                img = inputs.cpu().data[j].numpy().transpose((1, 2, 0))
                mean = np.array([0.485, 0.456, 0.406])
                std = np.array([0.229, 0.224, 0.225])
                img = std * img + mean
                img = np.clip(img, 0, 1)
                
                true_label = class_names[labels[j]]
                pred_label = class_names[preds[j]]
                
                color = 'green' if true_label == pred_label else 'red'
                ax.set_title(f'True: {true_label} | Pred: {pred_label}', color=color)
                plt.imshow(img)
                
                if images_so_far == num_images:
                    return

visualize_predictions(model, dataloaders['val'])

## 8. Download the Model

**How to download the `.pth` file from Kaggle:**

1.  Look at the **Output** section on the right sidebar of the Kaggle notebook editor.
2.  You should see `resnet50_meeting_context.pth` listed under `/kaggle/working`.
3.  Click the **three dots (...)** next to the file name.
4.  Select **Download**.

Once downloaded, place this file in your local project directory (e.g., inside `ml/models/`) so your RoME app can load it.