In [1]:
import os
os.chdir("../")

In this tutorial we will run training and validation using a small demo dataset. The dataset was randomly generated and simulates samples (cores) in tumour microarrays TMAs. This demo is expected to be completed within a few minutes.

## Config file

Parameters are defined in a config file (``./configs/config_demo.json`` for this demo). Important parameters include:

- ``data_sources``: locations of data for training and validation.
- ``data_sources.fp_genes``: file path of a .txt file containing the list of genes to input to the model.
- ``data_sources_external``: locations of data from a dataset external to the one used for training and validation.

In this config, non-response (NR) is class 1, and response (R) is class 0


## Training

```sh
python train.py --config_file configs/FILENAME.json --resume_epoch EPOCH --fold_id FOLD --gpu_id GPU_NUM
```

- ``--config_file`` path to config file
- ``--resume_epoch`` specifies whether to train from scratch or resume from a checkpoint, e.g., ``--resume_epoch 10`` to resume from the saved checkpoint from epoch 10. Set to 0 for training from scratch. Training will end when epoch number reaches `total_epochs` specified in the config file.
- ``--fold_id`` specifies the cross-validation fold (1, 2, 3...)
- ``--gpu_id`` which GPU to use (0, 1, 2...)

In [2]:
!python train.py --config_file configs/config_demo.json --resume_epoch 0 --fold_id 1 --gpu_id 0

Using GPUs: 0
test output
2025-10-09 12:37:11,249 INFO Initialising model
Classes: ['R', 'NR']
Num genes: 100
2025-10-09 12:37:11,638 INFO Preparing data
2025-10-09 12:37:11,709 INFO Total number of training batches: 16
2025-10-09 12:37:13,102 INFO Begin training
Epoch: 1
Epoch[1/1], Loss:22.5567
L_Response:11.4203, Acc:0.4375
L_A_raw:11.1364
2025-10-09 12:41:24,824 INFO Model saved: experiments/fold1_2025_10_09_12_37_11/models/epoch_1_model.pth
2025-10-09 12:41:24,855 INFO Model saved: experiments/fold1_2025_10_09_12_37_11/models/epoch_1_model_gma.pth
2025-10-09 12:41:24,907 INFO Optimiser saved: experiments/fold1_2025_10_09_12_37_11/models/epoch_1_optim.pth
2025-10-09 12:41:24,908 INFO Training finished



  0%|          | 0/16 [00:00<?, ?it/s]
loss: 1.3149:   0%|          | 0/16 [00:12<?, ?it/s]
loss: 1.3149:   6%|▋         | 1/16 [00:12<03:04, 12.31s/it]
loss: 1.5273:   6%|▋         | 1/16 [00:29<03:04, 12.31s/it]
loss: 1.5273:  12%|█▎        | 2/16 [00:29<03:29, 14.96s/it]
loss: 1.2637:  12%|█▎        | 2/16 [00:45<03:29, 14.96s/it]
loss: 1.2637:  19%|█▉        | 3/16 [00:45<03:21, 15.48s/it]
loss: 1.5401:  19%|█▉        | 3/16 [01:02<03:21, 15.48s/it]
loss: 1.5401:  25%|██▌       | 4/16 [01:02<03:12, 16.01s/it]
loss: 1.2553:  25%|██▌       | 4/16 [01:18<03:12, 16.01s/it]
loss: 1.2553:  31%|███▏      | 5/16 [01:18<02:59, 16.34s/it]
loss: 1.5454:  31%|███▏      | 5/16 [01:35<02:59, 16.34s/it]
loss: 1.5454:  38%|███▊      | 6/16 [01:35<02:43, 16.32s/it]
loss: 1.2525:  38%|███▊      | 6/16 [01:51<02:43, 16.32s/it]
loss: 1.2525:  44%|████▍     | 7/16 [01:51<02:26, 16.25s/it]
loss: 1.5486:  44%|████▍     | 7/16 [02:07<02:26, 16.25s/it]
loss: 1.5486:  50%|█████     | 8/16 [02:07<02:08, 16.

A folder (timestamped name) is created under `./experiments` where model checkpoints and outputs will be saved.

## Validation

```sh
python predict.py --config_file configs/FILENAME.json --epoch EPOCH --mode val --fold_id FOLD --gpu_id GPU_NUM --save_outputs
```

- ``--epoch`` specifies which epoch to test, e.g., ``10`` to use the model from epoch 10, or use `last` for the most recent, or `all` for all epochs
- ``--save_outputs`` indicates to save the embeddings (.pt) files per neighbourhood and the attention value (resistance score) per cell
- ``--mode`` is ``val`` (using the validation split) or ``predict`` (using data external to those used for training and validation). If specifying ``predict``, data will be used from locations in ``data_sources_predict`` in the config file.

In [3]:
!python predict.py --config_file configs/config_demo.json --epoch last --mode val --fold_id 1 --gpu_id 0 --save_outputs

Using GPUs: 0
Classes: ['R', 'NR']
Num genes: 100
Predict using experiments/fold1_2025_10_09_12_37_11/models/epoch_1_model.pth
Predict using experiments/fold1_2025_10_09_12_37_11/models/epoch_1_model_gma.pth
Epoch[1], ACC:0.6667, F1:0.6250, AUC:0.5556
***best epoch***
Best epoch 1: F1 0.625, ACC 0.6666666666666666, AUC 0.5555555555555556



  0%|          | 0/6 [00:00<?, ?it/s]
 17%|█▋        | 1/6 [00:11<00:56, 11.39s/it]
 33%|███▎      | 2/6 [00:22<00:44, 11.10s/it]
 50%|█████     | 3/6 [00:33<00:33, 11.31s/it]
 67%|██████▋   | 4/6 [00:45<00:22, 11.48s/it]
 83%|████████▎ | 5/6 [00:56<00:11, 11.23s/it]
100%|██████████| 6/6 [01:07<00:00, 11.16s/it]
100%|██████████| 6/6 [01:07<00:00, 11.23s/it]


Note: in this tutorial we do not expect the model to train well because we're using randomly generated data.

## Outputs

The predictions were saved to ``experiments/{timestamp}/val_output/``. ``predictions.txt`` contains the ground truth and predicted response status for each sample, with the predicted probability. If using ``--save_outputs``, embeddings and attention values for each sample will be saved as files ending in ``embeddings.pt`` and ``attn.csv``. Each row corresponds to a neighbourhood, with the centre cell ID indicated in ``attn.csv`` files, and the attention value is in the ``A_raw`` column.