# MIL (Multiple Instance Learning) Tutorial

This tutorial demonstrates how to use Multiple Instance Learning (MIL) models with Slideflow for digital pathology tasks. 

## Setting up the environment

First, let's import the necessary libraries and set up our project.

In [1]:
import os
import slideflow as sf
from slideflow.mil import mil_config

# Set the backend to PyTorch
os.environ['SF_BACKEND'] = 'torch'

# Initialize the project
project_root = '/mnt/labshare/DL_OTHER/TEST_PROJECTS/TEST_PROJECT'
P = sf.Project(root=project_root)

  from .autonotebook import tqdm as notebook_tqdm


## Save features bags as torch tensors

When generating feature bags, make sure to save them as torch tensors. 
I commented out the code because the features should already be saved.

In [2]:
# extractor = 'ctranspath'
# weights_path = f'/mnt/labshare/MODELS/{extractor}/pytorch_model.bin'
# dataset = P.dataset(tile_px=299, tile_um=302)
# extractor_model = sf.model.build_feature_extractor(extractor, center_crop=True)
# features = sf.DatasetFeatures(extractor_model, dataset=dataset, normalizer='reinhard')
# features.to_torch(project_root + f'/features/{extractor}/torch')

## Splitting Data for Training and Validation

There are different strategies for splitting data between training and validation sets. Here are some common approaches:

In [3]:

# Option 1: Using filters
train_dataset = P.dataset(tile_px=299, tile_um=302, filters={'dataset': 'train'})
val_dataset = P.dataset(tile_px=299, tile_um=302, filters={'dataset': 'val'})

# Option 2: Using k-fold cross-validation
dataset = P.dataset(tile_px=299, tile_um=302)
splits = dataset.kfold_split(k=2, splits='./splits.json')

## Training a MIL Model

Now, let's train a MIL model using the Attention MIL architecture. 

First using the k-fold cross-validation approach then using the train and val datasets.


In [4]:
extractor = 'ctranspath'
model = 'attention_mil'

# k-fold cross-validation training
config = mil_config(model, bag_size=4, batch_size=2, epochs=2)
# FIXME: here bag_size, batch_size, and epochs are hardcoded to low values to 
#        make the tutorial run fast. Do not use these values.

for i, (train, val) in enumerate(splits):
    P.train_mil(
        config=config,
        exp_label=f'{extractor}_{model}_fold{i}',
        outcomes='cohort',
        train_dataset=train,
        val_dataset=val,
        bags=project_root + f'/features/{extractor}/torch'
    )

# FIXME: if you run this notebook you'll save many mil models at TEST_PROJECT/mil
#        please remove them so we don't accumulate too many models in the project.

epoch,train_loss,valid_loss,roc_auc_score,time
0,0.368933,0.353699,0.666667,00:00
1,0.185689,0.340995,0.666667,00:00


Better model found at epoch 0 with valid_loss value: 0.35369935631752014.
Better model found at epoch 1 with valid_loss value: 0.3409951627254486.


epoch,train_loss,valid_loss,roc_auc_score,time
0,0.376723,0.351969,0.666667,00:00
1,0.287166,0.343922,0.333333,00:00


Better model found at epoch 0 with valid_loss value: 0.3519689738750458.
Better model found at epoch 1 with valid_loss value: 0.343921959400177.


In [5]:
P.train_mil(
    config=config,
    exp_label=f'{extractor}_{model}',
    outcomes='cohort',
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    bags=project_root + f'/features/{extractor}/torch'
)

epoch,train_loss,valid_loss,roc_auc_score,time
0,0.714222,0.489486,0.666667,00:00
1,0.522529,0.529432,0.333333,00:00


Better model found at epoch 0 with valid_loss value: 0.48948606848716736.


<fastai.learner.Learner at 0x7f2bbb6dddf0>

## Evaluating a MIL Model

After training, it's important to evaluate the model's performance on a separate test set.

In [6]:
test_dataset = P.dataset(tile_px=299, tile_um=302, filters={'dataset': 'val'})

df = P.evaluate_mil(
    '/mnt/labshare/DL_OTHER/TEST_PROJECTS/TEST_PROJECT/mil/00000-ctranspath_attention_mil',
    outcomes='cohort',
    dataset=test_dataset,
    bags=project_root + f'/features/{extractor}/torch'
)

## Different MIL Model Architectures

Slideflow supports various MIL architectures. Here are some examples:

- Attention MIL
- TransMIL
- CLAM-SB

Let's train models using different architectures:

In [7]:
models = ['transmil', 'attention_mil']

for model in models:
    config = mil_config(model, bag_size=4, batch_size=2, epochs=2)
    P.train_mil(
        config=config,
        exp_label=f'{extractor}_{model}',
        outcomes='cohort',
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        bags=project_root + f'/features/{extractor}/torch'
    )

epoch,train_loss,valid_loss,roc_auc_score,time
0,0.331518,0.402687,0.0,00:00
1,0.333517,0.538786,0.0,00:01


Better model found at epoch 0 with valid_loss value: 0.40268686413764954.


epoch,train_loss,valid_loss,roc_auc_score,time
0,0.399557,0.392156,0.666667,00:01
1,0.237676,0.391844,0.666667,00:00


Better model found at epoch 0 with valid_loss value: 0.39215579628944397.
Better model found at epoch 1 with valid_loss value: 0.3918438255786896.


## Aggregation Level

The `aggregation_level` parameter determines how the bags are formed and how outcomes are assigned:

- If `aggregation_level` is 'slide' (default), the bags are comprised of tiles from one slide, and there will be an outcome for each slide.
- If `aggregation_level` is 'patient', the bags are comprised of tiles from all slides belonging to a patient, and there will be a single outcome for each patient. This only has an effect when all or some patients have multiple slides.



## Save Monitor

We can also change the `save_monitor` from the default 'loss' to 'roc_auc_score', in which case the model with the best AUC will be saved.

In [8]:
model = 'attention_mil'
config = mil_config(
    model, bag_size=4, batch_size=2, epochs=2,

    # aggregation_level='patient',
    save_monitor='roc_auc_score'
)

P.train_mil(
    config=config,
    exp_label=f'{extractor}_{model}_patient_auc',
    outcomes='cohort',
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    bags=project_root + f'/features/{extractor}/torch'
)

# Note: aggregation Level and Save Monitor have nothing to do with each other.
#       I just wanted to make the tutorial shorter.

epoch,train_loss,valid_loss,roc_auc_score,time
0,0.339066,0.397487,1.0,00:00
1,0.320226,0.398026,1.0,00:00


Better model found at epoch 0 with roc_auc_score value: 1.0.


<fastai.learner.Learner at 0x7f2d0ac69850>

## Other Configurable Parameters

Aside from `epoch`, `bag_size`, `batch_size` - which we have been specifying to 
have the cell run fast and work on a small test dataset -
MIL models have many other parameters that can be adjusted to optimize performance.
Here's an example of how to modify some of these parameters:

In [10]:
config = mil_config(
    model,
    lr=1e-3,
    bag_size=4,
    epochs=3,
    batch_size=2
)

P.train_mil(
    config=config,
    exp_label=f'{extractor}_{model}_custom',
    outcomes='cohort',
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    bags=project_root + f'/features/{extractor}/torch'
)

epoch,train_loss,valid_loss,roc_auc_score,time
0,0.467339,0.385258,0.666667,00:00
1,0.36686,0.383177,1.0,00:00
2,0.285895,0.382503,0.666667,00:00


Better model found at epoch 0 with valid_loss value: 0.38525834679603577.
Better model found at epoch 1 with valid_loss value: 0.3831765651702881.
Better model found at epoch 2 with valid_loss value: 0.38250330090522766.


<fastai.learner.Learner at 0x7f2ceaee3820>

## Clean up
remove all folder in `/mnt/labshare/DL_OTHER/TEST_PROJECTS/TEST_PROJECT/mil` except 
`00000-ctranspath_attention_mil`

In [12]:
import shutil

# Define the path to the folder
folder_path = '/mnt/labshare/DL_OTHER/TEST_PROJECTS/TEST_PROJECT/mil'

# List all directories in the folder
directories = [d for d in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, d))]
print(directories)

# Remove all directories except the one that has the name '00000-ctranspath_attention_mil'
for directory in directories:
    if directory != '00000-ctranspath_attention_mil':
        shutil.rmtree(os.path.join(folder_path, directory))

# print all folders in /mnt/labshare/DL_OTHER/TEST_PROJECTS/TEST_PROJECT/mil
print(os.listdir('/mnt/labshare/DL_OTHER/TEST_PROJECTS/TEST_PROJECT/mil')) 

['00022-ctranspath_attention_mil_custom', '00020-ctranspath_attention_mil', '00019-ctranspath_transmil', '00016-ctranspath_attention_mil_fold0', '00017-ctranspath_attention_mil_fold1', '00000-ctranspath_attention_mil', '00018-ctranspath_attention_mil', '00021-ctranspath_attention_mil_patient_auc', '00023-ctranspath_attention_mil_custom']
['00000-ctranspath_attention_mil']
