In [None]:
import os
os.chdir("../")

In this tutorial we will run training and validation using a small demo dataset. The dataset was randomly generated and simulates samples (cores) in tumour microarrays TMAs. This demo is expected to be completed within a few minutes.

## Config file

Parameters are defined in a config file (``./configs/config_demo.json`` for this demo). Important parameters include:

- ``data_sources``: locations of data for training and validation.
- ``data_sources.fp_genes``: file path of a .txt file containing the list of genes to input to the model.
- ``data_sources_external``: locations of data from a dataset external to the one used for training and validation.

In this config, non-response (NR) is class 1, and response (R) is class 0


## Training

```sh
python train.py --config_file configs/FILENAME.json --resume_epoch EPOCH --fold_id FOLD --gpu_id GPU_NUM
```

- ``--config_file`` path to config file
- ``--resume_epoch`` specifies whether to train from scratch or resume from a checkpoint, e.g., ``--resume_epoch 10`` to resume from the saved checkpoint from epoch 10. Set to 0 for training from scratch. Training will end when epoch number reaches `total_epochs` specified in the config file.
- ``--fold_id`` specifies the cross-validation fold (1, 2, 3...)
- ``--gpu_id`` which GPU to use (0, 1, 2...)

In [None]:
!python train.py --config_file configs/config_demo.json --resume_epoch 0 --fold_id 1 --gpu_id 0

A folder is created under `./experiments` where model checkpoints and outputs will be saved.

## Validation

```sh
python predict.py --config_file configs/FILENAME.json --epoch EPOCH --mode val --fold_id FOLD --gpu_id GPU_NUM --save_outputs
```

- ``--epoch`` specifies which epoch to test, e.g., ``10`` to use the model from epoch 10, or use `last` for the most recent, or `all` for all epochs
- ``--save_outputs`` indicates to save the embeddings (.pt) files per neighbourhood and the attention value (resistance score) per cell
- ``--mode`` is ``val`` (using the validation split) or ``predict`` (using data external to those used for training and validation). If specifying ``predict``, data will be used from locations in ``data_sources_predict`` in the config file.

In [None]:
!python inference.py --config_file configs/config_demo.json --epoch last --mode val --fold_id 1 --gpu_id 0

In this tutorial we trained for only 1 epoch to keep the demo short. Since we're using randomly generated data, we do not expect the model to train well.

## Outputs

The predictions were saved to ``experiments/{timestamp}/val_output/``. ``predictions.txt`` contains the ground truth and predicted response status for each sample, with the predicted probability. If using ``--save_outputs``, embeddings and attention values for each sample will be saved as files ending in ``embeddings.pt`` and ``attn.csv``. Each row corresponds to a neighbourhood, with the centre cell ID indicated in ``attn.csv`` files, and the attention value is in the ``A_raw`` column.