# MIL (Multiple Instance Learning) Tutorial

This tutorial demonstrates how to use Multiple Instance Learning (MIL) models with Slideflow for digital pathology tasks. 

## Setting up the environment

First, let's import the necessary libraries and set up our project.

In [None]:
# Set environment variables with os package
import os
import slideflow as sf
from slideflow.mil import mil_config

# Set verbose logging
import logging
logging.getLogger('slideflow').setLevel(logging.INFO)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '10'
import sys
sys.stderr = sys.__stdout__

# Check if slideflow was properly installed
sf.about()

# Initialize the project
project_root = '/mnt/labshare/PROJECTS/TEST_PROJECTS/TEST_PROJECT'
P = sf.Project(root=project_root)

<a id='project'></a>
### Getting Started with a Slideflow Project

This tutorial assumes that you have already created a project folder. Once the project has been created and you have specified the paths to datasets, annotation files, etc. we will begin by initializing a Slideflow Project object.

In [None]:
# Set root paths
username = "skochanny" # change me
root_path = f'/scratch/{username}/PROJECTS'
labshare_path = '/gpfs/data/pearson-lab/'
project_name = "TEST_PROJECT"
project_root_path = f"{root_path}/{project_name}"

# Initialize the Project class object
P = sf.Project(project_root_path)

### Save features bags as torch tensors

When generating feature bags, make sure to save them as torch tensors (see `feature_extraction.ipynb` tutorial).

I commented out the code because the features should already be saved. 

In [12]:
# extractor = 'ctranspath' # change me
# weights_path = f'/mnt/labshare/MODELS/{extractor}/pytorch_model.bin'
# dataset = P.dataset(tile_px=299, tile_um=302) # update to tile size extracted
# extractor_model = sf.model.build_feature_extractor(extractor, center_crop=True)
# features = sf.DatasetFeatures(extractor_model, dataset=dataset, normalizer='reinhard')
# features.to_torch(project_root + f'/features/{extractor}/torch')

### Splitting Data for Training and Validation

There are different strategies for splitting data between training and validation sets. Here are some common approaches:

In [None]:
# Create dataset
dataset = P.dataset(tile_px=299, tile_um=302)

# Option 1: Using filters - assumes you have 'dataset' outcome header in annotations file
train_dataset = P.dataset(tile_px=299, tile_um=302, filters={'dataset': 'train'})
val_dataset = P.dataset(tile_px=299, tile_um=302, filters={'dataset': 'val'})

# Option 2: Using k-fold cross-validation kfold_split()
dataset = P.dataset(tile_px=299, tile_um=302)
splits = dataset.kfold_split(k=2, splits='./splits.json')

# Option 3: 
# Split a dataset into training and validation using 5-fold cross-validation, with this being the first cross-fold.
train_n_val_dataset, test_dataset = dataset.split(
    model_type='classification', # Categorical labels
    labels='storrc_group',       # Label to balance between datasets
    val_fraction=0.1             # fraction to use for testing
)

train_dataset, val_dataset = dataset.split(
    model_type='classification', # Categorical labels
    labels='storrc_group',       # Label to balance between datasets
    val_fraction=0.3             # fraction to use for testing
)




# k-fold cross-validation training
extractor = 'ctranspath'
model = 'attention_mil'
config = mil_config(model, bag_size=4, batch_size=2, epochs=2)
# FIXME: here bag_size, batch_size, and epochs are hardcoded to low values to 
#        make the tutorial run fast. Do not use these values.

for i, (train, val) in enumerate(splits):
    P.train_mil(
        config=config,
        exp_label=f'{extractor}_{model}',
        outcomes='cohort',
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        bags=project_root + f'/features/{extractor}/torch'
    )

## Training a MIL Model

Now, let's train a MIL model using the Attention MIL architecture.

In [None]:
extractor = 'ctranspath'
model = 'attention_mil'

config = mil_config(model, epochs=2, bag_size=4, batch_size=2)

P.train_mil(
    config=config,
    exp_label=f'{extractor}_{model}_patient_auc',
    outcomes='cohort',
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    bags=project_root + f'/features/{extractor}/torch'
)

# Note: aggregation Level and Save Monitor have nothing to do with each other.
#       I just wanted to make the tutorial shorter.

## Other Configurable Parameters

Aside from `epoch`, `bag_size`, `batch_size` - which we have been specifying to 
have the cell run fast and work on a small test dataset -
MIL models have many other parameters that can be adjusted to optimize performance.
Here's an example of how to modify some of these parameters:

In [None]:
test_dataset = P.dataset(tile_px=299, tile_um=302, filters={'dataset': 'test'})

df = P.evaluate_mil(
    'mil/00000-virchow2_attention_mil', # path to model, or name of model in project/mil
    outcomes='cohort',
    dataset=test_dataset,
    bags=project_root + f'/features/{extractor}/torch'
)

## Different MIL Model Architectures

Slideflow supports various MIL architectures. Here are some examples:

- Attention MIL
- TransMIL
- CLAM-SB

Let's train models using different architectures:

In [None]:
models = ['attention_mil', 'transmil', 'clam_sb']

for model in models:
    config = mil_config(model, epochs=2)
    P.train_mil(
        config=config,
        exp_label=f'{extractor}_{model}',
        outcomes='adsq',
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        bags=project_root + f'/features/{extractor}/torch'
    )

## Aggregation Level

The `aggregation_level` parameter determines how the bags are formed and how outcomes are assigned:

- If `aggregation_level` is 'slide' (default), the bags are comprised of tiles from one slide, and there will be an outcome for each slide.
- If `aggregation_level` is 'patient', the bags are comprised of tiles from all slides belonging to a patient, and there will be a single outcome for each patient. This only has an effect when all or some patients have multiple slides.

## Save Monitor

We can also change the `save_monitor` from the default 'loss' to 'roc_auc_score', in which case the model with the best AUC will be saved.

In [None]:
config = mil_config(
    model,
    epochs=2,
    aggregation_level='patient',
    save_monitor='roc_auc_score'
)

P.train_mil(
    config=config,
    exp_label=f'{extractor}_{model}_patient_auc',
    outcomes='adsq',
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    bags=project_root + f'/features/{extractor}/torch'
)

# Note: aggregation Level and Save Monitor have nothing to do with each other.
#       I just wanted to make the tutorial shorter.

## Other Configurable Parameters

Aside from `epoch` - which we have been specifying to have the cell run fast -
MIL models have many other parameters that can be adjusted to optimize performance.
Here's an example of how to modify some of these parameters:

In [None]:
config = mil_config(
    model,
    lr=1e-3,
    bag_size=4,
    epochs=3,
    batch_size=2
)

P.train_mil(
    config=config,
    exp_label=f'{extractor}_{model}_custom',
    outcomes='cohort',
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    bags=project_root + f'/features/{extractor}/torch'
)

## Clean up
remove all folder in `/mnt/labshare/DL_OTHER/TEST_PROJECTS/TEST_PROJECT/mil` except 
`00000-ctranspath_attention_mil`

In [12]:
import shutil

# Define the path to the folder
folder_path = '/mnt/labshare/DL_OTHER/TEST_PROJECTS/TEST_PROJECT/mil'

# List all directories in the folder
directories = [d for d in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, d))]
print(directories)

# Remove all directories except the one that has the name '00000-ctranspath_attention_mil'
for directory in directories:
    if directory != '00000-ctranspath_attention_mil':
        shutil.rmtree(os.path.join(folder_path, directory))

# print all folders in /mnt/labshare/DL_OTHER/TEST_PROJECTS/TEST_PROJECT/mil
print(os.listdir('/mnt/labshare/DL_OTHER/TEST_PROJECTS/TEST_PROJECT/mil')) 

['00022-ctranspath_attention_mil_custom', '00020-ctranspath_attention_mil', '00019-ctranspath_transmil', '00016-ctranspath_attention_mil_fold0', '00017-ctranspath_attention_mil_fold1', '00000-ctranspath_attention_mil', '00018-ctranspath_attention_mil', '00021-ctranspath_attention_mil_patient_auc', '00023-ctranspath_attention_mil_custom']
['00000-ctranspath_attention_mil']
