Skip to content
Merged
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,6 @@ _build/
.vscode/

_build

.DS_Store
tutorials/multiomics-cancer-classification/outputs
61 changes: 61 additions & 0 deletions tutorials/multiomics-cancer-classification/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
Default configurations according to the MOGONET method described in 'MOGONET integrates multi-omics data using
graph convolutional networks allowing patient classification and biomarker identification'
- Wang, T., Shao, W., Huang, Z., Tang, H., Zhang, J., Ding, Z., Huang, K. (2021).

https://github.com/txWang/MOGONET/blob/main/main_mogonet.py
"""

from yacs.config import CfgNode

# ---------------------------------------------------------
# Config definition
# ---------------------------------------------------------

_C = CfgNode()

# ---------------------------------------------------------
# Dataset
# ---------------------------------------------------------
_C.DATASET = CfgNode()
_C.DATASET.ROOT = "dataset/"
_C.DATASET.NAME = "TCGA_BRCA"
_C.DATASET.URL = "https://github.com/pykale/data/raw/main/multiomics/TCGA_BRCA.zip"
_C.DATASET.RANDOM_SPLIT = False
_C.DATASET.NUM_MODALITIES = 3 # Number of omics modalities in the dataset
_C.DATASET.NUM_CLASSES = 5

# ---------------------------------------------------------
# Solver
# ---------------------------------------------------------
_C.SOLVER = CfgNode()
_C.SOLVER.SEED = 2023
_C.SOLVER.MAX_EPOCHS_PRETRAIN = 500
_C.SOLVER.MAX_EPOCHS = 2500

# -----------------------------------------------------------------------------
# Model (MOGONET) configs
# -----------------------------------------------------------------------------
_C.MODEL = CfgNode()
_C.MODEL.EDGE_PER_NODE = (
10 # Predefined number of edges per nodes in computing adjacency matrix
)
_C.MODEL.EQUAL_WEIGHT = False
_C.MODEL.GCN_LR_PRETRAIN = 1e-3
_C.MODEL.GCN_LR = 5e-4
_C.MODEL.GCN_DROPOUT_RATE = 0.5
_C.MODEL.GCN_HIDDEN_DIM = [400, 400, 200]

# The View Correlation Discovery Network (VCDN) to learn the higher-level intra-view and cross-view correlations
# in the label space. See the MOGONET paper for more information.
_C.MODEL.VCDN_LR = 1e-3

# ---------------------------------------------------------
# Misc options
# ---------------------------------------------------------
_C.OUTPUT = CfgNode()
_C.OUTPUT.OUT_DIR = "./outputs"


def get_cfg_defaults():
return _C.clone()
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
DATASET:
NAME: "ROSMAP"
URL: "https://github.com/pykale/data/raw/main/multiomics/ROSMAP.zip"
NUM_MODALITIES: 3
NUM_CLASSES: 2

SOLVER:
MAX_EPOCHS_PRETRAIN: 5 # For quick testing
MAX_EPOCHS: 10 # For quick testing

MODEL:
EDGE_PER_NODE: 2
GCN_HIDDEN_DIM: [200, 200, 100]
13 changes: 13 additions & 0 deletions tutorials/multiomics-cancer-classification/experiments/base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
DATASET:
NAME: "TCGA_BRCA"
URL: "https://github.com/pykale/data/raw/main/multiomics/TCGA_BRCA.zip"
NUM_MODALITIES: 3
NUM_CLASSES: 5

SOLVER:
MAX_EPOCHS_PRETRAIN: 5 # For quick testing
MAX_EPOCHS: 100 # For quick testing

MODEL:
EDGE_PER_NODE: 10
GCN_HIDDEN_DIM: [400, 400, 200]
102 changes: 102 additions & 0 deletions tutorials/multiomics-cancer-classification/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from typing import List, Optional

from torch.nn import CrossEntropyLoss
from yacs.config import CfgNode

from kale.embed.mogonet import MogonetGCN
from kale.loaddata.multiomics_datasets import SparseMultiomicsDataset
from kale.pipeline.multiomics_trainer import MultiomicsTrainer
from kale.predict.decode import LinearClassifier, VCDN


class MogonetModel:
r"""Setup the MOGONET model via the config file.

Args:
cfg (CfgNode): A YACS config object.
dataset (SparseMultiomicsDataset): The input dataset created in form of :class:`~torch_geometric.data.Dataset`.
"""

def __init__(self, cfg: CfgNode, dataset: SparseMultiomicsDataset) -> None:
self.cfg = cfg
self.dataset = dataset
self.unimodal_encoder: List[MogonetGCN] = []
self.unimodal_decoder: List[LinearClassifier] = []
self.multimodal_decoder: Optional[VCDN] = None
self.loss_function = CrossEntropyLoss(reduction="none")
self._create_model()

def _create_model(self) -> None:
"""Create the MOGONET model via the config file."""
num_modalities = self.cfg.DATASET.NUM_MODALITIES
num_classes = self.cfg.DATASET.NUM_CLASSES
gcn_dropout_rate = self.cfg.MODEL.GCN_DROPOUT_RATE
gcn_hidden_dim = self.cfg.MODEL.GCN_HIDDEN_DIM
vcdn_hidden_dim = pow(num_classes, num_modalities)

for modality in range(num_modalities):
self.unimodal_encoder.append(
MogonetGCN(
in_channels=self.dataset.get(modality).num_features,
hidden_channels=gcn_hidden_dim,
dropout=gcn_dropout_rate,
)
)

self.unimodal_decoder.append(
LinearClassifier(in_dim=gcn_hidden_dim[-1], out_dim=num_classes)
)

if num_modalities >= 2:
self.multimodal_decoder = VCDN(
num_modalities=num_modalities,
num_classes=num_classes,
hidden_dim=vcdn_hidden_dim,
)

def get_model(self, pretrain: bool = False) -> MultiomicsTrainer:
"""Return the prepared MOGONET model based on user preference.

Args:
pretrain (bool, optional): Whether to return the pretrain model. (default: ``False``)

Returns:
MultiomicsTrainer: The prepared MOGONET model.
"""
num_modalities = self.cfg.DATASET.NUM_MODALITIES
num_classes = self.cfg.DATASET.NUM_CLASSES
gcn_lr_pretrain = self.cfg.MODEL.GCN_LR_PRETRAIN
gcn_lr = self.cfg.MODEL.GCN_LR
vcdn_lr = self.cfg.MODEL.VCDN_LR

if pretrain:
multimodal_model = None
train_multimodal_decoder = False
gcn_lr = gcn_lr_pretrain
else:
multimodal_model = self.multimodal_decoder
train_multimodal_decoder = True
gcn_lr = gcn_lr

model = MultiomicsTrainer(
dataset=self.dataset,
num_modalities=num_modalities,
num_classes=num_classes,
unimodal_encoder=self.unimodal_encoder,
unimodal_decoder=self.unimodal_decoder,
loss_fn=self.loss_function,
multimodal_decoder=multimodal_model,
train_multimodal_decoder=train_multimodal_decoder,
gcn_lr=gcn_lr,
vcdn_lr=vcdn_lr,
)

return model

def __str__(self) -> str:
r"""Returns a string representation of the model object.

Returns:
str: The string representation of the model object.
"""
return self.get_model().__str__()
385 changes: 385 additions & 0 deletions tutorials/multiomics-cancer-classification/notebook.ipynb

Large diffs are not rendered by default.