# Mitosis Detector Notebook

Import libraries and mount google drive

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, TensorDataset
from torchvision import datasets, transforms, models
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
import seaborn as sns
import time
from google.colab import drive

drive.mount('/content/drive')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Mounted at /content/drive
Using device: cuda


In [None]:
!unzip -q "/content/drive/MyDrive/mitosis_detector_BTTAI/dataset_root.zip" -d "/content/dataset_root"

Configure paths and parameters

In [None]:
BASE_PATH = '/content/drive/MyDrive/mitosis_detector_BTTAI'

DATA_DIR = '/content/dataset_root/dataset_root'
MODEL_SAVE_PATH = os.path.join(BASE_PATH, 'best_mitosis_model.pth')
PLOT_SAVE_PATH = os.path.join(BASE_PATH, 'training_curves.png')

BATCH_SIZE = 32
LEARNING_RATE = 0.0001
WEIGHT_DECAY = 1e-5
NUM_EPOCHS = 50
RANDOM_SEED = 123
PATIENCE = 10

torch.manual_seed(RANDOM_SEED)

<torch._C.Generator at 0x7c281772cdb0>

In [None]:
full_dataset = datasets.ImageFolder(DATA_DIR, transform=None)

IMG_SIZE = 224
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

# Data augmentation for the training set
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.ToTensor(),
    normalize,
])

# Only resizing and normalization for validation and test sets
val_test_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    normalize,
])

# Load the full dataset from the directory structure
print(f"Loading data from: {DATA_DIR}")
full_dataset = datasets.ImageFolder(DATA_DIR)
class_names = full_dataset.classes
print(f"Found {len(full_dataset)} images belonging to {len(class_names)} classes: {class_names}")

# Define split sizes
train_size = int(0.7 * len(full_dataset))
val_size = int(0.15 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size

# Split the dataset
train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])

# Apply the transformations
train_dataset.dataset.transform = train_transforms
val_dataset.dataset.transform = val_test_transforms
test_dataset.dataset.transform = val_test_transforms

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

Loading data from: /content/dataset_root/dataset_root
Found 26350 images belonging to 2 classes: ['1', '2']
Training samples: 18445
Validation samples: 3952
Test samples: 3953


Model definition

In [None]:
def get_model(num_classes=2):
    # Load the pre-trained DeepLabV3 model with a ResNet-101 backbone
    model = models.segmentation.deeplabv3_resnet101(weights=models.segmentation.DeepLabV3_ResNet101_Weights.DEFAULT)

    # Freeze all the parameters in the model
    for param in model.parameters():
        param.requires_grad = False

    for param in model.backbone.layer4.parameters():
        param.requires_grad = True

    # Replace the classifier head with a new sequence for classification
    model.classifier = nn.Sequential(
        nn.Conv2d(2048, 256, kernel_size=3, stride=1, padding=1, bias=False),
        nn.BatchNorm2d(256),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Conv2d(256, num_classes, kernel_size=1, stride=1),
        nn.AdaptiveAvgPool2d(1),
        nn.Flatten(start_dim=1)
    )

    # Unfreeze the parameters of the new classifier head so they can be trained
    for param in model.classifier.parameters():
        param.requires_grad = True

    return model

model = get_model(num_classes=len(class_names))
model = model.to(device)

print("Model architecture loaded and modified for transfer learning.")

Downloading: "https://download.pytorch.org/models/deeplabv3_resnet101_coco-586e9e4e.pth" to /root/.cache/torch/hub/checkpoints/deeplabv3_resnet101_coco-586e9e4e.pth


100%|██████████| 233M/233M [00:05<00:00, 42.9MB/s]


Model architecture loaded and modified for transfer learning.


In [None]:
def extract_and_save_features(data_loader, model, save_path, device):
    model.eval()
    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    # Prepare temporary lists for current chunk
    features_list = []
    labels_list = []
    chunk_idx = 0

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(data_loader):
            inputs = inputs.to(device)
            feats = model.backbone(inputs)['out']        # [B, 2048, H, W]
            feats = torch.nn.functional.adaptive_avg_pool2d(feats, (1, 1))
            feats = feats.squeeze(-1).squeeze(-1).cpu()  # [B, 2048]

            features_list.append(feats)
            labels_list.append(labels)

            # Every N batches → write to disk to free RAM
            if (i + 1) % 20 == 0:  # adjust as needed
                torch.save((torch.cat(features_list), torch.cat(labels_list)),
                           f"{save_path}_part{chunk_idx}.pt")
                chunk_idx += 1
                features_list.clear()
                labels_list.clear()
                torch.cuda.empty_cache()

        # Save remaining data
        if features_list:
            torch.save((torch.cat(features_list), torch.cat(labels_list)),
                       f"{save_path}_part{chunk_idx}.pt")
            features_list.clear()
            labels_list.clear()
            torch.cuda.empty_cache()

    print(f"Saved all features in chunks to {save_path}_part*.pt")


Extracting train features...


In [None]:
extract_and_save_features(train_loader, model, "train_features", device)
extract_and_save_features(val_loader, model, "val_features", device)

In [None]:
#train_features, train_labels = torch.load("train_features.pt")
#val_features, val_labels = torch.load("val_features.pt")

# Wrap them in TensorDatasets
train_feat_dataset = TensorDataset(train_features, train_labels)
val_feat_dataset = TensorDataset(val_features, val_labels)

# Use new DataLoaders
train_loader = DataLoader(train_feat_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_feat_dataset, batch_size=BATCH_SIZE, shuffle=False)

Training and Validation Cycle

In [None]:
# Import the scheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Loss function
criterion = nn.CrossEntropyLoss()


# We pass only the parameters that we want to train (unfrozen)
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                       lr=LEARNING_RATE,
                       weight_decay=WEIGHT_DECAY)

# This will reduce the LR when validation loss stops improving
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)


# Variables for training history and Early Stopping
history = {'train_loss': [], 'val_loss': [], 'val_acc': []}
best_val_loss = float('inf') # We want to minimize validation loss
epochs_no_improve = 0

acc = []

print("Starting model training...")
start_time = time.time()

for epoch in range(NUM_EPOCHS):
    #Training
    model.train()
    running_train_loss = 0.0
    print(f"Starting Epoch {epoch+1}")
    for i, (inputs, labels) in enumerate(train_loader):
        print(f"Batch {i+1}: Loading data...")
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        features = model.backbone(inputs)['out']
        outputs = model.classifier(features)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_train_loss += loss.item() * features.size(0)

    epoch_train_loss = running_train_loss / len(train_dataset)
    history['train_loss'].append(epoch_train_loss)

    #Validation
    model.eval()
    running_val_loss = 0.0
    correct_predictions = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            features = model.backbone(inputs)['out']
            outputs = model.classifier(features)
            loss = criterion(outputs, labels)
            running_val_loss += loss.item() * features.size(0)
            _, preds = torch.max(outputs, 1)
            correct_predictions += torch.sum(preds == labels.data)

    epoch_val_loss = running_val_loss / len(val_dataset)
    epoch_val_acc = correct_predictions.double() / len(val_dataset)
    history['val_loss'].append(epoch_val_loss)
    history['val_acc'].append(epoch_val_acc.item())

    print(f"Epoch {epoch+1}/{NUM_EPOCHS} | Train Loss: {epoch_train_loss:.4f} | Val Loss: {epoch_val_loss:.4f} | Val Acc: {epoch_val_acc:.4f}")
    acc.append(epoch_val_acc)

    scheduler.step(epoch_val_loss)

    if epoch_val_loss < best_val_loss:
        best_val_loss = epoch_val_loss
        epochs_no_improve = 0
        # Save a full checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': best_val_loss,
        }, MODEL_SAVE_PATH)
        print(f"New best model saved to {MODEL_SAVE_PATH} with Val Loss: {best_val_loss:.4f}")
    else:
        epochs_no_improve += 1
        print(f"No improvement in Val Loss for {epochs_no_improve} epochs.")

    if epochs_no_improve >= PATIENCE:
        print(f"Early stopping triggered after {epoch+1} epochs.")
        break


end_time = time.time()
training_time = (end_time - start_time) / 60
print(f"Training finished in {training_time:.2f} minutes.")

Starting model training...
Starting Epoch 1
Batch 1: Loading data...
Batch 2: Loading data...
Batch 3: Loading data...
Batch 4: Loading data...
Batch 5: Loading data...
Batch 6: Loading data...
Batch 7: Loading data...
Batch 8: Loading data...
Batch 9: Loading data...
Batch 10: Loading data...
Batch 11: Loading data...
Batch 12: Loading data...
Batch 13: Loading data...
Batch 14: Loading data...
Batch 15: Loading data...
Batch 16: Loading data...
Batch 17: Loading data...
Batch 18: Loading data...
Batch 19: Loading data...
Batch 20: Loading data...
Batch 21: Loading data...
Batch 22: Loading data...
Batch 23: Loading data...
Batch 24: Loading data...
Batch 25: Loading data...
Batch 26: Loading data...
Batch 27: Loading data...
Batch 28: Loading data...
Batch 29: Loading data...
Batch 30: Loading data...
Batch 31: Loading data...
Batch 32: Loading data...
Batch 33: Loading data...
Batch 34: Loading data...
Batch 35: Loading data...
Batch 36: Loading data...
Batch 37: Loading data...
Bat

KeyboardInterrupt: 

Evaluation on the Test Set

In [None]:
print("Evaluating model on the test set...")

model = get_model(num_classes=len(class_names)).to(device)
checkpoint = torch.load(MODEL_SAVE_PATH)
model.load_state_dict(checkpoint['model_state_dict'])

model.eval()

all_preds = []
all_labels = []

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        features = model.backbone(inputs)['out']
        outputs = model.classifier(features)

        _, preds = torch.max(outputs, 1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Calculate metrics
test_acc = np.sum(np.array(all_preds) == np.array(all_labels)) / len(all_labels)
precision = precision_score(all_labels, all_preds, average='binary', pos_label=0) # '0' is mitotic
recall = recall_score(all_labels, all_preds, average='binary', pos_label=0)
f1 = f1_score(all_labels, all_preds, average='binary', pos_label=0)

print(f"Accuracy:  {test_acc:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall:    {recall:.4f}")
print(f"F1-Score:  {f1:.4f}")


# Confusion Matrix
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=['Mitotic (1)', 'Non-Mitotic (2)'],
            yticklabels=['Mitotic (1)', 'Non-Mitotic (2)'])
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
plt.title('Confusion Matrix on Test Set')
plt.show()

Evaluating model on the test set...


NameError: name 'get_model' is not defined

In [None]:
epochs = len(acc)

plt.figure(figsize=(8, 5))
plt.plot(epochs, acc)
plt.title('Validation Accuracy per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
plt.show()