# Quick Start: Train MintFlow on Multiple Tissue Sections

Tutorial for basic training on multiple tissue sections

- **Creator**: Amir Akbarnejad (aa36@sanger.ac.uk)
- **Affiliation**: Wellcome Sanger Institute and University of Cambridge
- **Date of Creation**: 01.07.2025
- **Date of Last Modificaion**: 01.07.2025

**To be able to run the notebook, the parts that you need to modify are specified by `TODO:MODIFY:`. The rest can be left untouched, as far as the goal is to run the notebook.**

This notebook demonstrates how to train MintFlow on multiple tissue sections. This notebook is only for demonstration, and to get biologically meaningful results you may need longer training and/or different hyper-parameter settings.

In [None]:
import os, sys
import yaml
import mintflow
import pickle
from tqdm.autonotebook import tqdm


import scanpy as sc
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch
import pandas as pd

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

## 1. Overview
To train on multiple tissue sections, MintFlow expects you to have each tissue section as a separate anndata object stored on disk as a separate `.h5ad` file. 
For example if you want to train on 10 tissue sections, you need 10 `.h5ad` files stored on disk with the following requirements.
1. The `adata.X` field of each `.h5ad` file must be raw read counts without, e.g., row sum or log1p normalisation.
2. The `.obs` filed of each `.h5ad` file must contain a column containing the unique identifier of that tissue section. MintFlow uses this column to assign IDs to tissue sections.
    - The name of that column can vary from one tissue section to another.
4. The `.obs` field of each `.h5ad` file must contain a column containing the batch identifier (e.g. biological batch identifier or technological batch identifier) of each tissue section. For example if the first 3 tissue sections come from the same biological batch, one may want to assign the same batch identifier to them.
    - The name of that column can vary from one tissue section to another.
 

## 2. Download the raw anndata objects
Download these 5 sample `.h5ad` files from google drive [(link to the folder on google drive)](https://drive.google.com/drive/folders/1lqbvUkHj5dan0o4YM9uTS5xdjQzPup_i?usp=sharing)
and place them in a directory of you choice. Thereafter, set the variable `path_anndata` below to the path where you placed the `.h5ad` files.

In [None]:
path_anndata = './NonGit/RawData_Tutorial_MultipleTissueSections/'  
# TODO:MODIFY: set to the path where you've put the `.h5ad` file that you downloaded.

## 3. Process the anndata files
Here we read the raw anndata objects, modify them to fullfil, e.g., requirements 2 and 3 explained in the "Overview" section, and store them back on disk so MintFlow can read and use them.

In [None]:
for index_tissue_section in tqdm(range(1, 6), desc='Processing the .h5ad files'):
    adata = sc.read_h5ad(
        './NonGit/RawData_Tutorial_MultipleTissueSections/tissue_section_{}.h5ad'.format(index_tissue_section) 
    )
    adata.obs['TissueSectionID_for_MintFlow'] = index_tissue_section  # the unique ID assigned to this tissue section
    adata.obs['batchID_for_MintFlow'] = index_tissue_section  # i.e. each tissue section is assumed to come from a different batch

    # convert the added column to type "category" (to avoid the slow down caused by having integer columns in pandas dataframes)
    adata.obs['TissueSectionID_for_MintFlow'] = adata.obs['TissueSectionID_for_MintFlow'].astype("category")
    adata.obs['batchID_for_MintFlow'] = adata.obs['batchID_for_MintFlow'].astype("category")

    # save the anndata object back to disk
    adata.write_h5ad(
        './NonGit/RawData_Tutorial_MultipleTissueSections/tissue_section_forMintFlow_{}.h5ad'.format(index_tissue_section)
    )

## 4. Create the four configuration objects

Having prepared and stored the 5 anndata objects on disk, we now have to create 4 configuration objects:
1. `config_data_train` to configure the training data
2. `config_data_evaluation` to configure the evaluation data
3. `config_model` to configure the MintFlow model
4. `config_training` to configure the training

### 4.1. Load the default configuration objects
Instead of creating the configuration objects from scratch, we load the default configuration objects and modify them partially.

In [None]:
config_data_train, config_data_evaluation, config_model, config_training = mintflow.get_default_configurations(
    num_tissue_sections_training=5,
    num_tissue_sections_evaluation=5
)

### 4.2. Costomise `config_data_train`
Since we have 5 tissue sections, we have to specify `config_data_train` for each tissue section separately (i.e. `config_data_train['list_tissue']['anndata1']`, `config_data_train['list_tissue']['anndata2']`, ..., and `config_data_train['list_tissue']['anndata5']`) 

#### 4.2.1. Customise `config_data_train` for tissue section 1
For more info about each field, please refer to the comments next to each configuration. As mentioned above, as far as the goal is to run the notebook, you can leave the configurations untouched.

In [None]:
config_data_train['list_tissue']['anndata1']['file'] = os.path.join(path_anndata, 'tissue_section_forMintFlow_1.h5ad')
#   the absolute path to anndata object of tissue section 1 on disk.


config_data_train['list_tissue']['anndata1']['obskey_cell_type'] = 'broad_celltypes'
#   meaning that for the 1st tissue section, cell type labels are provided in `broad_celltypes` column of `adata.obs`.


config_data_train['list_tissue']['anndata1']['obskey_sliceid_to_checkUnique'] = 'TissueSectionID_for_MintFlow'
#   meaning that for the 1st tissue section, tissue section ID (i.e. slice ID) is provided in `info_id` column of `adata.obs`


config_data_train['list_tissue']['anndata1']['obskey_x'] = 'x_centroid'
#   meaning that for the 1st tissue section, spatial x coordinates are provided in `x_centroid` column of `adata.obs`


config_data_train['list_tissue']['anndata1']['obskey_y'] = 'y_centroid'
#   meaning that for the 1st tissue section, spatial y coordinates are provided in `y_centroid` column of `adata.obs`



config_data_train['list_tissue']['anndata1']['obskey_biological_batch_key'] = 'batchID_for_MintFlow'
#   meaning that for the 1st tissue section, batch identifier is provided in `info_id` column of `adata.obs`


config_data_train['list_tissue']['anndata1']['config_dataloader_train']['width_window'] = 200
#   For tissue section one, the widow size of tissue crops used by the customised dataloader 
#      (the dataloader desribed in Supplementary Fig. 16 of paper).
#   The larger this number, the larger the tissue crops, and the bigger the subset of cells in each training iteration.
#      This implies that more GPU memory would be required during training.
#   In this notebook after calling `mintflow.setup_data` in Sec 6 the crop(s) are shown on tissue, 
#      with some information on image titles which can help you tune this parameter.
#   In the manuscript we used `width_window` values between 300 and 800 depending on dataset.


config_data_train['list_tissue']['anndata1']['config_neighbourhood_graph'] = {
    'n_neighs': 10,
    'set_diag': 'False',
    'delaunay': 'False',
}
#   The parameters for creating the neighbourhood graph for training tissue section 1

#### 4.2.2. Customise `config_data_train` for tissue section 2
We've deleted the comments, since the configs are similar. Note that in the below cell 
- `config_data_train['list_tissue']['anndata1']` is changed to `config_data_train['list_tissue']['anndata2']`
- `tissue_section_forMintFlow_1.h5ad` is changed to `tissue_section_forMintFlow_2.h5ad`

In [None]:
config_data_train['list_tissue']['anndata2']['file'] = os.path.join(path_anndata, 'tissue_section_forMintFlow_2.h5ad')
config_data_train['list_tissue']['anndata2']['obskey_cell_type'] = 'broad_celltypes'
config_data_train['list_tissue']['anndata2']['obskey_sliceid_to_checkUnique'] = 'TissueSectionID_for_MintFlow'
config_data_train['list_tissue']['anndata2']['obskey_x'] = 'x_centroid'
config_data_train['list_tissue']['anndata2']['obskey_y'] = 'y_centroid'
config_data_train['list_tissue']['anndata2']['obskey_biological_batch_key'] = 'batchID_for_MintFlow'
config_data_train['list_tissue']['anndata2']['config_dataloader_train']['width_window'] = 200
config_data_train['list_tissue']['anndata2']['config_neighbourhood_graph'] = {
    'n_neighs': 10,
    'set_diag': 'False',
    'delaunay': 'False',
}

#### 4.2.3. Customise `config_data_train` for tissue section 3

In [None]:
config_data_train['list_tissue']['anndata3']['file'] = os.path.join(path_anndata, 'tissue_section_forMintFlow_3.h5ad')
config_data_train['list_tissue']['anndata3']['obskey_cell_type'] = 'broad_celltypes'
config_data_train['list_tissue']['anndata3']['obskey_sliceid_to_checkUnique'] = 'TissueSectionID_for_MintFlow'
config_data_train['list_tissue']['anndata3']['obskey_x'] = 'x_centroid'
config_data_train['list_tissue']['anndata3']['obskey_y'] = 'y_centroid'
config_data_train['list_tissue']['anndata3']['obskey_biological_batch_key'] = 'batchID_for_MintFlow'
config_data_train['list_tissue']['anndata3']['config_dataloader_train']['width_window'] = 200
config_data_train['list_tissue']['anndata3']['config_neighbourhood_graph'] = {
    'n_neighs': 10,
    'set_diag': 'False',
    'delaunay': 'False',
}

#### 4.2.4. Customise `config_data_train` for tissue section 4

In [None]:
config_data_train['list_tissue']['anndata4']['file'] = os.path.join(path_anndata, 'tissue_section_forMintFlow_4.h5ad')
config_data_train['list_tissue']['anndata4']['obskey_cell_type'] = 'broad_celltypes'
config_data_train['list_tissue']['anndata4']['obskey_sliceid_to_checkUnique'] = 'TissueSectionID_for_MintFlow'
config_data_train['list_tissue']['anndata4']['obskey_x'] = 'x_centroid'
config_data_train['list_tissue']['anndata4']['obskey_y'] = 'y_centroid'
config_data_train['list_tissue']['anndata4']['obskey_biological_batch_key'] = 'batchID_for_MintFlow'
config_data_train['list_tissue']['anndata4']['config_dataloader_train']['width_window'] = 200
config_data_train['list_tissue']['anndata4']['config_neighbourhood_graph'] = {
    'n_neighs': 10,
    'set_diag': 'False',
    'delaunay': 'False',
}

#### 4.2.5. Customise `config_data_train` for tissue section 5

In [None]:
config_data_train['list_tissue']['anndata5']['file'] = os.path.join(path_anndata, 'tissue_section_forMintFlow_5.h5ad')
config_data_train['list_tissue']['anndata5']['obskey_cell_type'] = 'broad_celltypes'
config_data_train['list_tissue']['anndata5']['obskey_sliceid_to_checkUnique'] = 'TissueSectionID_for_MintFlow'
config_data_train['list_tissue']['anndata5']['obskey_x'] = 'x_centroid'
config_data_train['list_tissue']['anndata5']['obskey_y'] = 'y_centroid'
config_data_train['list_tissue']['anndata5']['obskey_biological_batch_key'] = 'batchID_for_MintFlow'
config_data_train['list_tissue']['anndata5']['config_dataloader_train']['width_window'] = 200
config_data_train['list_tissue']['anndata5']['config_neighbourhood_graph'] = {
    'n_neighs': 10,
    'set_diag': 'False',
    'delaunay': 'False',
}

### 4.3. Costomise `config_data_evaluation`
The set of tissue sections for evaluation can be the same one used for training, in which case the same values can be used, as we do at the following.
We recommend that all training tissue sections listed in `config_data_train` are also included in `config_data_evaluation`, to enable evaluation on training tissue sections.

Similar to `config_data_train`, we need to configure `config_data_evaluation` 5 times for 5 tissue sections.
Note that in the below cells instead of `config_dataloader_train` we have `config_dataloader_test`



#### 4.3.1. Customise `config_data_evaluation` for tissue section 1

In [None]:
config_data_evaluation['list_tissue']['anndata1']['file'] = os.path.join(path_anndata, 'tissue_section_forMintFlow_1.h5ad')
config_data_evaluation['list_tissue']['anndata1']['obskey_cell_type'] = 'broad_celltypes'
config_data_evaluation['list_tissue']['anndata1']['obskey_sliceid_to_checkUnique'] = 'TissueSectionID_for_MintFlow'
config_data_evaluation['list_tissue']['anndata1']['obskey_x'] = 'x_centroid'
config_data_evaluation['list_tissue']['anndata1']['obskey_y'] = 'y_centroid'
config_data_evaluation['list_tissue']['anndata1']['obskey_biological_batch_key'] = 'batchID_for_MintFlow'
config_data_evaluation['list_tissue']['anndata1']['config_dataloader_test']['width_window'] = 200
config_data_evaluation['list_tissue']['anndata1']['config_neighbourhood_graph'] = {
    'n_neighs': 10,
    'set_diag': 'False',
    'delaunay': 'False',
}

#### 4.3.2. Customise `config_data_evaluation` for tissue section 2

In [None]:
config_data_evaluation['list_tissue']['anndata2']['file'] = os.path.join(path_anndata, 'tissue_section_forMintFlow_2.h5ad')
config_data_evaluation['list_tissue']['anndata2']['obskey_cell_type'] = 'broad_celltypes'
config_data_evaluation['list_tissue']['anndata2']['obskey_sliceid_to_checkUnique'] = 'TissueSectionID_for_MintFlow'
config_data_evaluation['list_tissue']['anndata2']['obskey_x'] = 'x_centroid'
config_data_evaluation['list_tissue']['anndata2']['obskey_y'] = 'y_centroid'
config_data_evaluation['list_tissue']['anndata2']['obskey_biological_batch_key'] = 'batchID_for_MintFlow'
config_data_evaluation['list_tissue']['anndata2']['config_dataloader_test']['width_window'] = 200
config_data_evaluation['list_tissue']['anndata2']['config_neighbourhood_graph'] = {
    'n_neighs': 10,
    'set_diag': 'False',
    'delaunay': 'False',
}

#### 4.3.3. Customise `config_data_evaluation` for tissue section 3

In [None]:
config_data_evaluation['list_tissue']['anndata3']['file'] = os.path.join(path_anndata, 'tissue_section_forMintFlow_3.h5ad')
config_data_evaluation['list_tissue']['anndata3']['obskey_cell_type'] = 'broad_celltypes'
config_data_evaluation['list_tissue']['anndata3']['obskey_sliceid_to_checkUnique'] = 'TissueSectionID_for_MintFlow'
config_data_evaluation['list_tissue']['anndata3']['obskey_x'] = 'x_centroid'
config_data_evaluation['list_tissue']['anndata3']['obskey_y'] = 'y_centroid'
config_data_evaluation['list_tissue']['anndata3']['obskey_biological_batch_key'] = 'batchID_for_MintFlow'
config_data_evaluation['list_tissue']['anndata3']['config_dataloader_test']['width_window'] = 200
config_data_evaluation['list_tissue']['anndata3']['config_neighbourhood_graph'] = {
    'n_neighs': 10,
    'set_diag': 'False',
    'delaunay': 'False',
}

#### 4.3.4. Customise `config_data_evaluation` for tissue section 4

In [None]:
config_data_evaluation['list_tissue']['anndata4']['file'] = os.path.join(path_anndata, 'tissue_section_forMintFlow_4.h5ad')
config_data_evaluation['list_tissue']['anndata4']['obskey_cell_type'] = 'broad_celltypes'
config_data_evaluation['list_tissue']['anndata4']['obskey_sliceid_to_checkUnique'] = 'TissueSectionID_for_MintFlow'
config_data_evaluation['list_tissue']['anndata4']['obskey_x'] = 'x_centroid'
config_data_evaluation['list_tissue']['anndata4']['obskey_y'] = 'y_centroid'
config_data_evaluation['list_tissue']['anndata4']['obskey_biological_batch_key'] = 'batchID_for_MintFlow'
config_data_evaluation['list_tissue']['anndata4']['config_dataloader_test']['width_window'] = 200
config_data_evaluation['list_tissue']['anndata4']['config_neighbourhood_graph'] = {
    'n_neighs': 10,
    'set_diag': 'False',
    'delaunay': 'False',
}

#### 4.3.5. Customise `config_data_evaluation` for tissue section 5

In [None]:
config_data_evaluation['list_tissue']['anndata5']['file'] = os.path.join(path_anndata, 'tissue_section_forMintFlow_5.h5ad')
config_data_evaluation['list_tissue']['anndata5']['obskey_cell_type'] = 'broad_celltypes'
config_data_evaluation['list_tissue']['anndata5']['obskey_sliceid_to_checkUnique'] = 'TissueSectionID_for_MintFlow'
config_data_evaluation['list_tissue']['anndata5']['obskey_x'] = 'x_centroid'
config_data_evaluation['list_tissue']['anndata5']['obskey_y'] = 'y_centroid'
config_data_evaluation['list_tissue']['anndata5']['obskey_biological_batch_key'] = 'batchID_for_MintFlow'
config_data_evaluation['list_tissue']['anndata5']['config_dataloader_test']['width_window'] = 200
config_data_evaluation['list_tissue']['anndata5']['config_neighbourhood_graph'] = {
    'n_neighs': 10,
    'set_diag': 'False',
    'delaunay': 'False',
}

### 4.4. Customise `config_model`
`config_model` contains two important parameters to remove batch effect:
- `config_model['coef_xbarint2notbatchID_loss']`: determines to what degree batch mixing (i.e. batch integration) is encouraged in intrinsic component of expression (its embedded representation `Xbar_int`).
- `config_model['coef_xbarspl2notbatchID_loss']`: determines to what degree batch mixing (i.e. batch integration) is encouraged in micro-environment component of expression (its embedded representation `Xbar_mic`).

As an example, if in a dataset you know that micro-environment effect is batch- or sample-dependant, you can set `config_model['coef_xbarspl2notbatchID_loss']` to a small number to tell the model that batch mixing/integraiton should not be strictly performed for `Xbar_mic`.


In [None]:
config_model['coef_xbarint2notbatchID_loss'] = 1.0
config_model['coef_xbarspl2notbatchID_loss'] = 1.0

### 4.5. Customise config_training

A note about wandb: before proceeding, it is highligy recommended (though optional) to setup wandb and track/log different values during training.
- To enable wandb: Go to [(https://wandb.ai/)](https://wandb.ai/) and create an account
- To disable wandb: set `config_training['flag_enable_wandb']` in the below cell to 'False'.

In [None]:
config_training['num_training_epochs'] = 20
# number of training epochs, i.e. the number of times the model sees the dataset during training.

config_training['flag_use_GPU'] = 'True's
# whether GPU is used.

config_training['flag_enable_wandb'] = 'True'
# if set to True, during training different loss terms are logged to wandb.
# It's highly recommended to enable wandb. Please refer to wandb website for more info: `wandb.ai`


config_training['wandb_project_name'] = 'MintFlow'
# wandb project name (ignored if `config_training['flag_enable_wandb']` is set to False)

config_training['wandb_run_name'] = 'Mintflow_Tutorial_Notebook3'
# wandb run name (ignored if `config_training['flag_enable_wandb']` is set to False)

## 5. Verify and post-process the four configurations
In this section we verify and postprocess the four configurations.

In [None]:
config_data_train = mintflow.verify_and_postprocess_config_data_train(config_data_train) 

In [None]:
config_data_evaluation = mintflow.verify_and_postprocess_config_data_evaluation(config_data_evaluation)

In [None]:
config_model = mintflow.verify_and_postprocess_config_model(config_model, num_tissue_sections=len(config_data_train))  

In [None]:
config_training = mintflow.verify_and_postprocess_config_training(config_training) 

In [None]:
print("Finished verifying the 4 configuration objects.")

## 6. Setup the Data/Model/Trainer
Having created and verified the 4 configurations, in this section we create the variables `data_mintflow`, `model`, and `trainer`.



In [None]:
dict_all4_configs = {
    'config_data_train':config_data_train,
    'config_data_evaluation':config_data_evaluation,
    'config_model':config_model,
    'config_training':config_training
}

In [None]:
data_mintflow = mintflow.setup_data(dict_all4_configs=dict_all4_configs)

In [None]:
model = mintflow.setup_model(
    dict_all4_configs=dict_all4_configs,
    data_mintflow=data_mintflow
)

In [None]:
trainer = mintflow.Trainer(
    dict_all4_configs=dict_all4_configs,
    model=model,
    data_mintflow=data_mintflow
)

## 7. Train the Model
Set the variable `path_ouptput_files` below to the path where you want the training files (checkpoints etc) to be saved.

In [None]:
path_ouptput_files = "./NonGit/Outputs_TutorialNoboteok3"
# TODO:MODIFY: the path where checkpoints and other files are saved during training.

In [None]:
for index_epoch in tqdm(range(config_training['num_training_epochs']), desc='Training epoch'):
    '''
    IMPORTANT NOTE: To change the number of epochs, set `config_training['num_training_epochs']` in previous cells of this notebook
    and please refrain from changing the for loop here to, e.g., `for index_epoch in tqdm(range(10), ...)`.
    Because MintFlow's annealing module presumes that the number of epochs equals `config_training['num_training_epochs']`.
    ''' 
    
    # train for one epoch
    trainer.train_one_epoch()

    # get/save the predictions
    predictions = mintflow.predict(
        device=device,
        dict_all4_configs=dict_all4_configs,
        data_mintflow=data_mintflow,
        model=model,
        evalulate_on_sections="all",
    )
    with open(os.path.join(path_ouptput_files, "predictions_epoch_{}.pkl".format(index_epoch)), 'wb') as f:
        pickle.dump(
            predictions,
            f
        )

    # evaluate the model and save the evaluation result for the checkpoint
    df_evaluation_result = mintflow.evaluate_by_known_signalling_genes(
        device=device,
        dict_all4_configs=dict_all4_configs,
        data_mintflow=data_mintflow,
        model=model,
        evalulate_on_sections='all',
        optional_list_colvaltype_toadd=[['training_epoch', index_epoch, 'category']]
    )
    df_evaluation_result.to_pickle(
        os.path.join(
            path_ouptput_files,
            'df_evaluation_result_epoch_{}.pkl'.format(index_epoch)
        )
    )

    # save the checkpoint
    mintflow.dump_checkpoint(
        model=model,
        data_mintflow=data_mintflow,
        dict_all4_configs=dict_all4_configs,
        path_dump=os.path.join(path_ouptput_files, "checkpoint_epoch_{}.pt".format(index_epoch)),
    )    

## 8. Select the best checkpoint and perform the analysis
This part is identical to the tutorial for trianing on a single tissue section. 
Please refer to the tutorial notebook titled "Quick Start: Train MintFlow on a Single Tissue Section", section 8 onwards in that notebook. 