# MIL (Multiple Instance Learning) Tutorial

This tutorial demonstrates how to use Multiple Instance Learning (MIL) models with Slideflow for digital pathology tasks. 

## Setting up the environment

First, let's import the necessary libraries and set up our project.

In [2]:
import os
import slideflow as sf
from slideflow.mil import mil_config

# Set the backend to PyTorch
os.environ['SF_BACKEND'] = 'torch'

# Initialize the project
project_root = '/mnt/labshare/DL_OTHER/TEST_PROJECTS/TEST_PROJECT'
P = sf.Project(root=project_root)

  from .autonotebook import tqdm as notebook_tqdm


## Save features bags as torch tensors

When generating feature bags, make sure to save them as torch tensors. 
I commented out the code because the features should already be saved.

In [12]:
# extractor = 'ctranspath'
# weights_path = f'/mnt/labshare/MODELS/{extractor}/pytorch_model.bin'
# dataset = P.dataset(tile_px=299, tile_um=302)
# extractor_model = sf.model.build_feature_extractor(extractor, center_crop=True)
# features = sf.DatasetFeatures(extractor_model, dataset=dataset, normalizer='reinhard')
# features.to_torch(project_root + f'/features/{extractor}/torch')

## Splitting Data for Training and Validation

There are different strategies for splitting data between training and validation sets. Here are some common approaches:

In [3]:

# Option 1: Using filters
train_dataset = P.dataset(tile_px=299, tile_um=302, filters={'dataset': 'train'})
val_dataset = P.dataset(tile_px=299, tile_um=302, filters={'dataset': 'val'})

# Option 2: Using k-fold cross-validation
dataset = P.dataset(tile_px=299, tile_um=302)
splits = dataset.kfold_split(k=2, splits='./splits.json')

# k-fold cross-validation training
extractor = 'ctranspath'
model = 'attention_mil'
config = mil_config(model, bag_size=4, batch_size=2, epochs=2)
# FIXME: here bag_size, batch_size, and epochs are hardcoded to low values to 
#        make the tutorial run fast. Do not use these values.

for i, (train, val) in enumerate(splits):
    P.train_mil(
        config=config,
        exp_label=f'{extractor}_{model}_fold{i}',
        outcomes='cohort',
        train_dataset=train,
        val_dataset=val,
        bags=project_root + f'/features/{extractor}/torch'
    )

epoch,train_loss,valid_loss,roc_auc_score,time
0,0.526444,0.371114,0.0,00:01
1,0.385448,0.374279,0.333333,00:00


Better model found at epoch 0 with valid_loss value: 0.37111377716064453.


epoch,train_loss,valid_loss,roc_auc_score,time
0,0.49336,0.361857,0.666667,00:00
1,0.425078,0.359169,0.666667,00:00


Better model found at epoch 0 with valid_loss value: 0.3618566393852234.
Better model found at epoch 1 with valid_loss value: 0.35916879773139954.


## Training a MIL Model

Now, let's train a MIL model using the Attention MIL architecture.

In [21]:
extractor = 'ctranspath'
model = 'attention_mil'

config = mil_config(model, epochs=2, bag_size=4, batch_size=2)

P.train_mil(
    config=config,
    exp_label=f'{extractor}_{model}',
    outcomes='cohort',
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    bags=project_root + f'/features/{extractor}/torch'
)

epoch,train_loss,valid_loss,roc_auc_score,time
0,0.351417,0.394075,0.666667,00:00
1,0.238186,0.388286,0.666667,00:00


Better model found at epoch 0 with valid_loss value: 0.3940750062465668.
Better model found at epoch 1 with valid_loss value: 0.3882855176925659.


<fastai.learner.Learner at 0x7f3e819c61c0>

## Evaluating a MIL Model

After training, it's important to evaluate the model's performance on a separate test set.

In [26]:
test_dataset = P.dataset(tile_px=299, tile_um=302, filters={'dataset': 'test'})

df = P.evaluate_mil(
    'mil/00000-virchow2_attention_mil',
    outcomes='adsq',
    dataset=test_dataset,
    bags=project_root + f'/features/{extractor}/torch'
)

ModelError: Could not find `mil_params.json` at mil/00000-virchow2_attention_mil. Check the provided model/weights path, or provide a configuration with 'config'.

## Different MIL Model Architectures

Slideflow supports various MIL architectures. Here are some examples:

- Attention MIL
- TransMIL
- CLAM-SB

Let's train models using different architectures:

In [None]:
models = ['attention_mil', 'transmil', 'clam_sb']

for model in models:
    config = mil_config(model, epochs=2)
    P.train_mil(
        config=config,
        exp_label=f'{extractor}_{model}',
        outcomes='adsq',
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        bags=project_root + f'/features/{extractor}/torch'
    )

## Aggregation Level

The `aggregation_level` parameter determines how the bags are formed and how outcomes are assigned:

- If `aggregation_level` is 'slide' (default), the bags are comprised of tiles from one slide, and there will be an outcome for each slide.
- If `aggregation_level` is 'patient', the bags are comprised of tiles from all slides belonging to a patient, and there will be a single outcome for each patient. This only has an effect when all or some patients have multiple slides.

## Save Monitor

We can also change the `save_monitor` from the default 'loss' to 'roc_auc_score', in which case the model with the best AUC will be saved.

In [None]:
config = mil_config(
    model,
    epochs=2,
    aggregation_level='patient',
    save_monitor='roc_auc_score'
)

P.train_mil(
    config=config,
    exp_label=f'{extractor}_{model}_patient_auc',
    outcomes='adsq',
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    bags=project_root + f'/features/{extractor}/torch'
)

# Note: aggregation Level and Save Monitor have nothing to do with each other.
#       I just wanted to make the tutorial shorter.

## Other Configurable Parameters

Aside from `epoch` - which we have been specifying to have the cell run fast -
MIL models have many other parameters that can be adjusted to optimize performance.
Here's an example of how to modify some of these parameters:

In [None]:
config = mil_config(
    model,
    lr=1e-3,
    bag_size=4,
    epochs=3,
    batch_size=2
)

P.train_mil(
    config=config,
    exp_label=f'{extractor}_{model}_custom',
    outcomes='adsq',
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    bags=project_root + f'/features/{extractor}/torch'
)