In [2]:
import argparse
import os
import random
from collections import OrderedDict
from glob import glob

import numpy as np
import pandas as pd
import yaml
import tqdm
from scipy.stats import pearsonr, spearmanr

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader, Subset

import torchvision
from torchvision.utils import save_image

from albumentations.augmentations import transforms
from albumentations.core.composition import Compose, OneOf
from sklearn.model_selection import train_test_split

from easydict import EasyDict as edict
from torch.optim import lr_scheduler

# project-specific imports
from archs import gigatime
from losses import *
from utils import *
from prov_data import *

Argument Parser

In [None]:
import argparse

def str2bool(v):
    return v.lower() in ('true', '1')

def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--name', default="gigatime_training")  # experiment name
    parser.add_argument('--output_dir', default="./scratch")  # directory to save outputs
    parser.add_argument('--gpu_ids', nargs='+', type=int)  # GPU IDs to use for training
    parser.add_argument('--metadata', default="./../data/full_metadata.csv")  # path to metadata CSV file
    parser.add_argument('--tiling_dir', default="./../data/gigatime_training_tiles/")  # directory containing tiled images
    parser.add_argument('--epochs', default=1, type=int)  # number of training epochs
    parser.add_argument('--batch_size', default=32, type=int)  # batch size for training

    # model
    parser.add_argument('--arch', default='NestedUNet')  # model architecture
    parser.add_argument('--input_channels', default=3, type=int)  # number of input channels
    parser.add_argument('--num_classes', default=23, type=int)  # number of output classes
    parser.add_argument('--input_w', default=512, type=int)  # input image width to resize to
    parser.add_argument('--input_h', default=512, type=int)  # input image height to resize to

    # loss
    parser.add_argument('--loss', default='BCEDiceLoss')  # loss function to use

    # optimizer
    parser.add_argument('--optimizer', default='Adam', choices=['Adam', 'SGD'])  # optimizer type
    parser.add_argument('--lr', default=1e-3, type=float)  # learning rate
    parser.add_argument('--momentum', default=0.9, type=float)  # momentum for SGD optimizer
    parser.add_argument('--weight_decay', default=1e-4, type=float)  # weight decay for regularization
    parser.add_argument('--nesterov', default=False, type=str2bool)  # enable Nesterov acceleration for SGD

    # scheduler
    parser.add_argument('--scheduler', default='CosineAnnealingLR',
                        choices=['CosineAnnealingLR', 'ReduceLROnPlateau', 'MultiStepLR', 'ConstantLR'])  # learning rate scheduler
    parser.add_argument('--min_lr', default=1e-5, type=float)  # minimum learning rate
    parser.add_argument('--factor', default=0.1, type=float)  # factor to reduce learning rate
    parser.add_argument('--patience', default=2, type=int)  # patience for ReduceLROnPlateau scheduler
    parser.add_argument('--milestones', default='1,2', type=str)  # epochs to reduce learning rate for MultiStepLR
    parser.add_argument('--gamma', default=2/3, type=float)  # learning rate decay factor
    parser.add_argument('--early_stopping', default=-1, type=int)  # early stopping patience (-1 to disable)

    parser.add_argument('--num_workers', default=12, type=int)  # number of data loader workers
    parser.add_argument('--window_size', type=int, default=256)  # size of cropping window
    parser.add_argument('--sampling_prob', type=float, default=0.5)  # sampling probability for training data, this is just to speed up training if we want to debug or train on a less data
    parser.add_argument('--val_sampling_prob', type=float, default=0.01)  # sampling probability for validation data, this is just to speed up training if we want to debug or train on a less data
    parser.add_argument('--transformer', type=str2bool, default=False)  # enable transformer architecture
    parser.add_argument('--sigmoid', type=str2bool, default=True)  # apply sigmoid activation to output
    parser.add_argument('--crop', type=str2bool, default=False)  # enable random cropping during training

    return edict(vars(parser.parse_args([])))  # use [] so notebook runs

config = parse_args()


Metrics

In [4]:
mean = torch.tensor([0.485, 0.456, 0.406]).cuda()
std = torch.tensor([0.229, 0.224, 0.225]).cuda()



def calculate_correlations(matrix1, matrix2):
    """
    Calculate Pearson and Spearman correlation coefficients between two matrices.

    Args:
        matrix1 (np.ndarray): The first matrix.
        matrix2 (np.ndarray): The second matrix.

    Returns:
        dict: A dictionary containing Pearson and Spearman correlation coefficients.
    """
    assert matrix1.shape == matrix2.shape, "Matrices must have the same shape"
    b, c, h, w = matrix1.shape

    pearson_correlations = []
    spearman_correlations = []

    for channel in range(c):
        pearson_corrs = []
        spearman_corrs = []

        for batch in range(b):
            flat_matrix1 = matrix1[batch, channel].flatten()
            flat_matrix2 = matrix2[batch, channel].flatten()

            # Remove NaN values
            valid_indices = ~np.isnan(flat_matrix1.detach().cpu().numpy()) & ~np.isnan(flat_matrix2.detach().cpu().numpy())
            flat_matrix1 = flat_matrix1[valid_indices]
            flat_matrix2 = flat_matrix2[valid_indices]

            if len(flat_matrix1) > 0 and len(flat_matrix2) > 0:
                pearson_corr, _ = pearsonr(flat_matrix1.detach().cpu().numpy(), flat_matrix2.detach().cpu().numpy())
                spearman_corr, _ = spearmanr(flat_matrix1.detach().cpu().numpy(), flat_matrix2.detach().cpu().numpy())
            else:
                pearson_corr = np.nan
                spearman_corr = np.nan

            pearson_corrs.append(pearson_corr)
            spearman_corrs.append(spearman_corr)

        # Average correlations across the batch dimension
        pearson_correlations.append(np.nanmean(pearson_corrs))
        spearman_correlations.append(np.nanmean(spearman_corrs))

    return pearson_correlations, spearman_correlations
    

def split_into_boxes(tensor, box_size):
    # Get the dimensions of the tensor
    batch_size, channels, height, width = tensor.shape
    
    # Calculate the number of boxes along each dimension
    num_boxes_y = height // box_size
    num_boxes_x = width // box_size
    
    # Split the tensor into non-overlapping boxes
    boxes = tensor.unfold(2, box_size, box_size).unfold(3, box_size, box_size)
    boxes = boxes.contiguous().view(batch_size, channels, num_boxes_y, num_boxes_x, box_size, box_size)
    
    return boxes

def count_ones(boxes):
    # Count the number of ones in each box
    return boxes.sum(dim=(4, 5))



def get_box_metrics(pred, mask, box_size):
    # Split the images into boxes
    pred_boxes = split_into_boxes(pred, box_size)
    mask_boxes = split_into_boxes(mask, box_size)
    # Count the number of ones in each box
    pred_counts = count_ones(pred_boxes)
    mask_counts = count_ones(mask_boxes)
    
    # Calculate precision and MSE for the matrices
    mse = ((pred_counts.float() - mask_counts.float()) ** 2).mean(dim=0)    
    mean_mse_per_channel = mse.mean(dim=(1,2))

    mean_mse = mse.mean().item()

    pearson, spearman = calculate_correlations(pred_counts, mask_counts)
    
    return mean_mse_per_channel, pearson, spearman 





Data Loader Helpers

In [5]:
def sample_data_loader(data_loader, config, sample_fraction=0.1, deterministic=False, what_split="train"):
    # this just samples some fraction of the data in the dataloader so that we can train on a smaller subset for quick testing
    dataset = data_loader.dataset
    total_size = len(dataset)
    sample_size = int(total_size * sample_fraction)

    if deterministic:
        sample_indices = [i for i in range(sample_size)]
    else:
        sample_indices = random.sample(range(total_size), sample_size)

    subset = Subset(dataset, sample_indices)

    if what_split == "train":
        sample_loader = DataLoader(subset, batch_size=data_loader.batch_size, shuffle=True,
            num_workers=config['num_workers'], prefetch_factor=6, drop_last=True)
    else:
        sample_loader = DataLoader(subset, batch_size=data_loader.batch_size, shuffle=False,
            num_workers=config['num_workers'], prefetch_factor=6, drop_last=False)
    return sample_loader

Training & Validation Loops

In [6]:
def train(config, train_loader, model, criterion, optimizer):
    # Initialize average meters to track loss and Pearson correlation metrics
    avg_meters = {'loss': AverageMeter(), 'pearson': AverageMeter()}
    pearson_per_class_meters = [AverageMeter() for _ in range(config['num_classes'])]
    window_size = config['window_size']
    
    # Set model to training mode
    model.train()

    # Initialize progress bar for training loop
    pbar = tqdm.tqdm(total=len(train_loader))
    for input, target, name in train_loader:
        # Downsample target by factor of 8, then resize to input dimensions to make the target coarse to discount for any pixel level registration error
        downsampled_image = F.interpolate(target, scale_factor=1/8, mode='bilinear', align_corners=False)
        target = F.interpolate(downsampled_image, size=(config["input_h"],config["input_h"]), mode='bilinear', align_corners=False)
        target = target.cuda()
        
        # Forward pass through model
        output_image = model(input.cuda()).cuda()

        # Calculate loss between predicted and target images
        loss = criterion(output_image, target)
        
        # Backpropagation and parameter update
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Calculate IoU metrics for overall and per-class evaluation
        _, pearson, _ = get_box_metrics(output_image, target, box_size=8)

        
        # Update per-class Pearson meters
        for class_idx, pearson_value in enumerate(pearson):
            pearson_per_class_meters[class_idx].update(pearson_value, input.size(0))

        # Update average meters with current batch metrics
        avg_meters['loss'].update(loss.item(), input.size(0))
        avg_meters['pearson'].update(np.nanmean(pearson), input.size(0))

        # Update progress bar with current metrics
        pbar.set_postfix({'loss': avg_meters['loss'].avg, 'pearson': avg_meters['pearson'].avg})
        pbar.update(1)
    pbar.close()

    # Return ordered dictionary with training metrics
    return OrderedDict([('loss', avg_meters['loss'].avg), ('pearson', avg_meters['pearson'].avg)] +
                       [(f'class_{i}', m.avg) for i, m in enumerate(pearson_per_class_meters)])

def validate(config, val_loader, model, criterion):
    # Initialize average meters to track validation loss and Pearson correlation metrics
    avg_meters = {'loss': AverageMeter(), 'pearson': AverageMeter()}
    pearson_per_class_meters = [AverageMeter() for _ in range(config['num_classes'])]
    window_size = config['window_size']
    
    # Set model to evaluation mode (disables dropout, batch norm updates)
    model.eval()

    # Disable gradient computation for validation (saves memory and computation)
    with torch.no_grad():
        # Initialize progress bar for validation loop
        pbar = tqdm.tqdm(total=len(val_loader))
        for input, target, name in val_loader:
            # Downsample target by factor of 8, then resize to input dimensions
            downsampled_image = F.interpolate(target, scale_factor=1/8, mode='bilinear', align_corners=False)
            target = F.interpolate(downsampled_image, size=(config["input_h"],config["input_h"]), mode='bilinear', align_corners=False)
            target = target.cuda()
            
            # Forward pass through model
            output_image = model(input.cuda()).cuda()

            # Calculate validation loss
            loss = criterion(output_image, target)
            
            # Calculate IoU metrics for overall and per-class evaluation
            _, pearson, _ = get_box_metrics(output_image, target, box_size=8)
            
            # Update per-class IoU meters
            for class_idx, pearson_value in enumerate(pearson):
                pearson_per_class_meters[class_idx].update(pearson_value, input.size(0))

            # Update average meters with current batch metrics
            avg_meters['loss'].update(loss.item(), input.size(0))
            avg_meters['pearson'].update(np.nanmean(pearson), input.size(0))

            # Update progress bar with current metrics
            pbar.set_postfix({'loss': avg_meters['loss'].avg, 'pearson': avg_meters['pearson'].avg})
            pbar.update(1)
        pbar.close()

    # Return ordered dictionary with validation metrics
    return OrderedDict([('loss', avg_meters['loss'].avg), ('pearson', avg_meters['pearson'].avg)] +
                       [(f'class_{i}', m.avg) for i, m in enumerate(pearson_per_class_meters)])

Model, Loss, Optimizer, Scheduler

In [7]:
# channel names
common_channel_list = [
    'DAPI','TRITC','Cy5','PD-1_1:200 - Cy5','CD14 - Cy5','CD4 - Cy5','T-bet - Cy5',
    'CD34 - Cy5','CD68_1:100 - TRITC','CD16 - Cy5','CD11c - Cy5','CD138 - TRITC',
    'CD20 - TRITC','CD3_1:1000 - Cy5','CD8 - TRITC','PD-L1 - Cy5','CK_1:150 - TRITC',
    'Ki67_1:150 - TRITC','Tryptase - TRITC','Actin-D - TRITC','Caspase3-D - Cy5',
    'PHH3-B - Cy5','Transgelin - TRITC'
]

# loss
if config['loss'] == 'MSELoss':
    criterion = nn.MSELoss().cuda()
elif config['loss'] == 'BCEWithLogitsLoss':
    criterion = nn.BCEWithLogitsLoss().cuda()
elif config['loss'] == 'BCEDiceLoss':
    criterion = BCEDiceLoss().cuda()
else:
    criterion = losses.__dict__[config['loss']]().cuda()

# model
model = gigatime(num_classes=config['num_classes'],
                 sigmoid=config["sigmoid"],
                 loss_type=config["loss"],
                 input_channels=config['input_channels']).cuda()

if config["gpu_ids"] and len(config["gpu_ids"]) > 1:
    model = nn.DataParallel(model, device_ids=config["gpu_ids"])
    print("using multiple GPUs", config["gpu_ids"])

# optimizer
params = filter(lambda p: p.requires_grad, model.parameters())
if config['optimizer'] == 'Adam':
    optimizer = optim.Adam(params, lr=config['lr'], weight_decay=config['weight_decay'])
elif config['optimizer'] == 'SGD':
    optimizer = optim.SGD(params, lr=config['lr'], momentum=config['momentum'],
                          nesterov=config['nesterov'], weight_decay=config['weight_decay'])

# scheduler
if config['scheduler'] == 'CosineAnnealingLR':
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['epochs'], eta_min=config['min_lr'])
elif config['scheduler'] == 'ReduceLROnPlateau':
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=config['factor'],
                                                     patience=config['patience'], verbose=1,
                                                     min_lr=config['min_lr'])
elif config['scheduler'] == 'MultiStepLR':
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=[int(e) for e in config['milestones'].split(',')],
                                               gamma=config['gamma'])
elif config['scheduler'] == 'ConstantLR':
    scheduler = None

Dataset Creation

In [8]:
# Read metadata CSV file containing experiment information
metadata = pd.read_csv(config["metadata"])

# Convert tiling directory path to Path object for easier file handling
tiliting_dir = Path(config["tiling_dir"])

# Generate dataframe containing tile pair information from metadata and tiling directory
tile_pair_df = generate_tile_pair_df(metadata=metadata, tiling_dir=tiliting_dir)

# Filter tile pairs based on image quality metrics, the exact numbers were decided based on empirical analysis as well as suggestions from domain experts:
# - Black ratio < 0.3
# - Variance > 200
# Applied to both COMET and H&E images
tile_pair_df_filtered = tile_pair_df[tile_pair_df.apply(
    lambda x: (x["img_comet_black_ratio"] < 0.3) &
              (x["img_comet_variance"] > 200) &
              (x["img_he_black_ratio"] < 0.3) &
              (x["img_he_variance"] > 200), axis=1
)]

# Load segmentation metrics from JSON files for each unique directory
dir_names = tile_pair_df_filtered["dir_name"].unique()
segment_metric_dict = {}
for dir_name in dir_names:
    # Read segment metrics JSON file from each directory
    with open(os.path.join(dir_name, "segment_metric.json"), "r") as f:
        segment_metric_list = json.load(f)
    segment_metric_dict[dir_name] = segment_metric_list

# Initialize new columns based on metric keys from first directory's first entry
new_columns = {col: [] for col in next(iter(segment_metric_dict[dir_names[0]].values())).keys()}

# Populate new columns with metrics for each tile pair
for _, row in tile_pair_df_filtered.iterrows():
    # Get metrics for current tile pair from corresponding directory
    metrics = segment_metric_dict[row["dir_name"]][row["pair_name"]]
    # Add each metric value to corresponding column list
    for key, value in metrics.items():
        new_columns[key].append(value)

# Add all metric columns to the filtered dataframe
for key, values in new_columns.items():
    tile_pair_df_filtered[key] = values

# Further filter tile pairs based on dice coefficient > 0.2 for better segmentation quality
tile_pair_df_filtered_dicefilter = tile_pair_df_filtered[tile_pair_df_filtered["dice"] > 0.2]


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  tile_pair_df_filtered[key] = values
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  tile_pair_df_filtered[key] = values
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  tile_pair_df_filtered[key] = values
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col

Data loader setup

In [9]:
import albumentations as geometric

# Check if random cropping is enabled for data augmentation
if config['crop']: ## we use this by default as our patches are 512 and we train with random 256*256 crops for better generalization
    # Define training augmentation pipeline with cropping
    train_transform = Compose([
        geometric.RandomRotate90(),  # Randomly rotate images by 90 degrees
        geometric.Flip(),  # Random horizontal/vertical flips
        OneOf([  # Apply one of the following color augmentations
            transforms.HueSaturationValue(),  # Adjust hue, saturation, and value
            transforms.RandomBrightnessContrast(brightness_limit=0, contrast_limit=0.2),  # Adjust contrast only
            transforms.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0),  # Adjust brightness only
        ], p=1),
        geometric.RandomCrop(config['input_h'], config['input_w']),  # Random crop to target size
        transforms.Normalize()  # Normalize pixel values
    ],
        is_check_shapes=False)

    # Define validation transform with only cropping (no augmentation), also note that during validation we process 
    val_transform = Compose([
        geometric.Resize(config['input_h'], config['input_w']),  # Resize to target dimensions
        transforms.Normalize()  # Normalize pixel values
    ],
        is_check_shapes=False)

else:
    # Define training augmentation pipeline with resizing instead of cropping
    train_transform = Compose([
        geometric.RandomRotate90(),  # Randomly rotate images by 90 degrees
        geometric.Flip(),  # Random horizontal/vertical flips
        OneOf([  # Apply one of the following color augmentations
            transforms.HueSaturationValue(),  # Adjust hue, saturation, and value
            transforms.RandomBrightnessContrast(brightness_limit=0, contrast_limit=0.2),  # Adjust contrast only
            transforms.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0),  # Adjust brightness only
        ], p=1),
        geometric.Resize(config['input_h'], config['input_w']),  # Resize to target dimensions
        transforms.Normalize()  # Normalize pixel values
    ],
        is_check_shapes=False)

    # Define validation transform with only resizing (no augmentation)
    val_transform = Compose([
        geometric.Resize(config['input_h'], config['input_w']),  # Resize to target dimensions
        transforms.Normalize()  # Normalize pixel values
    ],
        is_check_shapes=False)

# Create training dataset with augmentation and cell masking enabled
train_dataset = HECOMETDataset_roi(
        all_tile_pair=tile_pair_df,  # Complete tile pair dataframe
        tile_pair_df=tile_pair_df_filtered_dicefilter,  # Filtered tile pairs based on quality metrics
        transform=train_transform,  # Apply training augmentations
        dir_path = config["tiling_dir"],  # Path to tiled image directory
        window_size = config["window_size"],  # Size of image windows to extract
        split="train",  # Specify training split
        mask_noncell=True,  # Mask non-cellular regions
        cell_mask_label=True,  # Use cell mask labels
    )    

# Create validation dataset with minimal transforms
val_dataset = HECOMETDataset_roi(
    all_tile_pair=tile_pair_df,  # Complete tile pair dataframe
    tile_pair_df=tile_pair_df_filtered_dicefilter,  # Filtered tile pairs based on quality metrics
    transform=val_transform,  # Apply validation transforms (no augmentation)
    dir_path = config["tiling_dir"],  # Path to tiled image directory
    window_size = config["window_size"],  # Size of image windows to extract
    split="valid",  # Specify validation split
    standard = "silver",  # this just sets the configration for validation tiles based on dice, 
    mask_noncell=True,  # Mask non-cellular regions
    cell_mask_label=True,  # Use cell mask labels
)    

# Create training data loader with shuffling and parallel loading
train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True,
                          num_workers=config['num_workers'], prefetch_factor=6, drop_last=True)

# Create validation data loader without shuffling
val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False,
                        num_workers=config['num_workers'], prefetch_factor=6, drop_last=False)

# Sample a subset of validation data for faster evaluation during training
val_loader = sample_data_loader(val_loader, config, config['val_sampling_prob'], deterministic=True, what_split="valid")

Training loop

In [11]:

for epoch in range(config['epochs']):
    print(f"Epoch [{epoch+1}/{config['epochs']}]")

    # --- Train ---
    train_log = train(config, train_loader, model, criterion, optimizer)

    # --- Validate ---
    val_log = validate(config, val_loader, model, criterion)

    print(f"Train -> Loss: {train_log['loss']:.4f}, IoU: {train_log['pearson']:.4f}")
    print(f"Val   -> Loss: {val_log['loss']:.4f}, IoU: {val_log['pearson']:.4f}")

    print("End of Epoch 1")
    print("Change settings (epochs) to train fully or use the train.py script to train the model")

Epoch [1/1]


100%|██████████| 879/879 [47:08<00:00,  3.22s/it, loss=0.788, pearson=0.287]
100%|██████████| 4/4 [00:26<00:00,  6.60s/it, loss=0.863, pearson=0.239]

Train -> Loss: 0.7882, IoU: 0.2874
Val   -> Loss: 0.8626, IoU: 0.2393
End of Epoch 1
Change settings (epochs) to train fully or use the train.py script to train the model



