Notebook for surgical instrument classification with automated hyperparameter sweeps using sweep.yml.

In [None]:
!pip install wandb torch torchvision matplotlib scikit-learn pyyaml

In [None]:
!curl -L -o surgical-instrument-classification.zip\
  https://www.kaggle.com/api/v1/datasets/download/debeshjha1/surgical-instrument-classification
!unzip surgical-instrument-classification.zip

In [None]:
import wandb
import yaml
wandb.login()

In [None]:
# Load sweep configuration from YAML
with open('src/sweep.yml', 'r') as file:
    sweep_config = yaml.safe_load(file)

print("Sweep configuration:")
print(yaml.dump(sweep_config, default_flow_style=False))

In [None]:
# Create sweep using YAML config
from src.train import train_model

sweep_id = wandb.sweep(sweep_config, project="surginet")
print(f"Sweep ID: {sweep_id}")

# Run sweep
wandb.agent(sweep_id, train_model, count=20)

In [None]:
# Alternative: Run single experiment with config from YAML defaults
config = {
    'model_name': 'resnet18',
    'batch_size': 32,
    'learning_rate': 0.001,
    'weight_decay': 1e-4,
    'dropout': 0.5,
    'pretrained': True,
    'epochs': sweep_config['parameters']['epochs']['value'],
    'val_split': sweep_config['parameters']['val_split']['value'],
    'patience': sweep_config['parameters']['patience']['value']
}

final_acc = train_model(config)
print(f"Final accuracy: {final_acc:.4f}")

In [None]:
# Quick evaluation
import torch
from src.models import get_model
from src.data import get_data_loaders
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import numpy as np

# Load model
model = get_model('resnet18', 10, pretrained=False)
model.load_state_dict(torch.load('best_model.pth', map_location='cpu'))
model.eval()

# Get data
_, val_loader, class_names = get_data_loaders()

# Evaluate
all_preds, all_labels = [], []
with torch.no_grad():
    for inputs, labels in val_loader:
        outputs = model(inputs)
        preds = outputs.argmax(1).numpy()
        all_preds.extend(preds)
        all_labels.extend(labels.numpy())

print(classification_report(all_labels, all_preds, target_names=class_names))