# Cancer Classification

In this tutorial, we demonstrate how to integrate **patient multiomics data** to enhance **cancer classification**.

This notebook builds on the work of **Wang et al. (Nature Communication, 2021)**, which present a novel multi-omics integrative method named **M**ulti-**O**mics **G**raph c**O**nvolutional **NET**works (MOGONET) and jointly explores omics-specific learning and cross-omics correlation learning for effective multiomics data classification by including mRNA expression data, DNA methylation data, and microRNA expression data.

---

**Objective**

TBC

# Setup

As a starting point, we will install the required packages and load a set of helper functions to assist throughout this tutorial. To keep the output clean and focused on interpretation, we will also suppress warnings.

Moreover, we provide helper functions that can be inspected directly in the `.py` files located in the notebook’s current directory. The three additional helper scripts are:
- `config.py`: Defines the base configuration settings, which can be overridden using a custom `.yaml` file.
- `parsing.py`: Contains utilities to compile evaluation results from the training process.

In [None]:
import os
import warnings

warnings.filterwarnings("ignore")
os.environ["PYTHONWARNINGS"] = "ignore"

[Optional] If you are using Google Colab, please using the following codes to load necessary demo data and code files.

In [None]:
!git clone --branch multiomics https://github.com/pykale/embc-mmai25.git
%cd /content/embc-mmai25/tutorials/multiomics-cancer-classification

fatal: destination path 'embc-mmai25' already exists and is not an empty directory.
/content/embc-mmai25/tutorials/multiomics-cancer-classification


## Packages

The main packages required for this tutorialare PyKale and PyTorch Geometric.

**PyKale** is an open-source interdisciplinary machine learning library developed at the University of Sheffield, with a focus on applications in biomedical and scientific domains.

**PyG** (PyTorch Geometric) is a library built upon  PyTorch to easily write and train Graph Neural Networks (GNNs) for a wide range of applications related to structured data.

Other required packages can be found in `embc-mmai25/requirements.txt`

In [None]:
!pip install --quiet git+https://github.com/pykale/pykale@main\
    && echo "PyKale installed successfully ✅" \
    || echo "Failed to install PyKale ❌"
!pip install --quiet -r /content/embc-mmai25/requirements.txt \
    && echo "Required packages installed successfully ✅" \
    || echo "Failed to install required packages ❌"
!pip install --upgrade --force-reinstall numpy
import torch
os.environ['TORCH'] = torch.__version__
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git \
    && echo "PyG installed successfully ✅" \
    || echo "Failed to install PyG ❌"

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
thinc 8.3.6 requires numpy<3.0.0,>=2.0.0, but you have numpy 1.26.4 which is incompatible.[0m[31m
[0mPyKale installed successfully ✅
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Required packages installed successfully ✅
Collecting numpy
  Using cached numpy-2.3.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (62 kB)
Using cached numpy-2.3.0-cp311-cp311-manylinux_2_28_x86_64.whl (16.9 MB)
Installing collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 1.26.4
    Uninstalli

## Configuration

To minimize the footprint of the notebook when specifying configurations, we provide a `config.py` file that defines default parameters. These can be customized by supplying a `.yaml` configuration file, such as `experiments/base.yaml` as an example.

In [36]:
from config import get_cfg_defaults

cfg = get_cfg_defaults()
cfg.merge_from_file("experiments/base.yaml")
# cfg.freeze()
cfg.SOLVER.MAX_EPOCHS = 500
cfg.DATASET.NUM_MODALITIES = 3
print(cfg)


DATASET:
  NAME: TCGA_BRCA
  NUM_CLASSES: 5
  NUM_MODALITIES: 3
  RANDOM_SPLIT: False
  ROOT: dataset/
  URL: https://github.com/pykale/data/raw/main/multiomics/TCGA_BRCA.zip
MODEL:
  EDGE_PER_NODE: 10
  EQUAL_WEIGHT: False
  GCN_DROPOUT_RATE: 0.5
  GCN_HIDDEN_DIM: [400, 400, 200]
  GCN_LR: 0.0005
  GCN_LR_PRETRAIN: 0.001
  VCDN_LR: 0.001
OUTPUT:
  OUT_DIR: ./outputs
SOLVER:
  MAX_EPOCHS: 500
  MAX_EPOCHS_PRETRAIN: 5
  SEED: 2023


# Data Loading

We use the preprocessed multiomics benchmark, BRCA, which have been provided by the authors of MOGONET paper in [their repository](https://github.com/txWang/MOGONET).
A brief description of BRCA dataset is shown in the following
table.

**Table 1**: Characteristics of the preprocessed BRCA multiomics dataset.

|      Omics       | #Training samples | #Test samples | #Features |
|:----------------:|:-----------------:|:-------------:|:---------:|
| mRNA expression  |        612        |      263      |   1000    |
| DNA methylation  |        612        |      263      |   1000    |
| miRNA expression |        612        |      263      |    503    |

Note: These datasets have been processed following the **Preprocessing** section of the original paper.

In [40]:
import torch
from kale.loaddata.multiomics_datasets import SparseMultiomicsDataset
from kale.prepdata.tabular_transform import ToOneHotEncoding, ToTensor

print("\n==> Preparing dataset...")
file_names = []
for modality in range(1, cfg.DATASET.NUM_MODALITIES + 1):
    file_names.append(f"{modality}_tr.csv")
    file_names.append(f"{modality}_lbl_tr.csv")
    file_names.append(f"{modality}_te.csv")
    file_names.append(f"{modality}_lbl_te.csv")

multiomics_data = SparseMultiomicsDataset(
    root=cfg.DATASET.ROOT,
    raw_file_names=file_names,
    num_modalities=cfg.DATASET.NUM_MODALITIES,
    num_classes=cfg.DATASET.NUM_CLASSES,
    edge_per_node=cfg.MODEL.EDGE_PER_NODE,
    url=cfg.DATASET.URL,
    random_split=cfg.DATASET.RANDOM_SPLIT,
    equal_weight=cfg.MODEL.EQUAL_WEIGHT,
    pre_transform=ToTensor(dtype=torch.float),
    target_pre_transform=ToOneHotEncoding(dtype=torch.float),
)

print(multiomics_data)


==> Preparing dataset...

Dataset info:
   number of modalities: 3
   number of classes: 5

   modality | total samples | num train | num test  | num features
   -----------------------------------------------------------------
   1        | 875           | 612       | 263       | 1000        
   2        | 875           | 612       | 263       | 1000        
   3        | 875           | 612       | 263       | 503         
   -----------------------------------------------------------------




In [41]:
# =============================================================================
# Author: Sina Tabakhi, sina.tabakhi@gmail.com
# =============================================================================

"""
Construct a pipeline to run the MOGONET method based on PyTorch Lightning. MOGONET is a multiomics fusion framework for
cancer classification and biomarker identification that utilizes supervised graph convolutional networks for omics
datasets.

This code is written by refactoring the MOGONET code (https://github.com/txWang/MOGONET/blob/main/train_test.py)
within the PyTorch Lightning.

Reference:
Wang, T., Shao, W., Huang, Z., Tang, H., Zhang, J., Ding, Z., Huang, K. (2021). MOGONET integrates multi-omics data
using graph convolutional networks allowing patient classification and biomarker identification. Nature communications.
https://www.nature.com/articles/s41467-021-23774-w
"""

from typing import List, Optional, Union

import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
from torch import Tensor
from torch.nn import CrossEntropyLoss, ModuleList
from torch.optim.optimizer import Optimizer
from torch_geometric.loader import DataLoader
from torch_sparse import SparseTensor

from kale.embed.mogonet import MogonetGCN
from kale.loaddata.multiomics_datasets import SparseMultiomicsDataset
from kale.predict.decode import LinearClassifier, VCDN


class MultiomicsTrainer(pl.LightningModule):
    r"""The PyTorch Lightning implementation of the MOGONET method, a multiomics fusion method designed for
    classification tasks.

    Args:
        dataset (SparseMultiomicsDataset): The input dataset created in form of :class:`~torch_geometric.data.Dataset`.
        num_modalities (int): The total number of modalities in the dataset.
        num_classes (int): The total number of classes in the dataset.
        unimodal_encoder (List[MogonetGCN]): The list of GCN encoders for each modality.
        unimodal_decoder (List[LinearClassifier]): The list of linear classifier decoders for each modality.
        loss_fn (CrossEntropyLoss): The loss function used to gauge the error between the prediction outputs and the
            provided target values.
        multimodal_decoder (VCDN, optional): The VCDN decoder used in the multiomics dataset.
            (default: ``None``)
        train_multimodal_decoder (bool, optional): Whether to train VCDN module. (default: ``True``)
        gcn_lr (float, optional): The learning rate used in the GCN module. (default: 5e-4)
        vcdn_lr (float, optional): The learning rate used in the VCDN module. (default: 1e-3)
    """

    def __init__(
        self,
        dataset: SparseMultiomicsDataset,
        num_modalities: int,
        num_classes: int,
        unimodal_encoder: List[MogonetGCN],
        unimodal_decoder: List[LinearClassifier],
        loss_fn: CrossEntropyLoss,
        multimodal_decoder: Optional[VCDN] = None,
        train_multimodal_decoder: bool = True,
        gcn_lr: float = 5e-4,
        vcdn_lr: float = 1e-3,
    ) -> None:
        super().__init__()
        self.dataset = dataset
        self.num_modalities = num_modalities
        self.num_classes = num_classes
        self.unimodal_encoder = ModuleList(unimodal_encoder)
        self.unimodal_decoder = ModuleList(unimodal_decoder)
        self.multimodal_decoder = multimodal_decoder
        self.train_multimodal_decoder = train_multimodal_decoder
        self.loss_fn = loss_fn
        self.gcn_lr = gcn_lr
        self.vcdn_lr = vcdn_lr

        # activate manual optimization
        self.automatic_optimization = False

    def configure_optimizers(self) -> List[Optimizer]:
        """Return the optimizers used during training."""
        optimizers = []

        for modality in range(self.num_modalities):
            optimizers.append(
                torch.optim.Adam(
                    list(self.unimodal_encoder[modality].parameters())
                    + list(self.unimodal_decoder[modality].parameters()),
                    lr=self.gcn_lr,
                )
            )

        if self.multimodal_decoder is not None:
            optimizers.append(torch.optim.Adam(self.multimodal_decoder.parameters(), lr=self.vcdn_lr))


        return optimizers

    def forward(
        self, x: List[Tensor], adj_t: List[SparseTensor], multimodal: bool = False
    ) -> Union[Tensor, List[Tensor]]:
        """Same as :meth:`torch.nn.Module.forward()`.

        Raises:
            TypeError: If `multimodal_decoder` is `None` for multiomics datasets.
        """
        output = []

        for modality in range(self.num_modalities):
            output.append(
                self.unimodal_decoder[modality](self.unimodal_encoder[modality](x[modality], adj_t[modality]))
            )

        if not multimodal:
            return output

        if self.multimodal_decoder is not None:
            return self.multimodal_decoder(output)

        raise TypeError("multimodal_decoder must be defined for multiomics datasets.")

    def training_step(self, train_batch, batch_idx: int):
        """Compute and return the training loss.

        Args:
            train_batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
                The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
            batch_idx (``int``): Integer displaying index of this batch.
        """
        optimizer = self.optimizers()

        if not isinstance(optimizer, (list, tuple)):
            optimizer = [optimizer]


        x = []
        adj_t = []
        y = []
        sample_weight = []
        for modality in range(self.num_modalities):
            data = train_batch[modality]
            x.append(data.x[data.train_idx])
            adj_t.append(data.adj_t_train)
            y.append(data.y[data.train_idx])
            sample_weight.append(data.train_sample_weight)

        outputs = self.forward(x, adj_t, multimodal=False)

        for modality in range(self.num_modalities):
            loss = self.loss_fn(outputs[modality], y[modality])
            loss = torch.mean(torch.mul(loss, sample_weight[modality]))
            self.logger.log_metrics({f"train_unimodal_step_loss ({modality + 1})": loss.detach()}, self.global_step)

            optimizer[modality].zero_grad()
            self.manual_backward(loss)
            optimizer[modality].step()

        if self.train_multimodal_decoder and self.multimodal_decoder is not None:
            output = self.forward(x, adj_t, multimodal=True)
            multi_loss = self.loss_fn(output, y[0])
            multi_loss = torch.mean(torch.mul(multi_loss, sample_weight[0]))
            self.logger.log_metrics({"train_multimodal_step_loss": multi_loss.detach()}, self.global_step)

            optimizer[-1].zero_grad()
            self.manual_backward(multi_loss)
            optimizer[-1].step()

    def test_step(self, test_batch, batch_idx: int):
        """Compute and return the test loss.

        Args:
            test_batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]):
                The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list.
            batch_idx (int): Integer displaying index of this batch.
        """
        x = []
        adj_t = []
        y = []
        for modality in range(self.num_modalities):
            data = test_batch[modality]
            x.append(data.x)
            adj_t.append(data.adj_t)
            y.append(torch.argmax(data.y[data.test_idx], dim=1))

        if self.multimodal_decoder is not None:
            output = self.forward(x, adj_t, multimodal=True)
        else:
            output = self.forward(x, adj_t, multimodal=False)[0]

        pred_test_data = torch.index_select(output, dim=0, index=test_batch[0].test_idx)
        final_output = F.softmax(pred_test_data, dim=1).detach().cpu().numpy()
        actual_output = y[0].detach().cpu()

        if self.num_classes == 2:
            self.log("Accuracy", round(accuracy_score(actual_output, final_output.argmax(1)), 3))
            self.log("F1", round(f1_score(actual_output, final_output.argmax(1)), 3))
            self.log("AUC", round(roc_auc_score(actual_output, final_output[:, 1]), 3))
        else:
            self.log("Accuracy", round(accuracy_score(actual_output, final_output.argmax(1)), 3))
            self.log("F1 weighted", round(f1_score(actual_output, final_output.argmax(1), average="weighted"), 3))
            self.log("F1 macro", round(f1_score(actual_output, final_output.argmax(1), average="macro"), 3))

        return final_output

    def _custom_data_loader(self) -> DataLoader:
        """Return an iterable or a collection of iterables that specifies all the samples in the dataset."""
        dataloaders = DataLoader(self.dataset, batch_size=1)
        return dataloaders

    def train_dataloader(self) -> DataLoader:
        """Return an iterable or a collection of iterables that specifies training samples in the dataset."""
        return self._custom_data_loader()

    def test_dataloader(self) -> DataLoader:
        """Return an iterable or a collection of iterables that specifies test samples in the dataset."""
        return self._custom_data_loader()

    def __str__(self) -> str:
        r"""Returns a string representation of the multiomics trainer object.

        Returns:
            str: The string representation of the multiomics trainer object.
        """
        model_str = ["\nModel info:\n", "   Unimodal encoder:\n"]

        for modality in range(self.num_modalities):
            model_str.append(f"    ({modality + 1}) {self.unimodal_encoder[modality]}")

        model_str.append("\n\n  Unimodal decoder:\n")
        for modality in range(self.num_modalities):
            model_str.append(f"    ({modality + 1}) {self.unimodal_decoder[modality]}")

        if self.multimodal_decoder is not None:
            model_str.append("\n\n  Multimodal decoder:\n")
            model_str.append(f"    {self.multimodal_decoder}")

        return "".join(model_str)


In [42]:
from typing import List, Optional

from torch.nn import CrossEntropyLoss
from yacs.config import CfgNode

from kale.embed.mogonet import MogonetGCN
from kale.loaddata.multiomics_datasets import SparseMultiomicsDataset
# from kale.pipeline.multiomics_trainer import MultiomicsTrainer
from kale.predict.decode import LinearClassifier, VCDN


class MogonetModel:
    r"""Setup the MOGONET model via the config file.

    Args:
        cfg (CfgNode): A YACS config object.
        dataset (SparseMultiomicsDataset): The input dataset created in form of :class:`~torch_geometric.data.Dataset`.
    """

    def __init__(self, cfg: CfgNode, dataset: SparseMultiomicsDataset) -> None:
        self.cfg = cfg
        self.dataset = dataset
        self.unimodal_encoder: List[MogonetGCN] = []
        self.unimodal_decoder: List[LinearClassifier] = []
        self.multimodal_decoder: Optional[VCDN] = None
        self.loss_function = CrossEntropyLoss(reduction="none")
        self._create_model()

    def _create_model(self) -> None:
        """Create the MOGONET model via the config file."""
        num_modalities = self.cfg.DATASET.NUM_MODALITIES
        num_classes = self.cfg.DATASET.NUM_CLASSES
        gcn_dropout_rate = self.cfg.MODEL.GCN_DROPOUT_RATE
        gcn_hidden_dim = self.cfg.MODEL.GCN_HIDDEN_DIM
        vcdn_hidden_dim = pow(num_classes, num_modalities)

        for modality in range(num_modalities):
            self.unimodal_encoder.append(
                MogonetGCN(
                    in_channels=self.dataset.get(modality).num_features,
                    hidden_channels=gcn_hidden_dim,
                    dropout=gcn_dropout_rate,
                )
            )

            self.unimodal_decoder.append(LinearClassifier(in_dim=gcn_hidden_dim[-1], out_dim=num_classes))

        if num_modalities >= 2:
          self.multimodal_decoder = VCDN(
              num_modalities=num_modalities, num_classes=num_classes, hidden_dim=vcdn_hidden_dim
          )

    def get_model(self, pretrain: bool = False) -> MultiomicsTrainer:
        """Return the prepared MOGONET model based on user preference.

        Args:
            pretrain (bool, optional): Whether to return the pretrain model. (default: ``False``)

        Returns:
            MultiomicsTrainer: The prepared MOGONET model.
        """
        num_modalities = self.cfg.DATASET.NUM_MODALITIES
        num_classes = self.cfg.DATASET.NUM_CLASSES
        gcn_lr_pretrain = self.cfg.MODEL.GCN_LR_PRETRAIN
        gcn_lr = self.cfg.MODEL.GCN_LR
        vcdn_lr = self.cfg.MODEL.VCDN_LR

        if pretrain:
            multimodal_model = None
            train_multimodal_decoder = False
            gcn_lr = gcn_lr_pretrain
        else:
            multimodal_model = self.multimodal_decoder
            train_multimodal_decoder = True
            gcn_lr = gcn_lr

        model = MultiomicsTrainer(
            dataset=self.dataset,
            num_modalities=num_modalities,
            num_classes=num_classes,
            unimodal_encoder=self.unimodal_encoder,
            unimodal_decoder=self.unimodal_decoder,
            loss_fn=self.loss_function,
            multimodal_decoder=multimodal_model,
            train_multimodal_decoder=train_multimodal_decoder,
            gcn_lr=gcn_lr,
            vcdn_lr=vcdn_lr,
        )

        return model

    def __str__(self) -> str:
        r"""Returns a string representation of the model object.

        Returns:
            str: The string representation of the model object.
        """
        return self.get_model().__str__()

# Setup Model



In [34]:
# from model import MogonetModel
import pytorch_lightning as pl
mogonet_model = MogonetModel(cfg, dataset=multiomics_data)
print(mogonet_model)


Model info:
   Unimodal encoder:
    (1) MogonetGCN(
  (conv1): MogonetGCNConv(1000, 400)
  (conv2): MogonetGCNConv(400, 400)
  (conv3): MogonetGCNConv(400, 200)
)    (2) MogonetGCN(
  (conv1): MogonetGCNConv(1000, 400)
  (conv2): MogonetGCNConv(400, 400)
  (conv3): MogonetGCNConv(400, 200)
)    (3) MogonetGCN(
  (conv1): MogonetGCNConv(503, 400)
  (conv2): MogonetGCNConv(400, 400)
  (conv3): MogonetGCNConv(400, 200)
)

  Unimodal decoder:
    (1) LinearClassifier(
  (fc): Linear(in_features=200, out_features=5, bias=True)
)    (2) LinearClassifier(
  (fc): Linear(in_features=200, out_features=5, bias=True)
)    (3) LinearClassifier(
  (fc): Linear(in_features=200, out_features=5, bias=True)
)

  Multimodal decoder:
    VCDN(
  (model): Sequential(
    (0): Linear(in_features=125, out_features=125, bias=True)
    (1): LeakyReLU(negative_slope=0.25)
    (2): Linear(in_features=125, out_features=5, bias=True)
  )
)


# Setup Trainer

In [43]:
network = mogonet_model.get_model(pretrain=False)
print(network)
trainer = pl.Trainer(
    max_epochs=cfg.SOLVER.MAX_EPOCHS,
    default_root_dir=cfg.OUTPUT.OUT_DIR,
    accelerator="auto",
    devices="auto",
    enable_model_summary=False,
    log_every_n_steps=1,
)

INFO:pytorch_lightning.utilities.rank_zero:Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs



Model info:
   Unimodal encoder:
    (1) MogonetGCN(
  (conv1): MogonetGCNConv(1000, 400)
  (conv2): MogonetGCNConv(400, 400)
  (conv3): MogonetGCNConv(400, 200)
)    (2) MogonetGCN(
  (conv1): MogonetGCNConv(1000, 400)
  (conv2): MogonetGCNConv(400, 400)
  (conv3): MogonetGCNConv(400, 200)
)    (3) MogonetGCN(
  (conv1): MogonetGCNConv(503, 400)
  (conv2): MogonetGCNConv(400, 400)
  (conv3): MogonetGCNConv(400, 200)
)

  Unimodal decoder:
    (1) LinearClassifier(
  (fc): Linear(in_features=200, out_features=5, bias=True)
)    (2) LinearClassifier(
  (fc): Linear(in_features=200, out_features=5, bias=True)
)    (3) LinearClassifier(
  (fc): Linear(in_features=200, out_features=5, bias=True)
)

  Multimodal decoder:
    VCDN(
  (model): Sequential(
    (0): Linear(in_features=125, out_features=125, bias=True)
    (1): LeakyReLU(negative_slope=0.25)
    (2): Linear(in_features=125, out_features=5, bias=True)
  )
)


In [46]:
trainer.fit(network)

Training: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=500` reached.


In [45]:
print("\n==> Testing model...")
_ = trainer.test(network)


==> Testing model...


Testing: |          | 0/? [00:00<?, ?it/s]