From 1e0c5c042f4f18e901dba4c3c71df5650b78e79b Mon Sep 17 00:00:00 2001 From: "L. M. Riza Rizky" <42672299+zaRizk7@users.noreply.github.com> Date: Mon, 14 Jul 2025 19:13:34 +0100 Subject: [PATCH 1/3] add visualisation outputs to drug tutorial --- .../tutorial-drug.ipynb | 5130 +++++++++++++---- 1 file changed, 3934 insertions(+), 1196 deletions(-) diff --git a/tutorials/drug-target-interaction/tutorial-drug.ipynb b/tutorials/drug-target-interaction/tutorial-drug.ipynb index 1cb3501..9f22516 100644 --- a/tutorials/drug-target-interaction/tutorial-drug.ipynb +++ b/tutorials/drug-target-interaction/tutorial-drug.ipynb @@ -1,1198 +1,3936 @@ { - "nbformat": 4, - "nbformat_minor": 5, - "metadata": { - "kernelspec": { - "display_name": "mmai-drug-tutorial", - "language": "python", - "name": "python3" - } - }, - "cells": [ - { - "metadata": {}, - "source": [ - "# Drug–Target Interaction Prediction\n", - "\n", - "![](images/drugban-pyakle-api.png)\n", - "\n", - "\n", - "In this tutorial, we will train models to predict the interaction between **two data modalities**: **molecules (drug)** and **proteins (target)** using `PyKale`. Drug-target interaction (DTI) plays a key role in drug discovery and identifying potential therapeutic targets. This example is based on the **DrugBAN** framework by [**Bai et al. (_Nature Machine Intelligence_, 2023)**](https://www.nature.com/articles/s42256-022-00605-1).\n", - "\n", - "The DTI prediction problem is formulated as a **binary classification task**, where the goal is to predict whether a given **drug–protein pair interacts or not**. The DrugBAN framework tackles this problem using two key ideas:\n", - "\n", - "- **Bilinear Attention Network (BAN)**, which learns detailed feature representations for both drugs and proteins and captures local interaction patterns between them.\n", - "\n", - "- **Adversarial Domain Adaptation**, which helps the model generalise to out-of-distribution datasets, i.e., in clustering-based cross-validation instead of random splits, improving its ability to predict interactions on unseen drug–target pairs.\n", - "\n", - "With `PyKale`, implementing such a multimodal DTI prediction pipeline is straightforward. The library provides ready-to-use modules and configuration support, making it easy to apply advanced techniques with minimal custom coding." - ], - "cell_type": "markdown", - "id": "8c1bf9c7" - }, - { - "metadata": {}, - "source": [ - "## Step 0: Environment Preparation\n", - "\n", - "As a starting point, we will install the required packages and load a set of helper functions to assist throughout this tutorial.\n", - "\n", - "To prepare the helper functions and necessary materials, we download them from the GitHub repository.\n", - "\n", - "Moreover, we provide helper functions that can be inspected directly in the `.py` files located in the notebook's current directory. The additional helper script is:\n", - "- [`config.py`](https://github.com/pykale/embc-mmai25/blob/main/tutorials/drug-target-interaction/configs.py): Defines the base configuration settings, which can be overridden using a custom `.yaml` file." - ], - "cell_type": "markdown", - "id": "745ccdcf" - }, - { - "metadata": {}, - "source": [ - "import os\n", - "\n", - "!rm -rf /content/mmai-tutorials\n", - "!git clone --single-branch -b main https://github.com/pykale/mmai-tutorials.git\n", - "%mv /content/mmai-tutorials/tutorials/drug-target-interaction /content/\n", - "%cd /content/drug-target-interaction\n", - "\n", - "print(\"Changed working directory to:\", os.getcwd())" - ], - "cell_type": "code", - "outputs": [], - "id": "a6028209", - "execution_count": null - }, - { - "metadata": {}, - "source": [ - "### Package Installation\n", - "\n", - "The main package required for this tutorial is `PyKale`.\n", - "\n", - "`PyKale` is an open-source interdisciplinary machine learning library developed at the University of Sheffield, with a focus on applications in biomedical and scientific domains.\n", - "\n", - "Then, we install `PyG` (PyTorch Geometric) and related packages.\n", - "\n", - "Please **do not** re-run this session after installation completed. Runing this installation multiple times will trigger issues related to `PyG`. If you want to re-run this installation, please click the `Runtime` on the top menu and choose `Disconnect and delete runtime` before installing." - ], - "cell_type": "markdown", - "id": "c52c6334" - }, - { - "metadata": {}, - "source": [ - "!pip install --quiet \\\n", - " git+https://github.com/pykale/pykale@main \\\n", - " yacs==0.1.8 \\\n", - " rdkit \\\n", - " torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric \\\n", - " -f https://data.pyg.org/whl/torch-2.6.0+cu124.html \\\n", - " && echo \"pykale,yacs and wfdb installed successfully ✅\" \\\n", - " || echo \"Failed to install pykale,yacs ❌\"" - ], - "cell_type": "code", - "outputs": [], - "id": "53e3b14e", - "execution_count": null - }, - { - "metadata": {}, - "source": [ - "We then hide the warnings messages to get a clear output." - ], - "cell_type": "markdown", - "id": "69f50b6a" - }, - { - "metadata": {}, - "source": [ - "import os\n", - "import warnings\n", - "\n", - "warnings.filterwarnings(\"ignore\")\n", - "os.environ[\"PYTHONWARNINGS\"] = \"ignore\"" - ], - "cell_type": "code", - "outputs": [], - "id": "6e871c63", - "execution_count": null - }, - { - "metadata": {}, - "source": [ - "Exercise: Check NumPy Version" - ], - "cell_type": "markdown", - "id": "6606e3fb" - }, - { - "metadata": {}, - "source": [ - "import numpy as np\n", - "\n", - "print(\"NumPy version:\", np.__version__) # numpy should be 2.0.0 or higher" - ], - "cell_type": "code", - "outputs": [], - "id": "0d384020", - "execution_count": null - }, - { - "metadata": {}, - "source": [ - "### Configuration\n", - "\n", - "To minimize the footprint of the notebook when specifying configurations, we provide a [`config.py`](https://github.com/pykale/embc-mmai25/blob/main/tutorials/drug-target-interaction/configs.py) file that defines default parameters. These can be customized by supplying a `.yaml` configuration file, such as [`configs/DA_cross_domain.yaml`](https://github.com/pykale/embc-mmai25/blob/main/tutorials/drug-target-interaction/configs/DA_cross_domain.yaml) as an example." - ], - "cell_type": "markdown", - "id": "cabd3406" - }, - { - "metadata": {}, - "source": [ - "from configs import get_cfg_defaults\n", - "\n", - "%cd /content/drug-target-interaction\n", - "\n", - "cfg = get_cfg_defaults() # Load the default settings from config.py\n", - "cfg.merge_from_file(\n", - " \"configs/DA_cross_domain.yaml\"\n", - ") # Update (or override) some of those settings using a custom YAML file" - ], - "cell_type": "code", - "outputs": [], - "id": "55c13b48", - "execution_count": null - }, - { - "metadata": {}, - "source": [ - "In this tutorial, we list the hyperparameters we would like users to play with outside the `.yaml` file:\n", - "- `cfg.SOLVER.MAX_EPOCH`: Number of epochs in training stage. You can reduce the number of training epochs to shorten runtime.\n", - "- `cfg.DATA.DATASET`: The dataset used in the study. This can be `bindingdb` or `biosnap`.\n", - "\n", - "As a quick exercise, please take a moment to review and understand the parameters in [`config.py`](https://github.com/pykale/embc-mmai25/blob/main/tutorials/drug-target-interaction/configs.py)." - ], - "cell_type": "markdown", - "id": "74ffdbc2" - }, - { - "metadata": {}, - "source": [ - "cfg.SOLVER.MAX_EPOCH = 2" - ], - "cell_type": "code", - "outputs": [], - "id": "424c7286", - "execution_count": null - }, - { - "metadata": {}, - "source": [ - "You can also switch to a different dataset." - ], - "cell_type": "markdown", - "id": "97c088fd" - }, - { - "metadata": {}, - "source": [ - "cfg.DATA.DATASET = \"biosnap\"" - ], - "cell_type": "code", - "outputs": [], - "id": "c69376fa", - "execution_count": null - }, - { - "metadata": {}, - "source": [ - "Exercise: Now print the full configuration to check all current hyperparameter and dataset settings." - ], - "cell_type": "markdown", - "id": "d3d41633" - }, - { - "metadata": {}, - "source": [ - "print(cfg)" - ], - "cell_type": "code", - "outputs": [], - "id": "45874296", - "execution_count": null - }, - { - "metadata": {}, - "source": [ - "## Step 1: Data Loading and Preparation\n", - "\n", - "In this tutorial, we use the **Biosnap** dataset for the main demonstration and the **BindingDB** dataset for the exercise at the end." - ], - "cell_type": "markdown", - "id": "17558d0c" - }, - { - "metadata": {}, - "source": [ - "### Data Downloading\n", - "\n", - "Please run the following cell to download necessary datasets." - ], - "cell_type": "markdown", - "id": "6c6071b9" - }, - { - "metadata": {}, - "source": [ - "!rm -rf data\n", - "!mkdir data\n", - "!cd data\n", - "\n", - "!pip install -q gdown\n", - "!gdown --id 1ogOcxZn-1q418LOT-gQ94aHQV0Y1sOmk --output data/drug-target-interaction.zip\n", - "!unzip data/drug-target-interaction.zip -d data/\n", - "!mv data/drug-target-interaction/checkpoint ./" - ], - "cell_type": "code", - "outputs": [], - "id": "56f9f58e", - "execution_count": null - }, - { - "metadata": {}, - "source": [ - "Exercise: Check the data is ready" - ], - "cell_type": "markdown", - "id": "c39b3e39" - }, - { - "metadata": {}, - "source": [ - "import os\n", - "import shutil\n", - "\n", - "print(\"Contents of the data folder:\")\n", - "for item in os.listdir(\"data/drug-target-interaction\"):\n", - " print(item)" - ], - "cell_type": "code", - "outputs": [], - "id": "a6258d1f", - "execution_count": null - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": [ - "The data content is structured as follows:\n", - "```sh\n", - " ├───data\n", - " │ ├───checkpoint\n", - " │ ├───bindingdb\n", - " │ ├───biosnap" - ], - "id": "9ab0b5f833dc40f8" - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": "The `data` folder contains two datasets: `bindingdb` and `biosnap`. Each dataset folder contains the following files. The `checkpoint` folder contains the saved model checkpoint, which are used later in the interpretation section.", - "id": "5be1dcc62b7d5649" - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": [ - "print(\"Contents of bindingdb folder:\")\n", - "for item in os.listdir(\"data/drug-target-interaction/bindingdb\"):\n", - " print(item)" - ], - "id": "a93303c51c8b974e" - }, - { - "metadata": {}, - "source": [ - "Each dataset folder follows the structure:\n", - "\n", - "```sh\n", - " ├───dataset_name\n", - " │ ├───cluster\n", - " │ │ ├───source_train.csv\n", - " │ │ ├───target_train.csv\n", - " │ │ ├───target_test.csv\n", - " │ ├───random\n", - " │ │ ├───test.csv\n", - " │ │ ├───train.csv\n", - " │ │ ├───val.csv\n", - " │ ├───full.csv\n", - "```" - ], - "cell_type": "markdown", - "id": "79cbc1c1" - }, - { - "metadata": {}, - "source": [ - "We use the cluster dataset folder for cross-domain prediction, containing three parts:\n", - "\n", - "- Train samples from the source domain: Drug–protein pairs the model learns from.\n", - "\n", - "- Train samples from the target domain: Additional training data from a different distribution to improve generalisation.\n", - "\n", - "- Test samples from the target domain: Unseen drug–protein pairs used to evaluate model performance on new data.\n", - "\n", - "The source and target sets are defined based on the clustering results." - ], - "cell_type": "markdown", - "id": "d35e04f9" - }, - { - "metadata": {}, - "source": "### Data Loading", - "cell_type": "markdown", - "id": "98acf744" - }, - { - "metadata": {}, - "source": [ - "Here’s what each csv file looks like in a table format:\n", - "\n", - "| SMILES | Protein Sequence | Y |\n", - "|--------------------|--------------------------|---|\n", - "| Fc1ccc(C2(COC…) | MDNVLPVDSDLS… | 1 |\n", - "| O=c1oc2c(O)c(…) | MMYSKLLTLTTL… | 0 |\n", - "| CC(C)Oc1cc(N…) | MGMACLTMTEME… | 1 |\n", - "\n", - "Each row of the dataset contains three key pieces of information:\n", - "\n", - "**Drugs**: \n", - "Drugs are often written as SMILES strings, which are like chemical formulas in text format (for example, `\"CC(=O)OC1=CC=CC=C1C(=O)O\"` is aspirin). \n", - "\n", - "\n", - "**Protein Sequence** \n", - "This is a string of letters where each letter stands for an amino acid, the building blocks of proteins. For example, `MGYTSLLT...` is a short protein sequence.\n", - "\n", - "\n", - "**Y (Labels)**: \n", - "Each drug–protein pair is given a label:\n", - "- `1` if they interact\n", - "- `0` if they do not\n", - "\n", - "\n", - "Each row shows one drug–protein pair. The goal of our machine learning model is to predict the last column (**Y**) — whether or not the drug and protein interact." - ], - "cell_type": "markdown", - "id": "1e5f4f44" - }, - { - "metadata": {}, - "source": "You can load CSV files into Python using tools like `pandas`. The output shows a sample of the data, including the SMILES string for the drug, the protein sequence, the interaction label (Y) and the cluster ID.", - "cell_type": "markdown", - "id": "b7590daf" - }, - { - "metadata": {}, - "source": [ - "import pandas as pd\n", - "\n", - "dataFolder = os.path.join(\n", - " f\"data/drug-target-interaction/{cfg.DATA.DATASET}\", str(cfg.DATA.SPLIT)\n", - ")\n", - "\n", - "df_train_source = pd.read_csv(os.path.join(dataFolder, \"source_train.csv\"))\n", - "df_train_target = pd.read_csv(os.path.join(dataFolder, \"target_train.csv\"))\n", - "df_test_target = pd.read_csv(os.path.join(dataFolder, \"target_test.csv\"))\n", - "\n", - "print(\"Sample example:\", df_train_source.iloc[0]))" - ], - "cell_type": "code", - "outputs": [], - "id": "0c709e31", - "execution_count": null - }, - { - "metadata": {}, - "source": [ - "### Data Preprocessing\n", - "\n", - "We convert drug SMILES strings into molecular graphs using `kale.loaddata.molecular_datasets.smiles_to_graph`, encoding atom-level features as node attributes and bond types as edges.\n", - "\n", - "\n", - "Protein sequences are transformed into fixed-length integer arrays using `kale.prepdata.chem_transform.integer_label_protein`, with each amino acid mapped to an integer and sequences padded or truncated to a uniform length.\n", - "\n", - "Finally, the `kale.loaddata.molecular_datasets.DTIDataset` class packages drugs, proteins, and labels into a PyTorch-ready dataset." - ], - "cell_type": "markdown", - "id": "542d4e69" - }, - { - "metadata": {}, - "source": [ - "**Note:** If you encounter an error related to requiring numpy `<2.0`, simply ignore it and re-run this block until it completes successfully." - ], - "cell_type": "markdown", - "id": "981d5520" - }, - { - "metadata": {}, - "source": [ - "from kale.loaddata.molecular_datasets import DTIDataset\n", - "\n", - "# Create preprocessed datasets\n", - "train_dataset = DTIDataset(df_train_source.index.values, df_train_source)\n", - "train_target_dataset = DTIDataset(df_train_target.index.values, df_train_target)\n", - "test_target_dataset = DTIDataset(df_test_target.index.values, df_test_target)" - ], - "cell_type": "code", - "outputs": [], - "id": "ae5af8eb", - "execution_count": null - }, - { - "metadata": {}, - "source": [ - "We load data in small, manageable pieces called batches to save memory and speed up training. We use `kale.loaddata.sampler.MultiDataLoader` from PyKale to load one batch from the source domain and one from the target domain at each training step." - ], - "cell_type": "markdown", - "id": "a0a510ce" - }, - { - "metadata": {}, - "source": [ - "First, we specify a few DataLoader parameters:\n", - "- Batch size: Number of samples per batch\n", - "- Shuffle: Randomly shuffle data\n", - "- Number of workers: Parallel data loading\n", - "- Drop last: Discard the last incomplete batch for consistent batch sizes\n", - "- Collate function: Use graph_collate_func to batch variable-sized molecular graphs" - ], - "cell_type": "markdown", - "id": "c09084c0" - }, - { - "metadata": {}, - "source": [ - "from torch.utils.data import DataLoader\n", - "from kale.loaddata.molecular_datasets import graph_collate_func\n", - "from kale.loaddata.sampler import MultiDataLoader\n", - "\n", - "params = {\n", - " \"batch_size\": cfg.SOLVER.BATCH_SIZE,\n", - " \"shuffle\": True,\n", - " \"num_workers\": cfg.SOLVER.NUM_WORKERS,\n", - " \"drop_last\": True,\n", - " \"collate_fn\": graph_collate_func,\n", - "}\n", - "\n", - "params" - ], - "cell_type": "code", - "outputs": [], - "id": "94a15868", - "execution_count": null - }, - { - "metadata": {}, - "source": [ - "Then, we create a DataLoader from both the source and target datasets for training." - ], - "cell_type": "markdown", - "id": "e884ed07" - }, - { - "metadata": {}, - "source": [ - "print(\"Using domain adaptation:\", cfg.DA.USE)\n", - "\n", - "if not cfg.DA.USE:\n", - " training_generator = DataLoader(train_dataset, **params)\n", - "else:\n", - " source_generator = DataLoader(train_dataset, **params)\n", - " target_generator = DataLoader(train_target_dataset, **params)\n", - "\n", - " # Get the number of batches in the longer dataset to align both\n", - " n_batches = max(len(source_generator), len(target_generator))\n", - "\n", - " # Combine the source and target data loaders using MultiDataLoader\n", - " training_generator = MultiDataLoader(\n", - " dataloaders=[source_generator, target_generator], n_batches=n_batches\n", - " )" - ], - "cell_type": "code", - "outputs": [], - "id": "24ba12b5", - "execution_count": null - }, - { - "metadata": {}, - "source": [ - "Lastly, we set up DataLoaders for validation and testing. Since we don’t want to shuffle or drop any samples, we adjust the parameters accordingly." - ], - "cell_type": "markdown", - "id": "649301de" - }, - { - "metadata": {}, - "source": [ - "# Update parameters for validation/testing (no shuffling, keep all data)\n", - "params.update({\"shuffle\": False, \"drop_last\": False})\n", - "\n", - "# Create validation and test data loaders\n", - "valid_generator = DataLoader(test_target_dataset, **params)\n", - "test_generator = DataLoader(test_target_dataset, **params)" - ], - "cell_type": "code", - "outputs": [], - "id": "b4cf543a", - "execution_count": null - }, - { - "metadata": {}, - "source": [ - "### Exercise: Dataset Inspection\n", - "\n", - "Once the dataset is ready, let’s inspect one sample from the training data to check the input graph, protein sequence, and label format." - ], - "cell_type": "markdown", - "id": "e474eea2" - }, - { - "metadata": {}, - "source": [ - "# Get the first batch (contains one batch from source and one from target)\n", - "first_batch = next(iter(training_generator))\n", - "\n", - "# Unpack source and target batches\n", - "source_batch, target_batch = first_batch\n", - "\n", - "# Inspect the first sample from the source batch\n", - "print(\"First sample from source batch:\")\n", - "print(\"Drug graph:\", source_batch[0][0])\n", - "print(\"Protein sequence:\", source_batch[1][0])\n", - "print(\"Label:\", source_batch[2][0])" - ], - "cell_type": "code", - "outputs": [], - "id": "31b8a93f", - "execution_count": null - }, - { - "metadata": {}, - "source": [ - "This sample is a tuple with three parts:\n", - "\n", - "1. **Drug Graph**\n", - "- `x=[290, 7]`: Feature matrix with 290 atoms (nodes) and 7 features per atom.\n", - "- `edge_index=[2, 58]`: Shows 146 edges, with source and target node indices.\n", - "- `edge_attr=[58, 1]`: Each edge has 1 bond feature, such as bond type.\n", - "- `num_nodes=290`: Confirms the graph has 290 nodes.\n", - "\n", - "2. **Protein Features (array)**\n", - "- Example values: `[11., 1., 18., ..., 0., 0., 0.]`: A fixed-length numeric array representing the protein sequence. Each position holds an integer-encoded amino acid, with zeros for padding.\n", - "\n", - "3. **Label (float)**\n", - "- `0.0`; The ground-truth interaction label indicating no interaction." - ], - "cell_type": "markdown", - "id": "cb0b269b" - }, - { - "metadata": {}, - "source": [ - "## Step 2: Model Definition" - ], - "cell_type": "markdown", - "id": "8eaf5c8f" - }, - { - "metadata": {}, - "source": [ - "### Embed\n", - "\n", - "DrugBAN consists of three main components: a Graph Convolutional Network (GCN) for extracting structural features from drug molecular graphs, a Convolutional Neural Network (CNN) for encoding protein sequences, and a Bilinear Attention Network (BAN) for fusing drug and protein features. The fused representation is then passed through a Multi-Layer Perceptron (MLP) classifier to predict interaction scores.\n", - "\n", - "We define the DrugBAN class in `kale.embed.ban`." - ], - "cell_type": "markdown", - "id": "b2819549" - }, - { - "metadata": {}, - "source": [ - "from kale.embed.ban import DrugBAN\n", - "\n", - "model = DrugBAN(**cfg)\n", - "print(model)" - ], - "cell_type": "code", - "outputs": [], - "id": "1c8f3acc", - "execution_count": null - }, - { - "metadata": {}, - "source": [ - "### Predict\n", - "We use the PyKale pipeline API `kale.pipeline.drugban_trainer` to connect dataloaders, encoders and outcoders for model training and evaluation." - ], - "cell_type": "markdown", - "id": "32084f24" - }, - { - "metadata": {}, - "source": [ - "from kale.pipeline.drugban_trainer import DrugbanTrainer\n", - "\n", - "drugban_trainer = DrugbanTrainer(\n", - " model=DrugBAN(**cfg),\n", - " solver_lr=cfg.SOLVER.LEARNING_RATE,\n", - " num_classes=cfg.DECODER.BINARY,\n", - " batch_size=cfg.SOLVER.BATCH_SIZE,\n", - " is_da=cfg.DA.USE,\n", - " solver_da_lr=cfg.SOLVER.DA_LEARNING_RATE,\n", - " da_init_epoch=cfg.DA.INIT_EPOCH,\n", - " da_method=cfg.DA.METHOD,\n", - " original_random=cfg.DA.ORIGINAL_RANDOM,\n", - " use_da_entropy=cfg.DA.USE_ENTROPY,\n", - " da_random_layer=cfg.DA.RANDOM_LAYER,\n", - " da_random_dim=cfg.DA.RANDOM_DIM,\n", - " decoder_in_dim=cfg.DECODER.IN_DIM,\n", - ")" - ], - "cell_type": "code", - "outputs": [], - "id": "46e2b9b4", - "execution_count": null - }, - { - "metadata": {}, - "source": [ - "We want to save the best model during training so we can reuse it later without needing to retrain. PyTorch Lightning’s `ModelCheckpoint` does this by automatically saving the model whenever it achieves a new best validation AUROC score." - ], - "cell_type": "markdown", - "id": "a48c86b9" - }, - { - "metadata": {}, - "source": [ - "import pytorch_lightning as pl\n", - "from pytorch_lightning.callbacks import ModelCheckpoint\n", - "\n", - "checkpoint_callback = ModelCheckpoint(\n", - " filename=\"{epoch}-{step}-{val_BinaryAUROC:.4f}\",\n", - " monitor=\"val_BinaryAUROC\",\n", - " mode=\"max\",\n", - ")" - ], - "cell_type": "code", - "outputs": [], - "id": "7754bd38", - "execution_count": null - }, - { - "metadata": {}, - "source": [ - "We now create the `Trainer`." - ], - "cell_type": "markdown", - "id": "969beac0" - }, - { - "metadata": {}, - "source": [ - "import torch\n", - "\n", - "trainer = pl.Trainer(\n", - " callbacks=[checkpoint_callback],\n", - " devices=\"auto\",\n", - " accelerator=\"auto\",\n", - " max_epochs=cfg.SOLVER.MAX_EPOCH,\n", - " deterministic=True,\n", - ")" - ], - "cell_type": "code", - "outputs": [], - "id": "e68e07bc", - "execution_count": null - }, - { - "metadata": {}, - "source": [ - "## Step 3: Model Training" - ], - "cell_type": "markdown", - "id": "1f9a4714" - }, - { - "metadata": {}, - "source": [ - "### Train\n", - "\n", - "After setting up the model and data loaders, we now start training the full DrugBAN model using the PyTorch Lightning Trainer via calling `trainer.fit()`.\n", - "\n", - "#### What Happens Here?\n", - "- The model receives batches of drug-protein pairs from the training data loader.\n", - "\n", - "- During each step, the GCN, CNN, BAN layer, and MLP classifier are updated to improve interaction prediction.\n", - "\n", - "- Validation is automatically run at the end of each epoch to track performance and save the best model based on AUROC.\n", - "\n", - "\n", - "This code block takes approximately 5 minutes to complete." - ], - "cell_type": "markdown", - "id": "b72634ee" - }, - { - "metadata": {}, - "source": [ - "trainer.fit(\n", - " drugban_trainer,\n", - " train_dataloaders=training_generator,\n", - " val_dataloaders=valid_generator,\n", - ")" - ], - "cell_type": "code", - "outputs": [], - "id": "0624b0c6", - "execution_count": null - }, - { - "metadata": {}, - "source": [ - "## Step 4: Evaluation\n", - "\n", - "Once training is complete, we evaluate the model on the test set using `trainer.test()`.\n", - "\n", - "### What is included in this step?\n", - "- The best model checkpoint (based on validation AUROC) is automatically loaded.\n", - "\n", - "- The model runs on the test data to generate predictions.\n", - "\n", - "- Final classification metrics, including AUROC, F1 score, accuracy, sensitivity, and specificity, are calculated and logged." - ], - "cell_type": "markdown", - "id": "23b3975c" - }, - { - "metadata": {}, - "source": [ - "trainer.test(drugban_trainer, dataloaders=test_generator, ckpt_path=\"best\")" - ], - "cell_type": "code", - "outputs": [], - "id": "c1415c02", - "execution_count": null - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": [ - "### Performance Comparison\n", - "\n", - "The earlier example was a simple demonstration. To properly evaluate DrugBAN against baseline models, we train it for 100 epochs across multiple random seeds.\n", - "\n", - "We provide a checkpoint trained for 100 epochs in the `checkpoint` for your test after the tutorial. We will also use the provided checkpoint for the interpretation section for a better visualization.\n" - ], - "id": "bb0a08bec91d2bd9" - }, - { - "metadata": {}, - "source": [ - "The figure below shows the performance of different models on the BioSNAP and BindingDB datasets:\n", - "- Left plot: AUROC (Area Under the ROC Curve)\n", - "- Right plot: AUPRC (Area Under the Precision–Recall Curve)\n", - "\n", - "![](https://media.springernature.com/full/springer-static/image/art%3A10.1038%2Fs42256-022-00605-1/MediaObjects/42256_2022_605_Fig3_HTML.png?as=webp)\n", - "\n", - "The box plots show the median as the centre lines and the mean as green triangles. The minima and lower percentile represent the worst and second-worst scores. The maxima and upper percentile indicate the best and second-best scores. Supplementary Table 2 provides the data statistics of the BindingDB and BioSNAP datasets." - ], - "cell_type": "markdown", - "id": "37dbe9f3" - }, - { - "metadata": {}, - "source": [ - "## Step 5: Interpretation\n", - "\n", - "We interpret the trained models by analyzing the learned attention weights. In this step, we will use PyKale's API to\n", - "1) draw the attention maps of the Bilinear Attention Network (BAN) layer, and\n", - "2) generate molecule images with attention highlights.\n", - "\n", - "This helps us understand which parts of the drug contribute to the interaction with the target protein." - ], - "cell_type": "markdown", - "id": "02e3c73e" - }, - { - "metadata": {}, - "source": [ - "### Extracting Attention Weights\n", - "First, we need to load the test dataset and create a DataLoader for it. This will allow us to process the test samples in batches. We define functions to create the test dataset and DataLoader." - ], - "cell_type": "markdown", - "id": "4a56f260141b7368" - }, - { - "metadata": {}, - "source": [ - "def get_test_dataset(dataFolder):\n", - " df_test_target = pd.read_csv(dataFolder)\n", - " test_target_dataset = DTIDataset(df_test_target.index.values, df_test_target)\n", - " return test_target_dataset\n", - "\n", - "\n", - "def get_test_dataloader(dataset, batchsize, num_workers, collate_fn):\n", - " test_dataloader = DataLoader(\n", - " dataset,\n", - " batch_size=batchsize,\n", - " num_workers=num_workers,\n", - " collate_fn=collate_fn,\n", - " shuffle=False,\n", - " drop_last=True,\n", - " )\n", - " return test_dataloader" - ], - "cell_type": "code", - "id": "2c67553408592b2", - "outputs": [], - "execution_count": null - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": "We load a small subset of samples for testing from the provided `.csv` file. You can create your own `.csv` file with the same format to test your drug–protein pairs.", - "id": "ecdab66ee05da10c" - }, - { - "metadata": {}, - "source": [ - "test_dataFolder = \"/content/drug-target-interaction/data/drug-target-interaction/bindingdb/interpretation_samples.csv\"" - ], - "cell_type": "code", - "outputs": [], - "id": "7ef1867541d2577a", - "execution_count": null - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": "We then build the test dataset and DataLoader using the functions defined above. The `batchsize` is set to 1 to ensure we process one sample at a time for attention visualization later.", - "id": "7fec5dc00a7b4aa4" - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": [ - "test_dataset = get_test_dataset(test_dataFolder)\n", - "test_dataloader = get_test_dataloader(\n", - " test_dataset,\n", - " batchsize=1,\n", - " num_workers=cfg.SOLVER.NUM_WORKERS,\n", - " collate_fn=graph_collate_func,\n", - ")" - ], - "id": "c99a558c96a1ffd" - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": "Then, we use the following function to load the trained model with the PyKale API.", - "id": "e1ff543d132abc42" - }, - { - "metadata": {}, - "source": [ - "def get_model_from_ckpt(ckpt_path, config):\n", - " return DrugbanTrainer.load_from_checkpoint(\n", - " checkpoint_path=ckpt_path,\n", - " model=DrugBAN(**config),\n", - " solver_lr=config.SOLVER.LEARNING_RATE,\n", - " num_classes=config.DECODER.BINARY,\n", - " batch_size=config.SOLVER.BATCH_SIZE,\n", - " # --- domain adaptation parameters ---\n", - " is_da=config.DA.USE,\n", - " solver_da_lr=config.SOLVER.DA_LEARNING_RATE,\n", - " da_init_epoch=config.DA.INIT_EPOCH,\n", - " da_method=config.DA.METHOD,\n", - " original_random=config.DA.ORIGINAL_RANDOM,\n", - " use_da_entropy=config.DA.USE_ENTROPY,\n", - " da_random_layer=config.DA.RANDOM_LAYER,\n", - " # --- discriminator parameters ---\n", - " da_random_dim=config.DA.RANDOM_DIM,\n", - " decoder_in_dim=config.DECODER.IN_DIM,\n", - " )" - ], - "cell_type": "code", - "outputs": [], - "id": "3b7f12b12b139799", - "execution_count": null - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": "Once the model and test data are prepared, we extract attention maps from the trained model. We set the directory to the provided checkpoint file, load the trained model, and set it to evaluation mode.", - "id": "c0678dddcdf076fc" - }, - { - "metadata": {}, - "source": [ - "checkpoint_path = \"/content/drug-target-interaction/checkpoint/best.ckpt\"\n", - "model = get_model_from_ckpt(checkpoint_path, cfg)\n", - "model.model.eval()" - ], - "cell_type": "code", - "outputs": [], - "id": "d2a8931099b73c01", - "execution_count": null - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": "We then iterate through the test DataLoader, passing each batch of drug and protein pairs to the model. The model's forward method returns the attention weights. After processing all batches, we concatenate the attention tensors into a single tensor.", - "id": "159d3fa67b29c9e9" - }, - { - "metadata": {}, - "source": [ - "from tqdm import tqdm\n", - "\n", - "all_attentions = []\n", - "for batch in tqdm(test_dataloader):\n", - " drug, protein, _ = batch\n", - " drug, protein = drug.to(model.device), protein.to(model.device)\n", - "\n", - " _, _, _, _, attention = model.model.forward(\n", - " drug, protein, mode=\"eval\"\n", - " ) # [B, H, V, Q]\n", - "\n", - " attention = attention.detach().cpu()\n", - " all_attentions.append(attention)\n", - "\n", - "# Concatenate into one tensor: [N, H, V, Q]\n", - "all_attentions = torch.cat(all_attentions, dim=0)\n", - "torch.save(all_attentions, \"attention_maps.pt\")\n", - "\n", - "all_attentions.shape" - ], - "cell_type": "code", - "outputs": [], - "id": "781a7762c36c72be", - "execution_count": null - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": "The attention has shape [B, H, V, Q] (Number of drug-target pairs, Heads of attentions, Drug tokens, Protein tokens).", - "id": "78dc763b6c0eef0" - }, - { - "metadata": {}, - "source": "### Visualize Attention Maps and Molecule Images", - "cell_type": "markdown", - "id": "8f72ea4d93f640cb" - }, - { - "metadata": {}, - "source": "Once attention maps are saved, run the visualization script:", - "cell_type": "markdown", - "id": "383c342a7c31d7ae" - }, - { - "metadata": {}, - "source": [ - "This script will:\n", - "\n", - "1) Load the attention weights and the corresponding SMILES + protein data.\n", - "\n", - "2) Plot:\n", - "\n", - " a) A heatmap of attention over drug–protein tokens.\n", - "\n", - " b) Molecular structures with atoms highlighted by attention values.\n", - "\n", - "The output images are saved in the `visualization` directory. You can also modify the `data_file` to use your own input in the same format as `target_test.csv`.\n", - "\n" - ], - "cell_type": "markdown", - "id": "d8a746169def8da5" - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": "We first import the necessary PyKale APIs and set the output directory.", - "id": "aac54bfc67ce32eb" - }, - { - "metadata": {}, - "source": [ - "%pip install nilearn\n", - "from kale.interpret.visualize import draw_attention_map, draw_mol_with_attention\n", - "from kale.prepdata.tensor_reshape import normalize_tensor\n", - "\n", - "out_dir = \"./visualization\"\n", - "os.makedirs(out_dir, exist_ok=True)" - ], - "cell_type": "code", - "outputs": [], - "id": "d3c1d2e4cab69107", - "execution_count": null - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": "We then load the attention maps, data, and SMILES strings from the test dataset.", - "id": "126b62034111d92a" - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": [ - "attention = torch.load(\"attention_maps.pt\", map_location=\"cpu\")\n", - "data_df = pd.read_csv(test_dataFolder)\n", - "smiles = data_df[\"SMILES\"]\n", - "proteins = data_df[\"Protein\"]" - ], - "id": "7f70a6810c1c5e60" - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": "We select the first sample from the attention maps and corresponding SMILES and protein sequence for visualization.", - "id": "d1a009bbb9f4a0f9" - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": [ - "index = 0\n", - "att_path = os.path.join(out_dir, f\"att_map_{index}.png\")\n", - "mol_path = os.path.join(out_dir, f\"mol_{index}.svg\")" - ], - "id": "e808c255fe862925" - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": "We crop the attention map to the actual lengths of the drug and protein sequences. This is important because the attention map may include padding tokens.", - "id": "438e6aa218e6b51d" - }, - { - "metadata": {}, - "cell_type": "code", - "outputs": [], - "execution_count": null, - "source": [ - "from rdkit import Chem\n", - "\n", - "\n", - "def get_real_length(smile, protein_sequence):\n", - " \"\"\"Get the real length of the drug and protein sequences.\"\"\"\n", - " mol = Chem.MolFromSmiles(smile)\n", - " return mol.GetNumAtoms(), len(protein_sequence)\n", - "\n", - "\n", - "att = attention[index] # [H, V, Q]\n", - "smile = smiles[index]\n", - "protein = proteins[index]\n", - "real_drug_len, real_prot_len = get_real_length(smile, protein)\n", - "att = att[:, :real_drug_len, :real_prot_len].mean(0) # [V, Q]\n", - "\n", - "# Normalize\n", - "att = normalize_tensor(att)" - ], - "id": "af15baa1c8caabc0" - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": "Finally, we save the attention map and the molecule image with attention highlights.", - "id": "60a4ce71146a721e" - }, - { - "metadata": {}, - "source": [ - "draw_attention_map(\n", - " att,\n", - " att_path,\n", - " title=f\"Drug {index} Attention\",\n", - " xlabel=\"Drug Tokens\",\n", - " ylabel=\"Protein Tokens\",\n", - ")" - ], - "cell_type": "code", - "outputs": [], - "id": "403f77ada0ecc446", - "execution_count": null - }, - { - "metadata": {}, - "source": [ - "draw_mol_with_attention(att.mean(dim=1), smile, mol_path)" - ], - "cell_type": "code", - "outputs": [], - "id": "b1003372361a66d6", - "execution_count": null - }, - { - "metadata": {}, - "cell_type": "markdown", - "source": "The output images are saved in the `visualization` directory. The attention map shows how much each drug token attends to each protein token, while the molecule image highlights the atoms based on their attention values.", - "id": "1999d67d6d14b263" - }, - { - "metadata": {}, - "source": [ - "## Extension Tasks" - ], - "cell_type": "markdown", - "id": "eeb308c3" - }, - { - "metadata": {}, - "source": [ - "### Task 1\n", - "\n", - "To use the BindingDB dataset, modify the relevant line in the Configuration section of Step 0 as shown below.\n", - "\n", - "```python\n", - "cfg.DATA.DATASET = \"bindingdb\"\n", - "```\n", - "\n", - "Reload the dataset and re-run training and testing.\n", - "\n", - "> Tip: See if the model struggles more or less with the new dataset. It can reveal how generalisable DrugBAN is.\n" - ], - "cell_type": "markdown", - "id": "aa2a83d8" - }, - { - "metadata": {}, - "source": [ - "### Task 2\n", - "\n", - "Turn off domain adaptation by updating the config file and re-running training and testing.\n", - "\n", - "Replace `configs/DA_cross_domain.yaml` with `configs/non_DA_cross_domain.yaml` in the Configuration section of Step 0 as shown below.\n", - "\n", - "```python\n", - "cfg.merge_from_file(\"configs/non_DA_cross_domain.yaml\")\n", - "```\n", - ">Tip: Compare the results with and without domain adaptation to see how it affects model performance." - ], - "cell_type": "markdown", - "id": "c94f174c" - } - ] + "cells": [ + { + "cell_type": "markdown", + "id": "8c1bf9c7", + "metadata": { + "id": "8c1bf9c7" + }, + "source": [ + "# Drug–Target Interaction Prediction\n", + "\n", + "![](https://github.com/pykale/mmai-tutorials/blob/main/tutorials/drug-target-interaction/images/drugban-pyakle-api.png?raw=1)\n", + "\n", + "\n", + "In this tutorial, we will train models to predict the interaction between **two data modalities**: **molecules (drug)** and **proteins (target)** using `PyKale`. Drug-target interaction (DTI) plays a key role in drug discovery and identifying potential therapeutic targets. This example is based on the **DrugBAN** framework by [**Bai et al. (_Nature Machine Intelligence_, 2023)**](https://www.nature.com/articles/s42256-022-00605-1).\n", + "\n", + "The DTI prediction problem is formulated as a **binary classification task**, where the goal is to predict whether a given **drug–protein pair interacts or not**. The DrugBAN framework tackles this problem using two key ideas:\n", + "\n", + "- **Bilinear Attention Network (BAN)**, which learns detailed feature representations for both drugs and proteins and captures local interaction patterns between them.\n", + "\n", + "- **Adversarial Domain Adaptation**, which helps the model generalise to out-of-distribution datasets, i.e., in clustering-based cross-validation instead of random splits, improving its ability to predict interactions on unseen drug–target pairs.\n", + "\n", + "With `PyKale`, implementing such a multimodal DTI prediction pipeline is straightforward. The library provides ready-to-use modules and configuration support, making it easy to apply advanced techniques with minimal custom coding." + ] + }, + { + "cell_type": "markdown", + "id": "745ccdcf", + "metadata": { + "id": "745ccdcf" + }, + "source": [ + "## Step 0: Environment Preparation\n", + "\n", + "As a starting point, we will install the required packages and load a set of helper functions to assist throughout this tutorial.\n", + "\n", + "To prepare the helper functions and necessary materials, we download them from the GitHub repository.\n", + "\n", + "Moreover, we provide helper functions that can be inspected directly in the `.py` files located in the notebook's current directory. The additional helper script is:\n", + "- [`config.py`](https://github.com/pykale/embc-mmai25/blob/main/tutorials/drug-target-interaction/configs.py): Defines the base configuration settings, which can be overridden using a custom `.yaml` file." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "a6028209", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "a6028209", + "outputId": "bdba9c0d-e5a6-4dba-981d-b12e21d2c463" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Cloning into 'mmai-tutorials'...\n", + "remote: Enumerating objects: 610, done.\u001b[K\n", + "remote: Counting objects: 100% (242/242), done.\u001b[K\n", + "remote: Compressing objects: 100% (146/146), done.\u001b[K\n", + "remote: Total 610 (delta 156), reused 121 (delta 96), pack-reused 368 (from 2)\u001b[K\n", + "Receiving objects: 100% (610/610), 23.16 MiB | 16.03 MiB/s, done.\n", + "Resolving deltas: 100% (309/309), done.\n", + "mv: cannot move '/content/mmai-tutorials/tutorials/drug-target-interaction' to '/content/drug-target-interaction': Directory not empty\n", + "/content/drug-target-interaction\n", + "Changed working directory to: /content/drug-target-interaction\n" + ] + } + ], + "source": [ + "import os\n", + "\n", + "!rm -rf /content/mmai-tutorials\n", + "!git clone --single-branch -b main https://github.com/pykale/mmai-tutorials.git\n", + "%mv /content/mmai-tutorials/tutorials/drug-target-interaction /content/\n", + "%cd /content/drug-target-interaction\n", + "\n", + "print(\"Changed working directory to:\", os.getcwd())" + ] + }, + { + "cell_type": "markdown", + "id": "c52c6334", + "metadata": { + "id": "c52c6334" + }, + "source": [ + "### Package Installation\n", + "\n", + "The main package required for this tutorial is `PyKale`.\n", + "\n", + "`PyKale` is an open-source interdisciplinary machine learning library developed at the University of Sheffield, with a focus on applications in biomedical and scientific domains.\n", + "\n", + "Then, we install `PyG` (PyTorch Geometric) and related packages.\n", + "\n", + "Please **do not** re-run this session after installation completed. Runing this installation multiple times will trigger issues related to `PyG`. If you want to re-run this installation, please click the `Runtime` on the top menu and choose `Disconnect and delete runtime` before installing." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "53e3b14e", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "53e3b14e", + "outputId": "7890c8c4-6c6c-4156-bdf4-55ffbf9ee2b5" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "pykale, gdown, pyg and yacs installed successfully ✅\n" + ] + } + ], + "source": [ + "%pip install --quiet \\\n", + " \"pykale[example]@git+https://github.com/pykale/pykale@main\" \\\n", + " gdown==5.2.0 torch-geometric==2.6.0 torch_sparse torch_scatter \\\n", + " -f https://data.pyg.org/whl/torch-2.6.0+cu124.html \\\n", + " && echo \"pykale, gdown, pyg and yacs installed successfully ✅\" \\\n", + " || echo \"Failed to install pykale, gdown, pyg and yacs ❌\"" + ] + }, + { + "cell_type": "markdown", + "id": "69f50b6a", + "metadata": { + "id": "69f50b6a" + }, + "source": [ + "We then hide the warnings messages to get a clear output." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "6e871c63", + "metadata": { + "id": "6e871c63" + }, + "outputs": [], + "source": [ + "import os\n", + "import warnings\n", + "\n", + "warnings.filterwarnings(\"ignore\")\n", + "os.environ[\"PYTHONWARNINGS\"] = \"ignore\"" + ] + }, + { + "cell_type": "markdown", + "id": "6606e3fb", + "metadata": { + "id": "6606e3fb" + }, + "source": [ + "Exercise: Check NumPy Version" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "0d384020", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "0d384020", + "outputId": "d1ecfffa-1567-4b1c-dc8e-cbbeda3728c0" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "NumPy version: 2.0.2\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "\n", + "print(\"NumPy version:\", np.__version__) # numpy should be 2.0.0 or higher" + ] + }, + { + "cell_type": "markdown", + "id": "cabd3406", + "metadata": { + "id": "cabd3406" + }, + "source": [ + "### Configuration\n", + "\n", + "To minimize the footprint of the notebook when specifying configurations, we provide a [`config.py`](https://github.com/pykale/embc-mmai25/blob/main/tutorials/drug-target-interaction/configs.py) file that defines default parameters. These can be customized by supplying a `.yaml` configuration file, such as [`configs/DA_cross_domain.yaml`](https://github.com/pykale/embc-mmai25/blob/main/tutorials/drug-target-interaction/configs/DA_cross_domain.yaml) as an example." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "55c13b48", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "55c13b48", + "outputId": "dd3df032-9263-4bcf-d47c-3059d1f66830" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/content/drug-target-interaction\n" + ] + } + ], + "source": [ + "from configs import get_cfg_defaults\n", + "\n", + "%cd /content/drug-target-interaction\n", + "\n", + "cfg = get_cfg_defaults() # Load the default settings from config.py\n", + "cfg.merge_from_file(\n", + " \"configs/DA_cross_domain.yaml\"\n", + ") # Update (or override) some of those settings using a custom YAML file" + ] + }, + { + "cell_type": "markdown", + "id": "74ffdbc2", + "metadata": { + "id": "74ffdbc2" + }, + "source": [ + "In this tutorial, we list the hyperparameters we would like users to play with outside the `.yaml` file:\n", + "- `cfg.SOLVER.MAX_EPOCH`: Number of epochs in training stage. You can reduce the number of training epochs to shorten runtime.\n", + "- `cfg.DATA.DATASET`: The dataset used in the study. This can be `bindingdb` or `biosnap`.\n", + "\n", + "As a quick exercise, please take a moment to review and understand the parameters in [`config.py`](https://github.com/pykale/embc-mmai25/blob/main/tutorials/drug-target-interaction/configs.py)." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "424c7286", + "metadata": { + "id": "424c7286" + }, + "outputs": [], + "source": [ + "cfg.SOLVER.MAX_EPOCH = 2" + ] + }, + { + "cell_type": "markdown", + "id": "97c088fd", + "metadata": { + "id": "97c088fd" + }, + "source": [ + "You can also switch to a different dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "c69376fa", + "metadata": { + "id": "c69376fa" + }, + "outputs": [], + "source": [ + "cfg.DATA.DATASET = \"biosnap\"" + ] + }, + { + "cell_type": "markdown", + "id": "d3d41633", + "metadata": { + "id": "d3d41633" + }, + "source": [ + "Exercise: Now print the full configuration to check all current hyperparameter and dataset settings." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "45874296", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "45874296", + "outputId": "5c4738ff-85e8-463f-dac7-6e183643a223" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "BCN:\n", + " HEADS: 2\n", + "COMET:\n", + " API_KEY: \n", + " EXPERIMENT_NAME: DA_cross_domain\n", + " PROJECT_NAME: drugban-23-May\n", + " TAG: DrugBAN_CDAN\n", + " USE: False\n", + "DA:\n", + " INIT_EPOCH: 10\n", + " LAMB_DA: 1\n", + " METHOD: CDAN\n", + " ORIGINAL_RANDOM: True\n", + " RANDOM_DIM: 256\n", + " RANDOM_LAYER: True\n", + " TASK: True\n", + " USE: True\n", + " USE_ENTROPY: False\n", + "DATA:\n", + " DATASET: biosnap\n", + " SPLIT: cluster\n", + "DECODER:\n", + " BINARY: 2\n", + " HIDDEN_DIM: 512\n", + " IN_DIM: 256\n", + " NAME: MLP\n", + " OUT_DIM: 128\n", + "DRUG:\n", + " HIDDEN_LAYERS: [128, 128, 128]\n", + " MAX_NODES: 290\n", + " NODE_IN_EMBEDDING: 128\n", + " NODE_IN_FEATS: 7\n", + " PADDING: True\n", + "PROTEIN:\n", + " EMBEDDING_DIM: 128\n", + " KERNEL_SIZE: [3, 6, 9]\n", + " NUM_FILTERS: [128, 128, 128]\n", + " PADDING: True\n", + "RESULT:\n", + " SAVE_MODEL: True\n", + "SOLVER:\n", + " BATCH_SIZE: 32\n", + " DA_LEARNING_RATE: 5e-05\n", + " LEARNING_RATE: 0.0001\n", + " MAX_EPOCH: 2\n", + " NUM_WORKERS: 0\n", + " SEED: 20\n" + ] + } + ], + "source": [ + "print(cfg)" + ] + }, + { + "cell_type": "markdown", + "id": "17558d0c", + "metadata": { + "id": "17558d0c" + }, + "source": [ + "## Step 1: Data Loading and Preparation\n", + "\n", + "In this tutorial, we use the **Biosnap** dataset for the main demonstration and the **BindingDB** dataset for the exercise at the end." + ] + }, + { + "cell_type": "markdown", + "id": "6c6071b9", + "metadata": { + "id": "6c6071b9" + }, + "source": [ + "### Data Downloading\n", + "\n", + "Please run the following cell to download necessary datasets." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "56f9f58e", + "metadata": { + "id": "56f9f58e" + }, + "outputs": [], + "source": [ + "!rm -rf data\n", + "!mkdir data\n", + "!cd data\n", + "\n", + "!pip install -q gdown\n", + "!gdown --id 1ogOcxZn-1q418LOT-gQ94aHQV0Y1sOmk --output data/drug-target-interaction.zip\n", + "!unzip data/drug-target-interaction.zip -d data/\n", + "!mv data/drug-target-interaction/checkpoint ./" + ] + }, + { + "cell_type": "markdown", + "id": "c39b3e39", + "metadata": { + "id": "c39b3e39" + }, + "source": [ + "Exercise: Check the data is ready" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "a6258d1f", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "a6258d1f", + "outputId": "2813ad1f-5971-44ca-f9b6-c90a496fe4b6" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Contents of the data folder:\n", + "biosnap\n", + "bindingdb\n", + "checkpoint\n" + ] + } + ], + "source": [ + "import os\n", + "import shutil\n", + "\n", + "print(\"Contents of the data folder:\")\n", + "for item in os.listdir(\"data/drug-target-interaction\"):\n", + " print(item)" + ] + }, + { + "cell_type": "markdown", + "id": "9ab0b5f833dc40f8", + "metadata": { + "id": "9ab0b5f833dc40f8" + }, + "source": [ + "The data content is structured as follows:\n", + "```sh\n", + " ├───data\n", + " │ ├───checkpoint\n", + " │ ├───bindingdb\n", + " │ ├───biosnap" + ] + }, + { + "cell_type": "markdown", + "id": "5be1dcc62b7d5649", + "metadata": { + "id": "5be1dcc62b7d5649" + }, + "source": [ + "The `data` folder contains two datasets: `bindingdb` and `biosnap`. Each dataset folder contains the following files. The `checkpoint` folder contains the saved model checkpoint, which are used later in the interpretation section." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "a93303c51c8b974e", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "a93303c51c8b974e", + "outputId": "93e0c029-b489-4f4d-8889-8b99f868039c" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Contents of bindingdb folder:\n", + "random\n", + "full.csv\n", + "interpretation_samples.csv\n", + "cluster\n" + ] + } + ], + "source": [ + "print(\"Contents of bindingdb folder:\")\n", + "for item in os.listdir(\"data/drug-target-interaction/bindingdb\"):\n", + " print(item)" + ] + }, + { + "cell_type": "markdown", + "id": "79cbc1c1", + "metadata": { + "id": "79cbc1c1" + }, + "source": [ + "Each dataset folder follows the structure:\n", + "\n", + "```sh\n", + " ├───dataset_name\n", + " │ ├───cluster\n", + " │ │ ├───source_train.csv\n", + " │ │ ├───target_train.csv\n", + " │ │ ├───target_test.csv\n", + " │ ├───random\n", + " │ │ ├───test.csv\n", + " │ │ ├───train.csv\n", + " │ │ ├───val.csv\n", + " │ ├───full.csv\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "d35e04f9", + "metadata": { + "id": "d35e04f9" + }, + "source": [ + "We use the cluster dataset folder for cross-domain prediction, containing three parts:\n", + "\n", + "- Train samples from the source domain: Drug–protein pairs the model learns from.\n", + "\n", + "- Train samples from the target domain: Additional training data from a different distribution to improve generalisation.\n", + "\n", + "- Test samples from the target domain: Unseen drug–protein pairs used to evaluate model performance on new data.\n", + "\n", + "The source and target sets are defined based on the clustering results." + ] + }, + { + "cell_type": "markdown", + "id": "98acf744", + "metadata": { + "id": "98acf744" + }, + "source": [ + "### Data Loading" + ] + }, + { + "cell_type": "markdown", + "id": "1e5f4f44", + "metadata": { + "id": "1e5f4f44" + }, + "source": [ + "Here’s what each csv file looks like in a table format:\n", + "\n", + "| SMILES | Protein Sequence | Y |\n", + "|--------------------|--------------------------|---|\n", + "| Fc1ccc(C2(COC…) | MDNVLPVDSDLS… | 1 |\n", + "| O=c1oc2c(O)c(…) | MMYSKLLTLTTL… | 0 |\n", + "| CC(C)Oc1cc(N…) | MGMACLTMTEME… | 1 |\n", + "\n", + "Each row of the dataset contains three key pieces of information:\n", + "\n", + "**Drugs**: \n", + "Drugs are often written as SMILES strings, which are like chemical formulas in text format (for example, `\"CC(=O)OC1=CC=CC=C1C(=O)O\"` is aspirin). \n", + "\n", + "\n", + "**Protein Sequence** \n", + "This is a string of letters where each letter stands for an amino acid, the building blocks of proteins. For example, `MGYTSLLT...` is a short protein sequence.\n", + "\n", + "\n", + "**Y (Labels)**: \n", + "Each drug–protein pair is given a label:\n", + "- `1` if they interact\n", + "- `0` if they do not\n", + "\n", + "\n", + "Each row shows one drug–protein pair. The goal of our machine learning model is to predict the last column (**Y**) — whether or not the drug and protein interact." + ] + }, + { + "cell_type": "markdown", + "id": "b7590daf", + "metadata": { + "id": "b7590daf" + }, + "source": [ + "You can load CSV files into Python using tools like `pandas`. The output shows a sample of the data, including the SMILES string for the drug, the protein sequence, the interaction label (Y) and the cluster ID." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "0c709e31", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "0c709e31", + "outputId": "5e663c15-f59f-4a0f-acc3-b545b6392a1e" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sample example: SMILES CC1=CN=C2N1C=CN=C2NCC1=CC=NC=C1\n", + "Protein MARSLLLPLQILLLSLALETAGEEAQGDKIIDGAPCARGSHPWQVA...\n", + "Y 0.0\n", + "drug_cluster 1904\n", + "target_cluster 1528\n", + "Name: 0, dtype: object\n" + ] + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "dataFolder = os.path.join(\n", + " f\"data/drug-target-interaction/{cfg.DATA.DATASET}\", str(cfg.DATA.SPLIT)\n", + ")\n", + "\n", + "df_train_source = pd.read_csv(os.path.join(dataFolder, \"source_train.csv\"))\n", + "df_train_target = pd.read_csv(os.path.join(dataFolder, \"target_train.csv\"))\n", + "df_test_target = pd.read_csv(os.path.join(dataFolder, \"target_test.csv\"))\n", + "\n", + "print(\"Sample example:\", df_train_source.iloc[0])" + ] + }, + { + "cell_type": "markdown", + "id": "542d4e69", + "metadata": { + "id": "542d4e69" + }, + "source": [ + "### Data Preprocessing\n", + "\n", + "We convert drug SMILES strings into molecular graphs using `kale.loaddata.molecular_datasets.smiles_to_graph`, encoding atom-level features as node attributes and bond types as edges.\n", + "\n", + "\n", + "Protein sequences are transformed into fixed-length integer arrays using `kale.prepdata.chem_transform.integer_label_protein`, with each amino acid mapped to an integer and sequences padded or truncated to a uniform length.\n", + "\n", + "Finally, the `kale.loaddata.molecular_datasets.DTIDataset` class packages drugs, proteins, and labels into a PyTorch-ready dataset." + ] + }, + { + "cell_type": "markdown", + "id": "981d5520", + "metadata": { + "id": "981d5520" + }, + "source": [ + "**Note:** If you encounter an error related to requiring numpy `<2.0`, simply ignore it and re-run this block until it completes successfully." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "ae5af8eb", + "metadata": { + "id": "ae5af8eb" + }, + "outputs": [], + "source": [ + "from kale.loaddata.molecular_datasets import DTIDataset\n", + "\n", + "# Create preprocessed datasets\n", + "train_dataset = DTIDataset(df_train_source.index.values, df_train_source)\n", + "train_target_dataset = DTIDataset(df_train_target.index.values, df_train_target)\n", + "test_target_dataset = DTIDataset(df_test_target.index.values, df_test_target)" + ] + }, + { + "cell_type": "markdown", + "id": "a0a510ce", + "metadata": { + "id": "a0a510ce" + }, + "source": [ + "We load data in small, manageable pieces called batches to save memory and speed up training. We use `kale.loaddata.sampler.MultiDataLoader` from PyKale to load one batch from the source domain and one from the target domain at each training step." + ] + }, + { + "cell_type": "markdown", + "id": "c09084c0", + "metadata": { + "id": "c09084c0" + }, + "source": [ + "First, we specify a few DataLoader parameters:\n", + "- Batch size: Number of samples per batch\n", + "- Shuffle: Randomly shuffle data\n", + "- Number of workers: Parallel data loading\n", + "- Drop last: Discard the last incomplete batch for consistent batch sizes\n", + "- Collate function: Use graph_collate_func to batch variable-sized molecular graphs" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "94a15868", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "94a15868", + "outputId": "a4c14890-db12-45b7-bcbe-1a94b4f846b2" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'batch_size': 32,\n", + " 'shuffle': True,\n", + " 'num_workers': 0,\n", + " 'drop_last': True,\n", + " 'collate_fn': }" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from torch.utils.data import DataLoader\n", + "from kale.loaddata.molecular_datasets import graph_collate_func\n", + "from kale.loaddata.sampler import MultiDataLoader\n", + "\n", + "params = {\n", + " \"batch_size\": cfg.SOLVER.BATCH_SIZE,\n", + " \"shuffle\": True,\n", + " \"num_workers\": cfg.SOLVER.NUM_WORKERS,\n", + " \"drop_last\": True,\n", + " \"collate_fn\": graph_collate_func,\n", + "}\n", + "\n", + "params" + ] + }, + { + "cell_type": "markdown", + "id": "e884ed07", + "metadata": { + "id": "e884ed07" + }, + "source": [ + "Then, we create a DataLoader from both the source and target datasets for training." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "24ba12b5", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "24ba12b5", + "outputId": "47a5c085-d2d0-48f8-cad2-4483e3fe2efa" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using domain adaptation: True\n" + ] + } + ], + "source": [ + "print(\"Using domain adaptation:\", cfg.DA.USE)\n", + "\n", + "if not cfg.DA.USE:\n", + " training_generator = DataLoader(train_dataset, **params)\n", + "else:\n", + " source_generator = DataLoader(train_dataset, **params)\n", + " target_generator = DataLoader(train_target_dataset, **params)\n", + "\n", + " # Get the number of batches in the longer dataset to align both\n", + " n_batches = max(len(source_generator), len(target_generator))\n", + "\n", + " # Combine the source and target data loaders using MultiDataLoader\n", + " training_generator = MultiDataLoader(\n", + " dataloaders=[source_generator, target_generator], n_batches=n_batches\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "649301de", + "metadata": { + "id": "649301de" + }, + "source": [ + "Lastly, we set up DataLoaders for validation and testing. Since we don’t want to shuffle or drop any samples, we adjust the parameters accordingly." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "b4cf543a", + "metadata": { + "id": "b4cf543a" + }, + "outputs": [], + "source": [ + "# Update parameters for validation/testing (no shuffling, keep all data)\n", + "params.update({\"shuffle\": False, \"drop_last\": False})\n", + "\n", + "# Create validation and test data loaders\n", + "valid_generator = DataLoader(test_target_dataset, **params)\n", + "test_generator = DataLoader(test_target_dataset, **params)" + ] + }, + { + "cell_type": "markdown", + "id": "e474eea2", + "metadata": { + "id": "e474eea2" + }, + "source": [ + "### Exercise: Dataset Inspection\n", + "\n", + "Once the dataset is ready, let’s inspect one sample from the training data to check the input graph, protein sequence, and label format." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "31b8a93f", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "31b8a93f", + "outputId": "74ae5660-3d74-4c5e-f5ed-065c1f099516" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "First sample from source batch:\n", + "Drug graph: Data(x=[290, 7], edge_index=[2, 106], edge_attr=[106, 1], num_nodes=290)\n", + "Protein sequence: tensor([11., 7., 18., ..., 0., 0., 0.], dtype=torch.float64)\n", + "Label: tensor(0., dtype=torch.float64)\n" + ] + } + ], + "source": [ + "# Get the first batch (contains one batch from source and one from target)\n", + "first_batch = next(iter(training_generator))\n", + "\n", + "# Unpack source and target batches\n", + "source_batch, target_batch = first_batch\n", + "\n", + "# Inspect the first sample from the source batch\n", + "print(\"First sample from source batch:\")\n", + "print(\"Drug graph:\", source_batch[0][0])\n", + "print(\"Protein sequence:\", source_batch[1][0])\n", + "print(\"Label:\", source_batch[2][0])" + ] + }, + { + "cell_type": "markdown", + "id": "cb0b269b", + "metadata": { + "id": "cb0b269b" + }, + "source": [ + "This sample is a tuple with three parts:\n", + "\n", + "1. **Drug Graph**\n", + "- `x=[290, 7]`: Feature matrix with 290 atoms (nodes) and 7 features per atom.\n", + "- `edge_index=[2, 58]`: Shows 146 edges, with source and target node indices.\n", + "- `edge_attr=[58, 1]`: Each edge has 1 bond feature, such as bond type.\n", + "- `num_nodes=290`: Confirms the graph has 290 nodes.\n", + "\n", + "2. **Protein Features (array)**\n", + "- Example values: `[11., 1., 18., ..., 0., 0., 0.]`: A fixed-length numeric array representing the protein sequence. Each position holds an integer-encoded amino acid, with zeros for padding.\n", + "\n", + "3. **Label (float)**\n", + "- `0.0`; The ground-truth interaction label indicating no interaction." + ] + }, + { + "cell_type": "markdown", + "id": "8eaf5c8f", + "metadata": { + "id": "8eaf5c8f" + }, + "source": [ + "## Step 2: Model Definition" + ] + }, + { + "cell_type": "markdown", + "id": "b2819549", + "metadata": { + "id": "b2819549" + }, + "source": [ + "### Embed\n", + "\n", + "DrugBAN consists of three main components: a Graph Convolutional Network (GCN) for extracting structural features from drug molecular graphs, a Convolutional Neural Network (CNN) for encoding protein sequences, and a Bilinear Attention Network (BAN) for fusing drug and protein features. The fused representation is then passed through a Multi-Layer Perceptron (MLP) classifier to predict interaction scores.\n", + "\n", + "We define the DrugBAN class in `kale.embed.ban`." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "1c8f3acc", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "1c8f3acc", + "outputId": "65f82225-7219-4391-abf2-2194a00c3af0" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DrugBAN(\n", + " (drug_extractor): MolecularGCN(\n", + " (init_transform): Linear(in_features=7, out_features=128, bias=False)\n", + " (gcn_layers): ModuleList(\n", + " (0-2): 3 x GCNConv(128, 128)\n", + " )\n", + " )\n", + " (protein_extractor): ProteinCNN(\n", + " (embedding): Embedding(26, 128, padding_idx=0)\n", + " (conv1): Conv1d(128, 128, kernel_size=(3,), stride=(1,))\n", + " (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (conv2): Conv1d(128, 128, kernel_size=(6,), stride=(1,))\n", + " (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (conv3): Conv1d(128, 128, kernel_size=(9,), stride=(1,))\n", + " (bn3): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (bcn): BANLayer(\n", + " (v_net): FCNet(\n", + " (main): Sequential(\n", + " (0): Dropout(p=0.2, inplace=False)\n", + " (1): Linear(in_features=128, out_features=768, bias=True)\n", + " (2): ReLU()\n", + " )\n", + " )\n", + " (q_net): FCNet(\n", + " (main): Sequential(\n", + " (0): Dropout(p=0.2, inplace=False)\n", + " (1): Linear(in_features=128, out_features=768, bias=True)\n", + " (2): ReLU()\n", + " )\n", + " )\n", + " (p_net): AvgPool1d(kernel_size=(3,), stride=(3,), padding=(0,))\n", + " (bn): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (mlp_classifier): MLPDecoder(\n", + " (fc1): Linear(in_features=256, out_features=512, bias=True)\n", + " (bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (fc2): Linear(in_features=512, out_features=512, bias=True)\n", + " (bn2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (fc3): Linear(in_features=512, out_features=128, bias=True)\n", + " (bn3): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (fc4): Linear(in_features=128, out_features=2, bias=True)\n", + " )\n", + ")\n" + ] + } + ], + "source": [ + "from kale.embed.ban import DrugBAN\n", + "\n", + "model = DrugBAN(**cfg)\n", + "print(model)" + ] + }, + { + "cell_type": "markdown", + "id": "32084f24", + "metadata": { + "id": "32084f24" + }, + "source": [ + "### Predict\n", + "We use the PyKale pipeline API `kale.pipeline.drugban_trainer` to connect dataloaders, encoders and outcoders for model training and evaluation." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "46e2b9b4", + "metadata": { + "id": "46e2b9b4" + }, + "outputs": [], + "source": [ + "from kale.pipeline.drugban_trainer import DrugbanTrainer\n", + "\n", + "drugban_trainer = DrugbanTrainer(\n", + " model=DrugBAN(**cfg),\n", + " solver_lr=cfg.SOLVER.LEARNING_RATE,\n", + " num_classes=cfg.DECODER.BINARY,\n", + " batch_size=cfg.SOLVER.BATCH_SIZE,\n", + " is_da=cfg.DA.USE,\n", + " solver_da_lr=cfg.SOLVER.DA_LEARNING_RATE,\n", + " da_init_epoch=cfg.DA.INIT_EPOCH,\n", + " da_method=cfg.DA.METHOD,\n", + " original_random=cfg.DA.ORIGINAL_RANDOM,\n", + " use_da_entropy=cfg.DA.USE_ENTROPY,\n", + " da_random_layer=cfg.DA.RANDOM_LAYER,\n", + " da_random_dim=cfg.DA.RANDOM_DIM,\n", + " decoder_in_dim=cfg.DECODER.IN_DIM,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "a48c86b9", + "metadata": { + "id": "a48c86b9" + }, + "source": [ + "We want to save the best model during training so we can reuse it later without needing to retrain. PyTorch Lightning’s `ModelCheckpoint` does this by automatically saving the model whenever it achieves a new best validation AUROC score." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "7754bd38", + "metadata": { + "id": "7754bd38" + }, + "outputs": [], + "source": [ + "import pytorch_lightning as pl\n", + "from pytorch_lightning.callbacks import ModelCheckpoint\n", + "\n", + "checkpoint_callback = ModelCheckpoint(\n", + " filename=\"{epoch}-{step}-{val_BinaryAUROC:.4f}\",\n", + " monitor=\"val_BinaryAUROC\",\n", + " mode=\"max\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "969beac0", + "metadata": { + "id": "969beac0" + }, + "source": [ + "We now create the `Trainer`." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "e68e07bc", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "e68e07bc", + "outputId": "08a9b744-3fb5-48c7-863f-dd64afc4dd80" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True\n", + "INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n", + "INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs\n" + ] + } + ], + "source": [ + "import torch\n", + "\n", + "trainer = pl.Trainer(\n", + " callbacks=[checkpoint_callback],\n", + " devices=\"auto\",\n", + " accelerator=\"auto\",\n", + " max_epochs=cfg.SOLVER.MAX_EPOCH,\n", + " deterministic=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "1f9a4714", + "metadata": { + "id": "1f9a4714" + }, + "source": [ + "## Step 3: Model Training" + ] + }, + { + "cell_type": "markdown", + "id": "b72634ee", + "metadata": { + "id": "b72634ee" + }, + "source": [ + "### Train\n", + "\n", + "After setting up the model and data loaders, we now start training the full DrugBAN model using the PyTorch Lightning Trainer via calling `trainer.fit()`.\n", + "\n", + "#### What Happens Here?\n", + "- The model receives batches of drug-protein pairs from the training data loader.\n", + "\n", + "- During each step, the GCN, CNN, BAN layer, and MLP classifier are updated to improve interaction prediction.\n", + "\n", + "- Validation is automatically run at the end of each epoch to track performance and save the best model based on AUROC.\n", + "\n", + "\n", + "This code block takes approximately 5 minutes to complete." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "0624b0c6", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 424, + "referenced_widgets": [ + "bc415b5a5635482eb20a65601866febf", + "98992d70c9d74d3fa794edf0dc9333b1", + "32de4bd89e034c35adc198247950d4bf", + "12c8f4410cc74d4f822c779c724bce94", + "fe6bdb21df9c4415b9899aaa96969502", + "65bb75f11b1f4064b715a166bed1e215", + "a893676474004fb0a24023403f5acf46", + "941d8002127f499c9868abffea2a2429", + "2fc478f35c3e48f2b0c13bd9a73f3dc6", + "79368bfcf5cc4067a59a46c242a77e2d", + "8632ffdd1b5844f183cc28e824cd117e", + "9a8deecaaef543fbb18d0f50b83b5abc", + "b113d7ba3f50417fb93de1118b4e4dec", + "4b0bcd88167b469da36e8433a4d47377", + "1ed13da64943461ab42ab18495a6246b", + "7d4fd3d9c5ed4cc6af20be7384a758ac", + "635b3ed7aa264fa6907187152359a63f", + "9d5ebfa060ac49d6bd7a8c4837b4fc29", + "a80aad59381c42edaa2adb89d781ddd6", + "b71c219972474c2c89ed04fa48ec637d", + "8f16c1ee593d406c88c9bfcf32fdd3df", + "16cdd1f5e14e405490575177c52f4408", + "07d966a71f604f63b00707e6d3a0bfe6", + "37f06572693f4972b468c3997fd0687c", + "a0cec4295099427bb8e97229bd49d620", + "931700f44cb0491ca187cbc58dc67476", + "dd7ad2a05e22470ab00e2118ff994147", + "1aa1b891fe6b447586d3e87ee62768a2", + "8618e343c4f1435ebf099bd70605606b", + "33a0e7cd7f7d49ffa57e962ca2bf0b66", + "f223529cf08b4235b8e74ad1049c5a60", + "304a1cadd7c048a2a34568301fb1c4dd", + "e809f1b951f740ddabc3a3e56d4dd903", + "f60cfca35cab44e5801bbb2b04f9cc6f", + "2f468dcdec6d4c6db229fcab061fd0d8", + "530bec1df42241a9be6a34c561be4a36", + "700bde0d67f64f9d8680c483d232fa5b", + "abf0b363155c4b1fb8de818ae41606eb", + "301141b0856941cba6a0a1f496267c4e", + "b4f64adc9af34262bb35ea42f0f58899", + "3c34cab162424e51b4dcf0008c39c7a0", + "b2e51980648f49f2963ba66745b82b57", + "b9b44b6940884c359ebb60834280b401", + "807b512710c848eabb6816cf8e9ecc75" + ] + }, + "id": "0624b0c6", + "outputId": "8a7fb16d-707f-4a4e-edd8-dec451f5fa5f" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "INFO:pytorch_lightning.callbacks.model_summary:\n", + " | Name | Type | Params | Mode \n", + "---------------------------------------------------------------------\n", + "0 | model | DrugBAN | 1.0 M | train\n", + "1 | domain_discriminator | DomainNetSmallImage | 133 K | train\n", + "2 | random_layer | RandomLayer | 66.0 K | train\n", + "3 | valid_metrics | MetricCollection | 0 | train\n", + "4 | test_metrics | MetricCollection | 0 | train\n", + "---------------------------------------------------------------------\n", + "1.2 M Trainable params\n", + "0 Non-trainable params\n", + "1.2 M Total params\n", + "4.847 Total estimated model params size (MB)\n", + "64 Modules in train mode\n", + "0 Modules in eval mode\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "bc415b5a5635482eb20a65601866febf", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | | 0/? [00:00┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Test metric DataLoader 0 ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ test_BinaryAUROC 0.48449382185935974 │\n", + "│ test_BinaryAccuracy 0.5071665048599243 │\n", + "│ test_BinaryF1Score 0.0933062881231308 │\n", + "│ test_BinaryRecall 0.050549451261758804 │\n", + "│ test_BinarySpecificity 0.9668141603469849 │\n", + "│ test_accuracy_sklearn 0.5038588643074036 │\n", + "│ test_auroc_sklearn 0.48449382185935974 │\n", + "│ test_f1_sklearn 0.6671618223190308 │\n", + "│ test_loss 0.8901852369308472 │\n", + "│ test_optim_threshold 0.07649494707584381 │\n", + "│ test_sensitivity 0.006637168116867542 │\n", + "│ test_specificity 0.997802197933197 │\n", + "└───────────────────────────┴───────────────────────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36m test_BinaryAUROC \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.48449382185935974 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_BinaryAccuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.5071665048599243 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_BinaryF1Score \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.0933062881231308 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_BinaryRecall \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.050549451261758804 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_BinarySpecificity \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9668141603469849 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_accuracy_sklearn \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.5038588643074036 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_auroc_sklearn \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.48449382185935974 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_f1_sklearn \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.6671618223190308 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.8901852369308472 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_optim_threshold \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.07649494707584381 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_sensitivity \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.006637168116867542 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_specificity \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.997802197933197 \u001b[0m\u001b[35m \u001b[0m│\n", + "└───────────────────────────┴───────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "[{'test_loss': 0.8901852369308472,\n", + " 'test_auroc_sklearn': 0.48449382185935974,\n", + " 'test_accuracy_sklearn': 0.5038588643074036,\n", + " 'test_f1_sklearn': 0.6671618223190308,\n", + " 'test_optim_threshold': 0.07649494707584381,\n", + " 'test_sensitivity': 0.006637168116867542,\n", + " 'test_specificity': 0.997802197933197,\n", + " 'test_BinaryAUROC': 0.48449382185935974,\n", + " 'test_BinaryF1Score': 0.0933062881231308,\n", + " 'test_BinaryRecall': 0.050549451261758804,\n", + " 'test_BinarySpecificity': 0.9668141603469849,\n", + " 'test_BinaryAccuracy': 0.5071665048599243}]" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.test(drugban_trainer, dataloaders=test_generator, ckpt_path=\"best\")" + ] + }, + { + "cell_type": "markdown", + "id": "bb0a08bec91d2bd9", + "metadata": { + "id": "bb0a08bec91d2bd9" + }, + "source": [ + "### Performance Comparison\n", + "\n", + "The earlier example was a simple demonstration. To properly evaluate DrugBAN against baseline models, we train it for 100 epochs across multiple random seeds.\n", + "\n", + "We provide a checkpoint trained for 100 epochs in the `checkpoint` for your test after the tutorial. We will also use the provided checkpoint for the interpretation section for a better visualization.\n" + ] + }, + { + "cell_type": "markdown", + "id": "37dbe9f3", + "metadata": { + "id": "37dbe9f3" + }, + "source": [ + "The figure below shows the performance of different models on the BioSNAP and BindingDB datasets:\n", + "- Left plot: AUROC (Area Under the ROC Curve)\n", + "- Right plot: AUPRC (Area Under the Precision–Recall Curve)\n", + "\n", + "![](https://media.springernature.com/full/springer-static/image/art%3A10.1038%2Fs42256-022-00605-1/MediaObjects/42256_2022_605_Fig3_HTML.png?as=webp)\n", + "\n", + "The box plots show the median as the centre lines and the mean as green triangles. The minima and lower percentile represent the worst and second-worst scores. The maxima and upper percentile indicate the best and second-best scores. Supplementary Table 2 provides the data statistics of the BindingDB and BioSNAP datasets." + ] + }, + { + "cell_type": "markdown", + "id": "02e3c73e", + "metadata": { + "id": "02e3c73e" + }, + "source": [ + "## Step 5: Interpretation\n", + "\n", + "We interpret the trained models by analyzing the learned attention weights. In this step, we will use PyKale's API to\n", + "1) draw the attention maps of the Bilinear Attention Network (BAN) layer, and\n", + "2) generate molecule images with attention highlights.\n", + "\n", + "This helps us understand which parts of the drug contribute to the interaction with the target protein." + ] + }, + { + "cell_type": "markdown", + "id": "4a56f260141b7368", + "metadata": { + "id": "4a56f260141b7368" + }, + "source": [ + "### Extracting Attention Weights\n", + "First, we need to load the test dataset and create a DataLoader for it. This will allow us to process the test samples in batches. We define functions to create the test dataset and DataLoader." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "2c67553408592b2", + "metadata": { + "id": "2c67553408592b2" + }, + "outputs": [], + "source": [ + "def get_test_dataset(dataFolder):\n", + " df_test_target = pd.read_csv(dataFolder)\n", + " test_target_dataset = DTIDataset(df_test_target.index.values, df_test_target)\n", + " return test_target_dataset\n", + "\n", + "\n", + "def get_test_dataloader(dataset, batchsize, num_workers, collate_fn):\n", + " test_dataloader = DataLoader(\n", + " dataset,\n", + " batch_size=batchsize,\n", + " num_workers=num_workers,\n", + " collate_fn=collate_fn,\n", + " shuffle=False,\n", + " drop_last=True,\n", + " )\n", + " return test_dataloader" + ] + }, + { + "cell_type": "markdown", + "id": "ecdab66ee05da10c", + "metadata": { + "id": "ecdab66ee05da10c" + }, + "source": [ + "We load a small subset of samples for testing from the provided `.csv` file. You can create your own `.csv` file with the same format to test your drug–protein pairs." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "7ef1867541d2577a", + "metadata": { + "id": "7ef1867541d2577a" + }, + "outputs": [], + "source": [ + "test_dataFolder = \"/content/drug-target-interaction/data/drug-target-interaction/bindingdb/interpretation_samples.csv\"" + ] + }, + { + "cell_type": "markdown", + "id": "7fec5dc00a7b4aa4", + "metadata": { + "id": "7fec5dc00a7b4aa4" + }, + "source": [ + "We then build the test dataset and DataLoader using the functions defined above. The `batchsize` is set to 1 to ensure we process one sample at a time for attention visualization later." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "c99a558c96a1ffd", + "metadata": { + "id": "c99a558c96a1ffd" + }, + "outputs": [], + "source": [ + "test_dataset = get_test_dataset(test_dataFolder)\n", + "test_dataloader = get_test_dataloader(\n", + " test_dataset,\n", + " batchsize=1,\n", + " num_workers=cfg.SOLVER.NUM_WORKERS,\n", + " collate_fn=graph_collate_func,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "e1ff543d132abc42", + "metadata": { + "id": "e1ff543d132abc42" + }, + "source": [ + "Then, we use the following function to load the trained model with the PyKale API." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "3b7f12b12b139799", + "metadata": { + "id": "3b7f12b12b139799" + }, + "outputs": [], + "source": [ + "def get_model_from_ckpt(ckpt_path, config):\n", + " return DrugbanTrainer.load_from_checkpoint(\n", + " checkpoint_path=ckpt_path,\n", + " model=DrugBAN(**config),\n", + " solver_lr=config.SOLVER.LEARNING_RATE,\n", + " num_classes=config.DECODER.BINARY,\n", + " batch_size=config.SOLVER.BATCH_SIZE,\n", + " # --- domain adaptation parameters ---\n", + " is_da=config.DA.USE,\n", + " solver_da_lr=config.SOLVER.DA_LEARNING_RATE,\n", + " da_init_epoch=config.DA.INIT_EPOCH,\n", + " da_method=config.DA.METHOD,\n", + " original_random=config.DA.ORIGINAL_RANDOM,\n", + " use_da_entropy=config.DA.USE_ENTROPY,\n", + " da_random_layer=config.DA.RANDOM_LAYER,\n", + " # --- discriminator parameters ---\n", + " da_random_dim=config.DA.RANDOM_DIM,\n", + " decoder_in_dim=config.DECODER.IN_DIM,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "c0678dddcdf076fc", + "metadata": { + "id": "c0678dddcdf076fc" + }, + "source": [ + "Once the model and test data are prepared, we extract attention maps from the trained model. We set the directory to the provided checkpoint file, load the trained model, and set it to evaluation mode." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "d2a8931099b73c01", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "d2a8931099b73c01", + "outputId": "1dbad556-912d-44f3-88ad-059cf7876a36" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "DrugBAN(\n", + " (drug_extractor): MolecularGCN(\n", + " (init_transform): Linear(in_features=7, out_features=128, bias=False)\n", + " (gcn_layers): ModuleList(\n", + " (0-2): 3 x GCNConv(128, 128)\n", + " )\n", + " )\n", + " (protein_extractor): ProteinCNN(\n", + " (embedding): Embedding(26, 128, padding_idx=0)\n", + " (conv1): Conv1d(128, 128, kernel_size=(3,), stride=(1,))\n", + " (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (conv2): Conv1d(128, 128, kernel_size=(6,), stride=(1,))\n", + " (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (conv3): Conv1d(128, 128, kernel_size=(9,), stride=(1,))\n", + " (bn3): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (bcn): BANLayer(\n", + " (v_net): FCNet(\n", + " (main): Sequential(\n", + " (0): Dropout(p=0.2, inplace=False)\n", + " (1): Linear(in_features=128, out_features=768, bias=True)\n", + " (2): ReLU()\n", + " )\n", + " )\n", + " (q_net): FCNet(\n", + " (main): Sequential(\n", + " (0): Dropout(p=0.2, inplace=False)\n", + " (1): Linear(in_features=128, out_features=768, bias=True)\n", + " (2): ReLU()\n", + " )\n", + " )\n", + " (p_net): AvgPool1d(kernel_size=(3,), stride=(3,), padding=(0,))\n", + " (bn): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (mlp_classifier): MLPDecoder(\n", + " (fc1): Linear(in_features=256, out_features=512, bias=True)\n", + " (bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (fc2): Linear(in_features=512, out_features=512, bias=True)\n", + " (bn2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (fc3): Linear(in_features=512, out_features=128, bias=True)\n", + " (bn3): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (fc4): Linear(in_features=128, out_features=2, bias=True)\n", + " )\n", + ")" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "checkpoint_path = \"/content/drug-target-interaction/checkpoint/best.ckpt\"\n", + "model = get_model_from_ckpt(checkpoint_path, cfg)\n", + "model.model.eval()" + ] + }, + { + "cell_type": "markdown", + "id": "159d3fa67b29c9e9", + "metadata": { + "id": "159d3fa67b29c9e9" + }, + "source": [ + "We then iterate through the test DataLoader, passing each batch of drug and protein pairs to the model. The model's forward method returns the attention weights. After processing all batches, we concatenate the attention tensors into a single tensor." + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "781a7762c36c72be", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "781a7762c36c72be", + "outputId": "fade7fb6-bea8-438a-9688-8cffd8dc9c47" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 6/6 [00:00<00:00, 65.52it/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "torch.Size([6, 2, 290, 1185])" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from tqdm import tqdm\n", + "\n", + "all_attentions = []\n", + "for batch in tqdm(test_dataloader):\n", + " drug, protein, _ = batch\n", + " drug, protein = drug.to(model.device), protein.to(model.device)\n", + "\n", + " _, _, _, _, attention = model.model.forward(\n", + " drug, protein, mode=\"eval\"\n", + " ) # [B, H, V, Q]\n", + "\n", + " attention = attention.detach().cpu()\n", + " all_attentions.append(attention)\n", + "\n", + "# Concatenate into one tensor: [N, H, V, Q]\n", + "all_attentions = torch.cat(all_attentions, dim=0)\n", + "torch.save(all_attentions, \"attention_maps.pt\")\n", + "\n", + "all_attentions.shape" + ] + }, + { + "cell_type": "markdown", + "id": "78dc763b6c0eef0", + "metadata": { + "id": "78dc763b6c0eef0" + }, + "source": [ + "The attention has shape [B, H, V, Q] (Number of drug-target pairs, Heads of attentions, Drug tokens, Protein tokens)." + ] + }, + { + "cell_type": "markdown", + "id": "8f72ea4d93f640cb", + "metadata": { + "id": "8f72ea4d93f640cb" + }, + "source": [ + "### Visualize Attention Maps and Molecule Images" + ] + }, + { + "cell_type": "markdown", + "id": "383c342a7c31d7ae", + "metadata": { + "id": "383c342a7c31d7ae" + }, + "source": [ + "Once attention maps are saved, run the visualization script:" + ] + }, + { + "cell_type": "markdown", + "id": "d8a746169def8da5", + "metadata": { + "id": "d8a746169def8da5" + }, + "source": [ + "This script will:\n", + "\n", + "1) Load the attention weights and the corresponding SMILES + protein data.\n", + "\n", + "2) Plot:\n", + "\n", + " a) A heatmap of attention over drug–protein tokens.\n", + "\n", + " b) Molecular structures with atoms highlighted by attention values.\n", + "\n", + "The output images are saved in the `visualization` directory. You can also modify the `data_file` to use your own input in the same format as `target_test.csv`.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "aac54bfc67ce32eb", + "metadata": { + "id": "aac54bfc67ce32eb" + }, + "source": [ + "We first import the necessary PyKale APIs and set the output directory." + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "d3c1d2e4cab69107", + "metadata": { + "id": "d3c1d2e4cab69107" + }, + "outputs": [], + "source": [ + "from kale.interpret.visualize import draw_attention_map, draw_mol_with_attention\n", + "from kale.prepdata.tensor_reshape import normalize_tensor\n", + "\n", + "out_dir = \"./visualization\"\n", + "os.makedirs(out_dir, exist_ok=True)" + ] + }, + { + "cell_type": "markdown", + "id": "126b62034111d92a", + "metadata": { + "id": "126b62034111d92a" + }, + "source": [ + "We then load the attention maps, data, and SMILES strings from the test dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "7f70a6810c1c5e60", + "metadata": { + "id": "7f70a6810c1c5e60" + }, + "outputs": [], + "source": [ + "attention = torch.load(\"attention_maps.pt\", map_location=\"cpu\")\n", + "data_df = pd.read_csv(test_dataFolder)\n", + "smiles = data_df[\"SMILES\"]\n", + "proteins = data_df[\"Protein\"]" + ] + }, + { + "cell_type": "markdown", + "id": "d1a009bbb9f4a0f9", + "metadata": { + "id": "d1a009bbb9f4a0f9" + }, + "source": [ + "We select the first sample from the attention maps and corresponding SMILES and protein sequence for visualization." + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "e808c255fe862925", + "metadata": { + "id": "e808c255fe862925" + }, + "outputs": [], + "source": [ + "index = 0\n", + "att_path = os.path.join(out_dir, f\"att_map_{index}.png\")\n", + "mol_path = os.path.join(out_dir, f\"mol_{index}.svg\")" + ] + }, + { + "cell_type": "markdown", + "id": "438e6aa218e6b51d", + "metadata": { + "id": "438e6aa218e6b51d" + }, + "source": [ + "We crop the attention map to the actual lengths of the drug and protein sequences. This is important because the attention map may include padding tokens." + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "af15baa1c8caabc0", + "metadata": { + "id": "af15baa1c8caabc0" + }, + "outputs": [], + "source": [ + "from rdkit import Chem\n", + "\n", + "\n", + "def get_real_length(smile, protein_sequence):\n", + " \"\"\"Get the real length of the drug and protein sequences.\"\"\"\n", + " mol = Chem.MolFromSmiles(smile)\n", + " return mol.GetNumAtoms(), len(protein_sequence)\n", + "\n", + "\n", + "att = attention[index] # [H, V, Q]\n", + "smile = smiles[index]\n", + "protein = proteins[index]\n", + "real_drug_len, real_prot_len = get_real_length(smile, protein)\n", + "att = att[:, :real_drug_len, :real_prot_len].mean(0) # [V, Q]\n", + "\n", + "# Normalize\n", + "att = normalize_tensor(att)" + ] + }, + { + "cell_type": "markdown", + "id": "60a4ce71146a721e", + "metadata": { + "id": "60a4ce71146a721e" + }, + "source": [ + "Finally, we save the attention map and the molecule image with attention highlights." + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "403f77ada0ecc446", + "metadata": { + "id": "403f77ada0ecc446" + }, + "outputs": [], + "source": [ + "draw_attention_map(\n", + " att,\n", + " att_path,\n", + " title=f\"Drug {index} Attention\",\n", + " xlabel=\"Drug Tokens\",\n", + " ylabel=\"Protein Tokens\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "b1003372361a66d6", + "metadata": { + "id": "b1003372361a66d6" + }, + "outputs": [], + "source": [ + "draw_mol_with_attention(att.mean(dim=1), smile, mol_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "4mHWCbJmGMgG", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 921 + }, + "id": "4mHWCbJmGMgG", + "outputId": "5b3f1960-85a5-4ca3-bbbb-f48f6bc67401", + "tags": [ + "hide-input" + ] + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + " \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from IPython.display import Image, SVG\n", + "\n", + "attention_plot = Image(att_path)\n", + "molecular_img = SVG(mol_path)\n", + "display(attention_plot)\n", + "display(molecular_img)" + ] + }, + { + "cell_type": "markdown", + "id": "1999d67d6d14b263", + "metadata": { + "id": "1999d67d6d14b263" + }, + "source": [ + "The output images are saved in the `visualization` directory. The attention map shows how much each drug token attends to each protein token, while the molecule image highlights the atoms based on their attention values." + ] + }, + { + "cell_type": "markdown", + "id": "eeb308c3", + "metadata": { + "id": "eeb308c3" + }, + "source": [ + "## Extension Tasks" + ] + }, + { + "cell_type": "markdown", + "id": "aa2a83d8", + "metadata": { + "id": "aa2a83d8" + }, + "source": [ + "### Task 1\n", + "\n", + "To use the BindingDB dataset, modify the relevant line in the Configuration section of Step 0 as shown below.\n", + "\n", + "```python\n", + "cfg.DATA.DATASET = \"bindingdb\"\n", + "```\n", + "\n", + "Reload the dataset and re-run training and testing.\n", + "\n", + "> Tip: See if the model struggles more or less with the new dataset. It can reveal how generalisable DrugBAN is.\n" + ] + }, + { + "cell_type": "markdown", + "id": "c94f174c", + "metadata": { + "id": "c94f174c" + }, + "source": [ + "### Task 2\n", + "\n", + "Turn off domain adaptation by updating the config file and re-running training and testing.\n", + "\n", + "Replace `configs/DA_cross_domain.yaml` with `configs/non_DA_cross_domain.yaml` in the Configuration section of Step 0 as shown below.\n", + "\n", + "```python\n", + "cfg.merge_from_file(\"configs/non_DA_cross_domain.yaml\")\n", + "```\n", + ">Tip: Compare the results with and without domain adaptation to see how it affects model performance." + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "07d966a71f604f63b00707e6d3a0bfe6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_37f06572693f4972b468c3997fd0687c", + "IPY_MODEL_a0cec4295099427bb8e97229bd49d620", + "IPY_MODEL_931700f44cb0491ca187cbc58dc67476" + ], + "layout": "IPY_MODEL_dd7ad2a05e22470ab00e2118ff994147" + } + }, + "12c8f4410cc74d4f822c779c724bce94": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_79368bfcf5cc4067a59a46c242a77e2d", + "placeholder": "​", + "style": "IPY_MODEL_8632ffdd1b5844f183cc28e824cd117e", + "value": " 2/2 [00:01<00:00,  1.07it/s]" + } + }, + "16cdd1f5e14e405490575177c52f4408": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "1aa1b891fe6b447586d3e87ee62768a2": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1ed13da64943461ab42ab18495a6246b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_8f16c1ee593d406c88c9bfcf32fdd3df", + "placeholder": "​", + "style": "IPY_MODEL_16cdd1f5e14e405490575177c52f4408", + "value": " 305/305 [01:07<00:00,  4.52it/s, v_num=2]" + } + }, + "2439bd152f3f40a190e80cb5156a8be7": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "2f468dcdec6d4c6db229fcab061fd0d8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_301141b0856941cba6a0a1f496267c4e", + "placeholder": "​", + "style": "IPY_MODEL_b4f64adc9af34262bb35ea42f0f58899", + "value": "Validation DataLoader 0: 100%" + } + }, + "2fc478f35c3e48f2b0c13bd9a73f3dc6": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "301141b0856941cba6a0a1f496267c4e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "304a1cadd7c048a2a34568301fb1c4dd": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "32de4bd89e034c35adc198247950d4bf": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_941d8002127f499c9868abffea2a2429", + "max": 2, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_2fc478f35c3e48f2b0c13bd9a73f3dc6", + "value": 2 + } + }, + "33a0e7cd7f7d49ffa57e962ca2bf0b66": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": "2", + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "37f06572693f4972b468c3997fd0687c": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_1aa1b891fe6b447586d3e87ee62768a2", + "placeholder": "​", + "style": "IPY_MODEL_8618e343c4f1435ebf099bd70605606b", + "value": "Validation DataLoader 0: 100%" + } + }, + "3c34cab162424e51b4dcf0008c39c7a0": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": "2", + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "4b0bcd88167b469da36e8433a4d47377": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_a80aad59381c42edaa2adb89d781ddd6", + "max": 305, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_b71c219972474c2c89ed04fa48ec637d", + "value": 305 + } + }, + "530bec1df42241a9be6a34c561be4a36": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_3c34cab162424e51b4dcf0008c39c7a0", + "max": 29, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_b2e51980648f49f2963ba66745b82b57", + "value": 29 + } + }, + "5436ad6633f7491898361cb877e5c96a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "635b3ed7aa264fa6907187152359a63f": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "64151477b0f4444b852363f72540296e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "65bb75f11b1f4064b715a166bed1e215": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "6692877437d14f56bdf0bc4e8793bb4b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": "inline-flex", + "flex": null, + "flex_flow": "row wrap", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": "100%" + } + }, + "700bde0d67f64f9d8680c483d232fa5b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_b9b44b6940884c359ebb60834280b401", + "placeholder": "​", + "style": "IPY_MODEL_807b512710c848eabb6816cf8e9ecc75", + "value": " 29/29 [00:02<00:00, 12.12it/s]" + } + }, + "79368bfcf5cc4067a59a46c242a77e2d": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "7d4fd3d9c5ed4cc6af20be7384a758ac": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": "inline-flex", + "flex": null, + "flex_flow": "row wrap", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": "100%" + } + }, + "807b512710c848eabb6816cf8e9ecc75": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "8618e343c4f1435ebf099bd70605606b": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "8632ffdd1b5844f183cc28e824cd117e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "8f16c1ee593d406c88c9bfcf32fdd3df": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "931700f44cb0491ca187cbc58dc67476": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_304a1cadd7c048a2a34568301fb1c4dd", + "placeholder": "​", + "style": "IPY_MODEL_e809f1b951f740ddabc3a3e56d4dd903", + "value": " 29/29 [00:02<00:00, 12.61it/s]" + } + }, + "941d8002127f499c9868abffea2a2429": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": "2", + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "98992d70c9d74d3fa794edf0dc9333b1": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_65bb75f11b1f4064b715a166bed1e215", + "placeholder": "​", + "style": "IPY_MODEL_a893676474004fb0a24023403f5acf46", + "value": "Sanity Checking DataLoader 0: 100%" + } + }, + "9a8deecaaef543fbb18d0f50b83b5abc": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_b113d7ba3f50417fb93de1118b4e4dec", + "IPY_MODEL_4b0bcd88167b469da36e8433a4d47377", + "IPY_MODEL_1ed13da64943461ab42ab18495a6246b" + ], + "layout": "IPY_MODEL_7d4fd3d9c5ed4cc6af20be7384a758ac" + } + }, + "9d5ebfa060ac49d6bd7a8c4837b4fc29": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "a0cec4295099427bb8e97229bd49d620": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_33a0e7cd7f7d49ffa57e962ca2bf0b66", + "max": 29, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_f223529cf08b4235b8e74ad1049c5a60", + "value": 29 + } + }, + "a3aa520f98b042f29f1fca44f24b61e0": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": "2", + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a80aad59381c42edaa2adb89d781ddd6": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": "2", + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "a893676474004fb0a24023403f5acf46": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "abf0b363155c4b1fb8de818ae41606eb": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": "inline-flex", + "flex": null, + "flex_flow": "row wrap", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": "100%" + } + }, + "af4b74bd9f13458c8dd9b74ad2004783": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "b113d7ba3f50417fb93de1118b4e4dec": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_635b3ed7aa264fa6907187152359a63f", + "placeholder": "​", + "style": "IPY_MODEL_9d5ebfa060ac49d6bd7a8c4837b4fc29", + "value": "Epoch 1: 100%" + } + }, + "b2e51980648f49f2963ba66745b82b57": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "b4f64adc9af34262bb35ea42f0f58899": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "b71c219972474c2c89ed04fa48ec637d": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "b9b44b6940884c359ebb60834280b401": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "bc415b5a5635482eb20a65601866febf": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_98992d70c9d74d3fa794edf0dc9333b1", + "IPY_MODEL_32de4bd89e034c35adc198247950d4bf", + "IPY_MODEL_12c8f4410cc74d4f822c779c724bce94" + ], + "layout": "IPY_MODEL_fe6bdb21df9c4415b9899aaa96969502" + } + }, + "be7949a38cfe4129abd0583b3b9c080e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_e9d78f7f6ffd4d22b7f85f65da6e54c3", + "placeholder": "​", + "style": "IPY_MODEL_5436ad6633f7491898361cb877e5c96a", + "value": " 29/29 [00:02<00:00, 10.86it/s]" + } + }, + "d10203b9631842348ce5ee323f56d8c3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_a3aa520f98b042f29f1fca44f24b61e0", + "max": 29, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_64151477b0f4444b852363f72540296e", + "value": 29 + } + }, + "d782f86576c9463eb171f866f007b9b3": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_f7b99f048d0b430b913f03382dcf6cae", + "IPY_MODEL_d10203b9631842348ce5ee323f56d8c3", + "IPY_MODEL_be7949a38cfe4129abd0583b3b9c080e" + ], + "layout": "IPY_MODEL_6692877437d14f56bdf0bc4e8793bb4b" + } + }, + "dd7ad2a05e22470ab00e2118ff994147": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": "inline-flex", + "flex": null, + "flex_flow": "row wrap", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": "100%" + } + }, + "e809f1b951f740ddabc3a3e56d4dd903": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "e9d78f7f6ffd4d22b7f85f65da6e54c3": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f223529cf08b4235b8e74ad1049c5a60": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "f60cfca35cab44e5801bbb2b04f9cc6f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_2f468dcdec6d4c6db229fcab061fd0d8", + "IPY_MODEL_530bec1df42241a9be6a34c561be4a36", + "IPY_MODEL_700bde0d67f64f9d8680c483d232fa5b" + ], + "layout": "IPY_MODEL_abf0b363155c4b1fb8de818ae41606eb" + } + }, + "f7b99f048d0b430b913f03382dcf6cae": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_2439bd152f3f40a190e80cb5156a8be7", + "placeholder": "​", + "style": "IPY_MODEL_af4b74bd9f13458c8dd9b74ad2004783", + "value": "Testing DataLoader 0: 100%" + } + }, + "fe6bdb21df9c4415b9899aaa96969502": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": "inline-flex", + "flex": null, + "flex_flow": "row wrap", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": "hidden", + "width": "100%" + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 } From cc3ecbf5ffdbd08010ff456721a9e90c79d2ad7a Mon Sep 17 00:00:00 2001 From: "L. M. Riza Rizky" <42672299+zaRizk7@users.noreply.github.com> Date: Mon, 14 Jul 2025 19:13:50 +0100 Subject: [PATCH 2/3] pre-commit --- .../tutorial-drug.ipynb | 3364 ++++------------- 1 file changed, 663 insertions(+), 2701 deletions(-) diff --git a/tutorials/drug-target-interaction/tutorial-drug.ipynb b/tutorials/drug-target-interaction/tutorial-drug.ipynb index 9f22516..bce1839 100644 --- a/tutorials/drug-target-interaction/tutorial-drug.ipynb +++ b/tutorials/drug-target-interaction/tutorial-drug.ipynb @@ -1,34 +1,36 @@ { + "nbformat": 4, + "nbformat_minor": 5, + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, "cells": [ { - "cell_type": "markdown", - "id": "8c1bf9c7", - "metadata": { - "id": "8c1bf9c7" - }, + "metadata": {}, "source": [ - "# Drug–Target Interaction Prediction\n", + "# Drug\u2013Target Interaction Prediction\n", "\n", "![](https://github.com/pykale/mmai-tutorials/blob/main/tutorials/drug-target-interaction/images/drugban-pyakle-api.png?raw=1)\n", "\n", "\n", "In this tutorial, we will train models to predict the interaction between **two data modalities**: **molecules (drug)** and **proteins (target)** using `PyKale`. Drug-target interaction (DTI) plays a key role in drug discovery and identifying potential therapeutic targets. This example is based on the **DrugBAN** framework by [**Bai et al. (_Nature Machine Intelligence_, 2023)**](https://www.nature.com/articles/s42256-022-00605-1).\n", "\n", - "The DTI prediction problem is formulated as a **binary classification task**, where the goal is to predict whether a given **drug–protein pair interacts or not**. The DrugBAN framework tackles this problem using two key ideas:\n", + "The DTI prediction problem is formulated as a **binary classification task**, where the goal is to predict whether a given **drug\u2013protein pair interacts or not**. The DrugBAN framework tackles this problem using two key ideas:\n", "\n", "- **Bilinear Attention Network (BAN)**, which learns detailed feature representations for both drugs and proteins and captures local interaction patterns between them.\n", "\n", - "- **Adversarial Domain Adaptation**, which helps the model generalise to out-of-distribution datasets, i.e., in clustering-based cross-validation instead of random splits, improving its ability to predict interactions on unseen drug–target pairs.\n", + "- **Adversarial Domain Adaptation**, which helps the model generalise to out-of-distribution datasets, i.e., in clustering-based cross-validation instead of random splits, improving its ability to predict interactions on unseen drug\u2013target pairs.\n", "\n", "With `PyKale`, implementing such a multimodal DTI prediction pipeline is straightforward. The library provides ready-to-use modules and configuration support, making it easy to apply advanced techniques with minimal custom coding." - ] + ], + "cell_type": "markdown", + "id": "8c1bf9c7" }, { - "cell_type": "markdown", - "id": "745ccdcf", - "metadata": { - "id": "745ccdcf" - }, + "metadata": {}, "source": [ "## Step 0: Environment Preparation\n", "\n", @@ -38,23 +40,27 @@ "\n", "Moreover, we provide helper functions that can be inspected directly in the `.py` files located in the notebook's current directory. The additional helper script is:\n", "- [`config.py`](https://github.com/pykale/embc-mmai25/blob/main/tutorials/drug-target-interaction/configs.py): Defines the base configuration settings, which can be overridden using a custom `.yaml` file." - ] + ], + "cell_type": "markdown", + "id": "745ccdcf" }, { + "metadata": {}, + "source": [ + "import os\n", + "\n", + "!rm -rf /content/mmai-tutorials\n", + "!git clone --single-branch -b main https://github.com/pykale/mmai-tutorials.git\n", + "%mv /content/mmai-tutorials/tutorials/drug-target-interaction /content/\n", + "%cd /content/drug-target-interaction\n", + "\n", + "print(\"Changed working directory to:\", os.getcwd())" + ], "cell_type": "code", - "execution_count": 1, - "id": "a6028209", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "a6028209", - "outputId": "bdba9c0d-e5a6-4dba-981d-b12e21d2c463" - }, "outputs": [ { - "name": "stdout", "output_type": "stream", + "name": "stdout", "text": [ "Cloning into 'mmai-tutorials'...\n", "remote: Enumerating objects: 610, done.\u001b[K\n", @@ -69,23 +75,11 @@ ] } ], - "source": [ - "import os\n", - "\n", - "!rm -rf /content/mmai-tutorials\n", - "!git clone --single-branch -b main https://github.com/pykale/mmai-tutorials.git\n", - "%mv /content/mmai-tutorials/tutorials/drug-target-interaction /content/\n", - "%cd /content/drug-target-interaction\n", - "\n", - "print(\"Changed working directory to:\", os.getcwd())" - ] + "id": "a6028209", + "execution_count": null }, { - "cell_type": "markdown", - "id": "c52c6334", - "metadata": { - "id": "c52c6334" - }, + "metadata": {}, "source": [ "### Package Installation\n", "\n", @@ -96,134 +90,98 @@ "Then, we install `PyG` (PyTorch Geometric) and related packages.\n", "\n", "Please **do not** re-run this session after installation completed. Runing this installation multiple times will trigger issues related to `PyG`. If you want to re-run this installation, please click the `Runtime` on the top menu and choose `Disconnect and delete runtime` before installing." - ] + ], + "cell_type": "markdown", + "id": "c52c6334" }, { + "metadata": {}, + "source": [ + "%pip install --quiet \\\n", + " \"pykale[example]@git+https://github.com/pykale/pykale@main\" \\\n", + " gdown==5.2.0 torch-geometric==2.6.0 torch_sparse torch_scatter \\\n", + " -f https://data.pyg.org/whl/torch-2.6.0+cu124.html \\\n", + " && echo \"pykale, gdown, pyg and yacs installed successfully \u2705\" \\\n", + " || echo \"Failed to install pykale, gdown, pyg and yacs \u274c\"" + ], "cell_type": "code", - "execution_count": 2, - "id": "53e3b14e", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "53e3b14e", - "outputId": "7890c8c4-6c6c-4156-bdf4-55ffbf9ee2b5" - }, "outputs": [ { - "name": "stdout", "output_type": "stream", + "name": "stdout", "text": [ " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "pykale, gdown, pyg and yacs installed successfully ✅\n" + "pykale, gdown, pyg and yacs installed successfully \u2705\n" ] } ], - "source": [ - "%pip install --quiet \\\n", - " \"pykale[example]@git+https://github.com/pykale/pykale@main\" \\\n", - " gdown==5.2.0 torch-geometric==2.6.0 torch_sparse torch_scatter \\\n", - " -f https://data.pyg.org/whl/torch-2.6.0+cu124.html \\\n", - " && echo \"pykale, gdown, pyg and yacs installed successfully ✅\" \\\n", - " || echo \"Failed to install pykale, gdown, pyg and yacs ❌\"" - ] + "id": "53e3b14e", + "execution_count": null }, { - "cell_type": "markdown", - "id": "69f50b6a", - "metadata": { - "id": "69f50b6a" - }, + "metadata": {}, "source": [ "We then hide the warnings messages to get a clear output." - ] + ], + "cell_type": "markdown", + "id": "69f50b6a" }, { - "cell_type": "code", - "execution_count": 3, - "id": "6e871c63", - "metadata": { - "id": "6e871c63" - }, - "outputs": [], + "metadata": {}, "source": [ "import os\n", "import warnings\n", "\n", "warnings.filterwarnings(\"ignore\")\n", "os.environ[\"PYTHONWARNINGS\"] = \"ignore\"" - ] + ], + "cell_type": "code", + "outputs": [], + "id": "6e871c63", + "execution_count": null }, { - "cell_type": "markdown", - "id": "6606e3fb", - "metadata": { - "id": "6606e3fb" - }, + "metadata": {}, "source": [ "Exercise: Check NumPy Version" - ] + ], + "cell_type": "markdown", + "id": "6606e3fb" }, { + "metadata": {}, + "source": [ + "import numpy as np\n", + "\n", + "print(\"NumPy version:\", np.__version__) # numpy should be 2.0.0 or higher" + ], "cell_type": "code", - "execution_count": 4, - "id": "0d384020", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "0d384020", - "outputId": "d1ecfffa-1567-4b1c-dc8e-cbbeda3728c0" - }, "outputs": [ { - "name": "stdout", "output_type": "stream", + "name": "stdout", "text": [ "NumPy version: 2.0.2\n" ] } ], - "source": [ - "import numpy as np\n", - "\n", - "print(\"NumPy version:\", np.__version__) # numpy should be 2.0.0 or higher" - ] + "id": "0d384020", + "execution_count": null }, { - "cell_type": "markdown", - "id": "cabd3406", - "metadata": { - "id": "cabd3406" - }, + "metadata": {}, "source": [ "### Configuration\n", "\n", "To minimize the footprint of the notebook when specifying configurations, we provide a [`config.py`](https://github.com/pykale/embc-mmai25/blob/main/tutorials/drug-target-interaction/configs.py) file that defines default parameters. These can be customized by supplying a `.yaml` configuration file, such as [`configs/DA_cross_domain.yaml`](https://github.com/pykale/embc-mmai25/blob/main/tutorials/drug-target-interaction/configs/DA_cross_domain.yaml) as an example." - ] + ], + "cell_type": "markdown", + "id": "cabd3406" }, { - "cell_type": "code", - "execution_count": 5, - "id": "55c13b48", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "55c13b48", - "outputId": "dd3df032-9263-4bcf-d47c-3059d1f66830" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/content/drug-target-interaction\n" - ] - } - ], + "metadata": {}, "source": [ "from configs import get_cfg_defaults\n", "\n", @@ -233,81 +191,78 @@ "cfg.merge_from_file(\n", " \"configs/DA_cross_domain.yaml\"\n", ") # Update (or override) some of those settings using a custom YAML file" - ] + ], + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "/content/drug-target-interaction\n" + ] + } + ], + "id": "55c13b48", + "execution_count": null }, { - "cell_type": "markdown", - "id": "74ffdbc2", - "metadata": { - "id": "74ffdbc2" - }, + "metadata": {}, "source": [ "In this tutorial, we list the hyperparameters we would like users to play with outside the `.yaml` file:\n", "- `cfg.SOLVER.MAX_EPOCH`: Number of epochs in training stage. You can reduce the number of training epochs to shorten runtime.\n", "- `cfg.DATA.DATASET`: The dataset used in the study. This can be `bindingdb` or `biosnap`.\n", "\n", "As a quick exercise, please take a moment to review and understand the parameters in [`config.py`](https://github.com/pykale/embc-mmai25/blob/main/tutorials/drug-target-interaction/configs.py)." - ] + ], + "cell_type": "markdown", + "id": "74ffdbc2" }, { - "cell_type": "code", - "execution_count": 6, - "id": "424c7286", - "metadata": { - "id": "424c7286" - }, - "outputs": [], + "metadata": {}, "source": [ "cfg.SOLVER.MAX_EPOCH = 2" - ] + ], + "cell_type": "code", + "outputs": [], + "id": "424c7286", + "execution_count": null }, { - "cell_type": "markdown", - "id": "97c088fd", - "metadata": { - "id": "97c088fd" - }, + "metadata": {}, "source": [ "You can also switch to a different dataset." - ] + ], + "cell_type": "markdown", + "id": "97c088fd" }, { - "cell_type": "code", - "execution_count": 7, - "id": "c69376fa", - "metadata": { - "id": "c69376fa" - }, - "outputs": [], + "metadata": {}, "source": [ "cfg.DATA.DATASET = \"biosnap\"" - ] + ], + "cell_type": "code", + "outputs": [], + "id": "c69376fa", + "execution_count": null }, { - "cell_type": "markdown", - "id": "d3d41633", - "metadata": { - "id": "d3d41633" - }, + "metadata": {}, "source": [ "Exercise: Now print the full configuration to check all current hyperparameter and dataset settings." - ] + ], + "cell_type": "markdown", + "id": "d3d41633" }, { + "metadata": {}, + "source": [ + "print(cfg)" + ], "cell_type": "code", - "execution_count": 8, - "id": "45874296", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "45874296", - "outputId": "5c4738ff-85e8-463f-dac7-6e183643a223" - }, "outputs": [ { - "name": "stdout", "output_type": "stream", + "name": "stdout", "text": [ "BCN:\n", " HEADS: 2\n", @@ -359,42 +314,31 @@ ] } ], - "source": [ - "print(cfg)" - ] + "id": "45874296", + "execution_count": null }, { - "cell_type": "markdown", - "id": "17558d0c", - "metadata": { - "id": "17558d0c" - }, + "metadata": {}, "source": [ "## Step 1: Data Loading and Preparation\n", "\n", "In this tutorial, we use the **Biosnap** dataset for the main demonstration and the **BindingDB** dataset for the exercise at the end." - ] + ], + "cell_type": "markdown", + "id": "17558d0c" }, { - "cell_type": "markdown", - "id": "6c6071b9", - "metadata": { - "id": "6c6071b9" - }, + "metadata": {}, "source": [ "### Data Downloading\n", "\n", "Please run the following cell to download necessary datasets." - ] + ], + "cell_type": "markdown", + "id": "6c6071b9" }, { - "cell_type": "code", - "execution_count": 9, - "id": "56f9f58e", - "metadata": { - "id": "56f9f58e" - }, - "outputs": [], + "metadata": {}, "source": [ "!rm -rf data\n", "!mkdir data\n", @@ -404,33 +348,35 @@ "!gdown --id 1ogOcxZn-1q418LOT-gQ94aHQV0Y1sOmk --output data/drug-target-interaction.zip\n", "!unzip data/drug-target-interaction.zip -d data/\n", "!mv data/drug-target-interaction/checkpoint ./" - ] + ], + "cell_type": "code", + "outputs": [], + "id": "56f9f58e", + "execution_count": null }, { - "cell_type": "markdown", - "id": "c39b3e39", - "metadata": { - "id": "c39b3e39" - }, + "metadata": {}, "source": [ "Exercise: Check the data is ready" - ] + ], + "cell_type": "markdown", + "id": "c39b3e39" }, { + "metadata": {}, + "source": [ + "import os\n", + "import shutil\n", + "\n", + "print(\"Contents of the data folder:\")\n", + "for item in os.listdir(\"data/drug-target-interaction\"):\n", + " print(item)" + ], "cell_type": "code", - "execution_count": 10, - "id": "a6258d1f", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "a6258d1f", - "outputId": "2813ad1f-5971-44ca-f9b6-c90a496fe4b6" - }, "outputs": [ { - "name": "stdout", "output_type": "stream", + "name": "stdout", "text": [ "Contents of the data folder:\n", "biosnap\n", @@ -439,55 +385,42 @@ ] } ], - "source": [ - "import os\n", - "import shutil\n", - "\n", - "print(\"Contents of the data folder:\")\n", - "for item in os.listdir(\"data/drug-target-interaction\"):\n", - " print(item)" - ] + "id": "a6258d1f", + "execution_count": null }, { - "cell_type": "markdown", - "id": "9ab0b5f833dc40f8", - "metadata": { - "id": "9ab0b5f833dc40f8" - }, + "metadata": {}, "source": [ "The data content is structured as follows:\n", "```sh\n", - " ├───data\n", - " │ ├───checkpoint\n", - " │ ├───bindingdb\n", - " │ ├───biosnap" - ] + " \u251c\u2500\u2500\u2500data\n", + " \u2502 \u251c\u2500\u2500\u2500checkpoint\n", + " \u2502 \u251c\u2500\u2500\u2500bindingdb\n", + " \u2502 \u251c\u2500\u2500\u2500biosnap" + ], + "cell_type": "markdown", + "id": "9ab0b5f833dc40f8" }, { - "cell_type": "markdown", - "id": "5be1dcc62b7d5649", - "metadata": { - "id": "5be1dcc62b7d5649" - }, + "metadata": {}, "source": [ "The `data` folder contains two datasets: `bindingdb` and `biosnap`. Each dataset folder contains the following files. The `checkpoint` folder contains the saved model checkpoint, which are used later in the interpretation section." - ] + ], + "cell_type": "markdown", + "id": "5be1dcc62b7d5649" }, { + "metadata": {}, + "source": [ + "print(\"Contents of bindingdb folder:\")\n", + "for item in os.listdir(\"data/drug-target-interaction/bindingdb\"):\n", + " print(item)" + ], "cell_type": "code", - "execution_count": 11, - "id": "a93303c51c8b974e", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "a93303c51c8b974e", - "outputId": "93e0c029-b489-4f4d-8889-8b99f868039c" - }, "outputs": [ { - "name": "stdout", "output_type": "stream", + "name": "stdout", "text": [ "Contents of bindingdb folder:\n", "random\n", @@ -497,77 +430,64 @@ ] } ], - "source": [ - "print(\"Contents of bindingdb folder:\")\n", - "for item in os.listdir(\"data/drug-target-interaction/bindingdb\"):\n", - " print(item)" - ] + "id": "a93303c51c8b974e", + "execution_count": null }, { - "cell_type": "markdown", - "id": "79cbc1c1", - "metadata": { - "id": "79cbc1c1" - }, + "metadata": {}, "source": [ "Each dataset folder follows the structure:\n", "\n", "```sh\n", - " ├───dataset_name\n", - " │ ├───cluster\n", - " │ │ ├───source_train.csv\n", - " │ │ ├───target_train.csv\n", - " │ │ ├───target_test.csv\n", - " │ ├───random\n", - " │ │ ├───test.csv\n", - " │ │ ├───train.csv\n", - " │ │ ├───val.csv\n", - " │ ├───full.csv\n", + " \u251c\u2500\u2500\u2500dataset_name\n", + " \u2502 \u251c\u2500\u2500\u2500cluster\n", + " \u2502 \u2502 \u251c\u2500\u2500\u2500source_train.csv\n", + " \u2502 \u2502 \u251c\u2500\u2500\u2500target_train.csv\n", + " \u2502 \u2502 \u251c\u2500\u2500\u2500target_test.csv\n", + " \u2502 \u251c\u2500\u2500\u2500random\n", + " \u2502 \u2502 \u251c\u2500\u2500\u2500test.csv\n", + " \u2502 \u2502 \u251c\u2500\u2500\u2500train.csv\n", + " \u2502 \u2502 \u251c\u2500\u2500\u2500val.csv\n", + " \u2502 \u251c\u2500\u2500\u2500full.csv\n", "```" - ] + ], + "cell_type": "markdown", + "id": "79cbc1c1" }, { - "cell_type": "markdown", - "id": "d35e04f9", - "metadata": { - "id": "d35e04f9" - }, + "metadata": {}, "source": [ "We use the cluster dataset folder for cross-domain prediction, containing three parts:\n", "\n", - "- Train samples from the source domain: Drug–protein pairs the model learns from.\n", + "- Train samples from the source domain: Drug\u2013protein pairs the model learns from.\n", "\n", "- Train samples from the target domain: Additional training data from a different distribution to improve generalisation.\n", "\n", - "- Test samples from the target domain: Unseen drug–protein pairs used to evaluate model performance on new data.\n", + "- Test samples from the target domain: Unseen drug\u2013protein pairs used to evaluate model performance on new data.\n", "\n", "The source and target sets are defined based on the clustering results." - ] + ], + "cell_type": "markdown", + "id": "d35e04f9" }, { - "cell_type": "markdown", - "id": "98acf744", - "metadata": { - "id": "98acf744" - }, + "metadata": {}, "source": [ "### Data Loading" - ] + ], + "cell_type": "markdown", + "id": "98acf744" }, { - "cell_type": "markdown", - "id": "1e5f4f44", - "metadata": { - "id": "1e5f4f44" - }, + "metadata": {}, "source": [ - "Here’s what each csv file looks like in a table format:\n", + "Here\u2019s what each csv file looks like in a table format:\n", "\n", "| SMILES | Protein Sequence | Y |\n", "|--------------------|--------------------------|---|\n", - "| Fc1ccc(C2(COC…) | MDNVLPVDSDLS… | 1 |\n", - "| O=c1oc2c(O)c(…) | MMYSKLLTLTTL… | 0 |\n", - "| CC(C)Oc1cc(N…) | MGMACLTMTEME… | 1 |\n", + "| Fc1ccc(C2(COC\u2026) | MDNVLPVDSDLS\u2026 | 1 |\n", + "| O=c1oc2c(O)c(\u2026) | MMYSKLLTLTTL\u2026 | 0 |\n", + "| CC(C)Oc1cc(N\u2026) | MGMACLTMTEME\u2026 | 1 |\n", "\n", "Each row of the dataset contains three key pieces of information:\n", "\n", @@ -580,39 +500,44 @@ "\n", "\n", "**Y (Labels)**: \n", - "Each drug–protein pair is given a label:\n", + "Each drug\u2013protein pair is given a label:\n", "- `1` if they interact\n", "- `0` if they do not\n", "\n", "\n", - "Each row shows one drug–protein pair. The goal of our machine learning model is to predict the last column (**Y**) — whether or not the drug and protein interact." - ] + "Each row shows one drug\u2013protein pair. The goal of our machine learning model is to predict the last column (**Y**) \u2014 whether or not the drug and protein interact." + ], + "cell_type": "markdown", + "id": "1e5f4f44" }, { - "cell_type": "markdown", - "id": "b7590daf", - "metadata": { - "id": "b7590daf" - }, + "metadata": {}, "source": [ "You can load CSV files into Python using tools like `pandas`. The output shows a sample of the data, including the SMILES string for the drug, the protein sequence, the interaction label (Y) and the cluster ID." - ] + ], + "cell_type": "markdown", + "id": "b7590daf" }, { + "metadata": {}, + "source": [ + "import pandas as pd\n", + "\n", + "dataFolder = os.path.join(\n", + " f\"data/drug-target-interaction/{cfg.DATA.DATASET}\", str(cfg.DATA.SPLIT)\n", + ")\n", + "\n", + "df_train_source = pd.read_csv(os.path.join(dataFolder, \"source_train.csv\"))\n", + "df_train_target = pd.read_csv(os.path.join(dataFolder, \"target_train.csv\"))\n", + "df_test_target = pd.read_csv(os.path.join(dataFolder, \"target_test.csv\"))\n", + "\n", + "print(\"Sample example:\", df_train_source.iloc[0])" + ], "cell_type": "code", - "execution_count": 12, - "id": "0c709e31", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "0c709e31", - "outputId": "5e663c15-f59f-4a0f-acc3-b545b6392a1e" - }, "outputs": [ { - "name": "stdout", "output_type": "stream", + "name": "stdout", "text": [ "Sample example: SMILES CC1=CN=C2N1C=CN=C2NCC1=CC=NC=C1\n", "Protein MARSLLLPLQILLLSLALETAGEEAQGDKIIDGAPCARGSHPWQVA...\n", @@ -623,26 +548,11 @@ ] } ], - "source": [ - "import pandas as pd\n", - "\n", - "dataFolder = os.path.join(\n", - " f\"data/drug-target-interaction/{cfg.DATA.DATASET}\", str(cfg.DATA.SPLIT)\n", - ")\n", - "\n", - "df_train_source = pd.read_csv(os.path.join(dataFolder, \"source_train.csv\"))\n", - "df_train_target = pd.read_csv(os.path.join(dataFolder, \"target_train.csv\"))\n", - "df_test_target = pd.read_csv(os.path.join(dataFolder, \"target_test.csv\"))\n", - "\n", - "print(\"Sample example:\", df_train_source.iloc[0])" - ] + "id": "0c709e31", + "execution_count": null }, { - "cell_type": "markdown", - "id": "542d4e69", - "metadata": { - "id": "542d4e69" - }, + "metadata": {}, "source": [ "### Data Preprocessing\n", "\n", @@ -652,26 +562,20 @@ "Protein sequences are transformed into fixed-length integer arrays using `kale.prepdata.chem_transform.integer_label_protein`, with each amino acid mapped to an integer and sequences padded or truncated to a uniform length.\n", "\n", "Finally, the `kale.loaddata.molecular_datasets.DTIDataset` class packages drugs, proteins, and labels into a PyTorch-ready dataset." - ] + ], + "cell_type": "markdown", + "id": "542d4e69" }, { - "cell_type": "markdown", - "id": "981d5520", - "metadata": { - "id": "981d5520" - }, + "metadata": {}, "source": [ "**Note:** If you encounter an error related to requiring numpy `<2.0`, simply ignore it and re-run this block until it completes successfully." - ] + ], + "cell_type": "markdown", + "id": "981d5520" }, { - "cell_type": "code", - "execution_count": 13, - "id": "ae5af8eb", - "metadata": { - "id": "ae5af8eb" - }, - "outputs": [], + "metadata": {}, "source": [ "from kale.loaddata.molecular_datasets import DTIDataset\n", "\n", @@ -679,24 +583,22 @@ "train_dataset = DTIDataset(df_train_source.index.values, df_train_source)\n", "train_target_dataset = DTIDataset(df_train_target.index.values, df_train_target)\n", "test_target_dataset = DTIDataset(df_test_target.index.values, df_test_target)" - ] + ], + "cell_type": "code", + "outputs": [], + "id": "ae5af8eb", + "execution_count": null }, { - "cell_type": "markdown", - "id": "a0a510ce", - "metadata": { - "id": "a0a510ce" - }, + "metadata": {}, "source": [ "We load data in small, manageable pieces called batches to save memory and speed up training. We use `kale.loaddata.sampler.MultiDataLoader` from PyKale to load one batch from the source domain and one from the target domain at each training step." - ] + ], + "cell_type": "markdown", + "id": "a0a510ce" }, { - "cell_type": "markdown", - "id": "c09084c0", - "metadata": { - "id": "c09084c0" - }, + "metadata": {}, "source": [ "First, we specify a few DataLoader parameters:\n", "- Batch size: Number of samples per batch\n", @@ -704,21 +606,31 @@ "- Number of workers: Parallel data loading\n", "- Drop last: Discard the last incomplete batch for consistent batch sizes\n", "- Collate function: Use graph_collate_func to batch variable-sized molecular graphs" - ] + ], + "cell_type": "markdown", + "id": "c09084c0" }, { + "metadata": {}, + "source": [ + "from torch.utils.data import DataLoader\n", + "from kale.loaddata.molecular_datasets import graph_collate_func\n", + "from kale.loaddata.sampler import MultiDataLoader\n", + "\n", + "params = {\n", + " \"batch_size\": cfg.SOLVER.BATCH_SIZE,\n", + " \"shuffle\": True,\n", + " \"num_workers\": cfg.SOLVER.NUM_WORKERS,\n", + " \"drop_last\": True,\n", + " \"collate_fn\": graph_collate_func,\n", + "}\n", + "\n", + "params" + ], "cell_type": "code", - "execution_count": 14, - "id": "94a15868", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "94a15868", - "outputId": "a4c14890-db12-45b7-bcbe-1a94b4f846b2" - }, "outputs": [ { + "output_type": "execute_result", "data": { "text/plain": [ "{'batch_size': 32,\n", @@ -728,57 +640,23 @@ " 'collate_fn': }" ] }, - "execution_count": 14, "metadata": {}, - "output_type": "execute_result" + "execution_count": 14 } ], - "source": [ - "from torch.utils.data import DataLoader\n", - "from kale.loaddata.molecular_datasets import graph_collate_func\n", - "from kale.loaddata.sampler import MultiDataLoader\n", - "\n", - "params = {\n", - " \"batch_size\": cfg.SOLVER.BATCH_SIZE,\n", - " \"shuffle\": True,\n", - " \"num_workers\": cfg.SOLVER.NUM_WORKERS,\n", - " \"drop_last\": True,\n", - " \"collate_fn\": graph_collate_func,\n", - "}\n", - "\n", - "params" - ] + "id": "94a15868", + "execution_count": null }, { - "cell_type": "markdown", - "id": "e884ed07", - "metadata": { - "id": "e884ed07" - }, + "metadata": {}, "source": [ "Then, we create a DataLoader from both the source and target datasets for training." - ] + ], + "cell_type": "markdown", + "id": "e884ed07" }, { - "cell_type": "code", - "execution_count": 15, - "id": "24ba12b5", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "24ba12b5", - "outputId": "47a5c085-d2d0-48f8-cad2-4483e3fe2efa" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Using domain adaptation: True\n" - ] - } - ], + "metadata": {}, "source": [ "print(\"Using domain adaptation:\", cfg.DA.USE)\n", "\n", @@ -795,26 +673,30 @@ " training_generator = MultiDataLoader(\n", " dataloaders=[source_generator, target_generator], n_batches=n_batches\n", " )" - ] + ], + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Using domain adaptation: True\n" + ] + } + ], + "id": "24ba12b5", + "execution_count": null }, { - "cell_type": "markdown", - "id": "649301de", - "metadata": { - "id": "649301de" - }, + "metadata": {}, "source": [ - "Lastly, we set up DataLoaders for validation and testing. Since we don’t want to shuffle or drop any samples, we adjust the parameters accordingly." - ] + "Lastly, we set up DataLoaders for validation and testing. Since we don\u2019t want to shuffle or drop any samples, we adjust the parameters accordingly." + ], + "cell_type": "markdown", + "id": "649301de" }, { - "cell_type": "code", - "execution_count": 16, - "id": "b4cf543a", - "metadata": { - "id": "b4cf543a" - }, - "outputs": [], + "metadata": {}, "source": [ "# Update parameters for validation/testing (no shuffling, keep all data)\n", "params.update({\"shuffle\": False, \"drop_last\": False})\n", @@ -822,43 +704,24 @@ "# Create validation and test data loaders\n", "valid_generator = DataLoader(test_target_dataset, **params)\n", "test_generator = DataLoader(test_target_dataset, **params)" - ] + ], + "cell_type": "code", + "outputs": [], + "id": "b4cf543a", + "execution_count": null }, { - "cell_type": "markdown", - "id": "e474eea2", - "metadata": { - "id": "e474eea2" - }, + "metadata": {}, "source": [ "### Exercise: Dataset Inspection\n", "\n", - "Once the dataset is ready, let’s inspect one sample from the training data to check the input graph, protein sequence, and label format." - ] + "Once the dataset is ready, let\u2019s inspect one sample from the training data to check the input graph, protein sequence, and label format." + ], + "cell_type": "markdown", + "id": "e474eea2" }, { - "cell_type": "code", - "execution_count": 17, - "id": "31b8a93f", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "31b8a93f", - "outputId": "74ae5660-3d74-4c5e-f5ed-065c1f099516" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "First sample from source batch:\n", - "Drug graph: Data(x=[290, 7], edge_index=[2, 106], edge_attr=[106, 1], num_nodes=290)\n", - "Protein sequence: tensor([11., 7., 18., ..., 0., 0., 0.], dtype=torch.float64)\n", - "Label: tensor(0., dtype=torch.float64)\n" - ] - } - ], + "metadata": {}, "source": [ "# Get the first batch (contains one batch from source and one from target)\n", "first_batch = next(iter(training_generator))\n", @@ -871,14 +734,25 @@ "print(\"Drug graph:\", source_batch[0][0])\n", "print(\"Protein sequence:\", source_batch[1][0])\n", "print(\"Label:\", source_batch[2][0])" - ] + ], + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "First sample from source batch:\n", + "Drug graph: Data(x=[290, 7], edge_index=[2, 106], edge_attr=[106, 1], num_nodes=290)\n", + "Protein sequence: tensor([11., 7., 18., ..., 0., 0., 0.], dtype=torch.float64)\n", + "Label: tensor(0., dtype=torch.float64)\n" + ] + } + ], + "id": "31b8a93f", + "execution_count": null }, { - "cell_type": "markdown", - "id": "cb0b269b", - "metadata": { - "id": "cb0b269b" - }, + "metadata": {}, "source": [ "This sample is a tuple with three parts:\n", "\n", @@ -893,47 +767,43 @@ "\n", "3. **Label (float)**\n", "- `0.0`; The ground-truth interaction label indicating no interaction." - ] + ], + "cell_type": "markdown", + "id": "cb0b269b" }, { - "cell_type": "markdown", - "id": "8eaf5c8f", - "metadata": { - "id": "8eaf5c8f" - }, + "metadata": {}, "source": [ "## Step 2: Model Definition" - ] + ], + "cell_type": "markdown", + "id": "8eaf5c8f" }, { - "cell_type": "markdown", - "id": "b2819549", - "metadata": { - "id": "b2819549" - }, + "metadata": {}, "source": [ "### Embed\n", "\n", "DrugBAN consists of three main components: a Graph Convolutional Network (GCN) for extracting structural features from drug molecular graphs, a Convolutional Neural Network (CNN) for encoding protein sequences, and a Bilinear Attention Network (BAN) for fusing drug and protein features. The fused representation is then passed through a Multi-Layer Perceptron (MLP) classifier to predict interaction scores.\n", "\n", "We define the DrugBAN class in `kale.embed.ban`." - ] + ], + "cell_type": "markdown", + "id": "b2819549" }, { + "metadata": {}, + "source": [ + "from kale.embed.ban import DrugBAN\n", + "\n", + "model = DrugBAN(**cfg)\n", + "print(model)" + ], "cell_type": "code", - "execution_count": 18, - "id": "1c8f3acc", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "1c8f3acc", - "outputId": "65f82225-7219-4391-abf2-2194a00c3af0" - }, "outputs": [ { - "name": "stdout", "output_type": "stream", + "name": "stdout", "text": [ "DrugBAN(\n", " (drug_extractor): MolecularGCN(\n", @@ -982,32 +852,20 @@ ] } ], - "source": [ - "from kale.embed.ban import DrugBAN\n", - "\n", - "model = DrugBAN(**cfg)\n", - "print(model)" - ] + "id": "1c8f3acc", + "execution_count": null }, { - "cell_type": "markdown", - "id": "32084f24", - "metadata": { - "id": "32084f24" - }, + "metadata": {}, "source": [ "### Predict\n", "We use the PyKale pipeline API `kale.pipeline.drugban_trainer` to connect dataloaders, encoders and outcoders for model training and evaluation." - ] + ], + "cell_type": "markdown", + "id": "32084f24" }, { - "cell_type": "code", - "execution_count": 19, - "id": "46e2b9b4", - "metadata": { - "id": "46e2b9b4" - }, - "outputs": [], + "metadata": {}, "source": [ "from kale.pipeline.drugban_trainer import DrugbanTrainer\n", "\n", @@ -1026,26 +884,22 @@ " da_random_dim=cfg.DA.RANDOM_DIM,\n", " decoder_in_dim=cfg.DECODER.IN_DIM,\n", ")" - ] + ], + "cell_type": "code", + "outputs": [], + "id": "46e2b9b4", + "execution_count": null }, { - "cell_type": "markdown", - "id": "a48c86b9", - "metadata": { - "id": "a48c86b9" - }, + "metadata": {}, "source": [ - "We want to save the best model during training so we can reuse it later without needing to retrain. PyTorch Lightning’s `ModelCheckpoint` does this by automatically saving the model whenever it achieves a new best validation AUROC score." - ] + "We want to save the best model during training so we can reuse it later without needing to retrain. PyTorch Lightning\u2019s `ModelCheckpoint` does this by automatically saving the model whenever it achieves a new best validation AUROC score." + ], + "cell_type": "markdown", + "id": "a48c86b9" }, { - "cell_type": "code", - "execution_count": 20, - "id": "7754bd38", - "metadata": { - "id": "7754bd38" - }, - "outputs": [], + "metadata": {}, "source": [ "import pytorch_lightning as pl\n", "from pytorch_lightning.callbacks import ModelCheckpoint\n", @@ -1055,33 +909,38 @@ " monitor=\"val_BinaryAUROC\",\n", " mode=\"max\",\n", ")" - ] + ], + "cell_type": "code", + "outputs": [], + "id": "7754bd38", + "execution_count": null }, { - "cell_type": "markdown", - "id": "969beac0", - "metadata": { - "id": "969beac0" - }, + "metadata": {}, "source": [ "We now create the `Trainer`." - ] + ], + "cell_type": "markdown", + "id": "969beac0" }, { + "metadata": {}, + "source": [ + "import torch\n", + "\n", + "trainer = pl.Trainer(\n", + " callbacks=[checkpoint_callback],\n", + " devices=\"auto\",\n", + " accelerator=\"auto\",\n", + " max_epochs=cfg.SOLVER.MAX_EPOCH,\n", + " deterministic=True,\n", + ")" + ], "cell_type": "code", - "execution_count": 21, - "id": "e68e07bc", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "e68e07bc", - "outputId": "08a9b744-3fb5-48c7-863f-dd64afc4dd80" - }, "outputs": [ { - "name": "stderr", "output_type": "stream", + "name": "stderr", "text": [ "INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True\n", "INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n", @@ -1089,34 +948,19 @@ ] } ], - "source": [ - "import torch\n", - "\n", - "trainer = pl.Trainer(\n", - " callbacks=[checkpoint_callback],\n", - " devices=\"auto\",\n", - " accelerator=\"auto\",\n", - " max_epochs=cfg.SOLVER.MAX_EPOCH,\n", - " deterministic=True,\n", - ")" - ] + "id": "e68e07bc", + "execution_count": null }, { - "cell_type": "markdown", - "id": "1f9a4714", - "metadata": { - "id": "1f9a4714" - }, + "metadata": {}, "source": [ "## Step 3: Model Training" - ] + ], + "cell_type": "markdown", + "id": "1f9a4714" }, { - "cell_type": "markdown", - "id": "b72634ee", - "metadata": { - "id": "b72634ee" - }, + "metadata": {}, "source": [ "### Train\n", "\n", @@ -1131,70 +975,24 @@ "\n", "\n", "This code block takes approximately 5 minutes to complete." - ] + ], + "cell_type": "markdown", + "id": "b72634ee" }, { + "metadata": {}, + "source": [ + "trainer.fit(\n", + " drugban_trainer,\n", + " train_dataloaders=training_generator,\n", + " val_dataloaders=valid_generator,\n", + ")" + ], "cell_type": "code", - "execution_count": 22, - "id": "0624b0c6", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 424, - "referenced_widgets": [ - "bc415b5a5635482eb20a65601866febf", - "98992d70c9d74d3fa794edf0dc9333b1", - "32de4bd89e034c35adc198247950d4bf", - "12c8f4410cc74d4f822c779c724bce94", - "fe6bdb21df9c4415b9899aaa96969502", - "65bb75f11b1f4064b715a166bed1e215", - "a893676474004fb0a24023403f5acf46", - "941d8002127f499c9868abffea2a2429", - "2fc478f35c3e48f2b0c13bd9a73f3dc6", - "79368bfcf5cc4067a59a46c242a77e2d", - "8632ffdd1b5844f183cc28e824cd117e", - "9a8deecaaef543fbb18d0f50b83b5abc", - "b113d7ba3f50417fb93de1118b4e4dec", - "4b0bcd88167b469da36e8433a4d47377", - "1ed13da64943461ab42ab18495a6246b", - "7d4fd3d9c5ed4cc6af20be7384a758ac", - "635b3ed7aa264fa6907187152359a63f", - "9d5ebfa060ac49d6bd7a8c4837b4fc29", - "a80aad59381c42edaa2adb89d781ddd6", - "b71c219972474c2c89ed04fa48ec637d", - "8f16c1ee593d406c88c9bfcf32fdd3df", - "16cdd1f5e14e405490575177c52f4408", - "07d966a71f604f63b00707e6d3a0bfe6", - "37f06572693f4972b468c3997fd0687c", - "a0cec4295099427bb8e97229bd49d620", - "931700f44cb0491ca187cbc58dc67476", - "dd7ad2a05e22470ab00e2118ff994147", - "1aa1b891fe6b447586d3e87ee62768a2", - "8618e343c4f1435ebf099bd70605606b", - "33a0e7cd7f7d49ffa57e962ca2bf0b66", - "f223529cf08b4235b8e74ad1049c5a60", - "304a1cadd7c048a2a34568301fb1c4dd", - "e809f1b951f740ddabc3a3e56d4dd903", - "f60cfca35cab44e5801bbb2b04f9cc6f", - "2f468dcdec6d4c6db229fcab061fd0d8", - "530bec1df42241a9be6a34c561be4a36", - "700bde0d67f64f9d8680c483d232fa5b", - "abf0b363155c4b1fb8de818ae41606eb", - "301141b0856941cba6a0a1f496267c4e", - "b4f64adc9af34262bb35ea42f0f58899", - "3c34cab162424e51b4dcf0008c39c7a0", - "b2e51980648f49f2963ba66745b82b57", - "b9b44b6940884c359ebb60834280b401", - "807b512710c848eabb6816cf8e9ecc75" - ] - }, - "id": "0624b0c6", - "outputId": "8a7fb16d-707f-4a4e-edd8-dec451f5fa5f" - }, "outputs": [ { - "name": "stderr", "output_type": "stream", + "name": "stderr", "text": [ "INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", "INFO:pytorch_lightning.callbacks.model_summary:\n", @@ -1215,6 +1013,7 @@ ] }, { + "output_type": "display_data", "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "bc415b5a5635482eb20a65601866febf", @@ -1225,10 +1024,10 @@ "Sanity Checking: | | 0/? [00:00┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃ Test metric DataLoader 0 ┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│ test_BinaryAUROC 0.48449382185935974 │\n", - "│ test_BinaryAccuracy 0.5071665048599243 │\n", - "│ test_BinaryF1Score 0.0933062881231308 │\n", - "│ test_BinaryRecall 0.050549451261758804 │\n", - "│ test_BinarySpecificity 0.9668141603469849 │\n", - "│ test_accuracy_sklearn 0.5038588643074036 │\n", - "│ test_auroc_sklearn 0.48449382185935974 │\n", - "│ test_f1_sklearn 0.6671618223190308 │\n", - "│ test_loss 0.8901852369308472 │\n", - "│ test_optim_threshold 0.07649494707584381 │\n", - "│ test_sensitivity 0.006637168116867542 │\n", - "│ test_specificity 0.997802197933197 │\n", - "└───────────────────────────┴───────────────────────────┘\n", + "
\u250f\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2533\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2513\n",
+              "\u2503        Test metric        \u2503       DataLoader 0        \u2503\n",
+              "\u2521\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2547\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2529\n",
+              "\u2502     test_BinaryAUROC      \u2502    0.48449382185935974    \u2502\n",
+              "\u2502    test_BinaryAccuracy    \u2502    0.5071665048599243     \u2502\n",
+              "\u2502    test_BinaryF1Score     \u2502    0.0933062881231308     \u2502\n",
+              "\u2502     test_BinaryRecall     \u2502   0.050549451261758804    \u2502\n",
+              "\u2502  test_BinarySpecificity   \u2502    0.9668141603469849     \u2502\n",
+              "\u2502   test_accuracy_sklearn   \u2502    0.5038588643074036     \u2502\n",
+              "\u2502    test_auroc_sklearn     \u2502    0.48449382185935974    \u2502\n",
+              "\u2502      test_f1_sklearn      \u2502    0.6671618223190308     \u2502\n",
+              "\u2502         test_loss         \u2502    0.8901852369308472     \u2502\n",
+              "\u2502   test_optim_threshold    \u2502    0.07649494707584381    \u2502\n",
+              "\u2502     test_sensitivity      \u2502   0.006637168116867542    \u2502\n",
+              "\u2502     test_specificity      \u2502     0.997802197933197     \u2502\n",
+              "\u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n",
               "
\n" ], "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│\u001b[36m \u001b[0m\u001b[36m test_BinaryAUROC \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.48449382185935974 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_BinaryAccuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.5071665048599243 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_BinaryF1Score \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.0933062881231308 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_BinaryRecall \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.050549451261758804 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_BinarySpecificity \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9668141603469849 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_accuracy_sklearn \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.5038588643074036 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_auroc_sklearn \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.48449382185935974 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_f1_sklearn \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.6671618223190308 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.8901852369308472 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_optim_threshold \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.07649494707584381 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_sensitivity \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.006637168116867542 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_specificity \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.997802197933197 \u001b[0m\u001b[35m \u001b[0m│\n", - "└───────────────────────────┴───────────────────────────┘\n" + "\u250f\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2533\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2513\n", + "\u2503\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m\u2503\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m\u2503\n", + "\u2521\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2547\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2529\n", + "\u2502\u001b[36m \u001b[0m\u001b[36m test_BinaryAUROC \u001b[0m\u001b[36m \u001b[0m\u2502\u001b[35m \u001b[0m\u001b[35m 0.48449382185935974 \u001b[0m\u001b[35m \u001b[0m\u2502\n", + "\u2502\u001b[36m \u001b[0m\u001b[36m test_BinaryAccuracy \u001b[0m\u001b[36m \u001b[0m\u2502\u001b[35m \u001b[0m\u001b[35m 0.5071665048599243 \u001b[0m\u001b[35m \u001b[0m\u2502\n", + "\u2502\u001b[36m \u001b[0m\u001b[36m test_BinaryF1Score \u001b[0m\u001b[36m \u001b[0m\u2502\u001b[35m \u001b[0m\u001b[35m 0.0933062881231308 \u001b[0m\u001b[35m \u001b[0m\u2502\n", + "\u2502\u001b[36m \u001b[0m\u001b[36m test_BinaryRecall \u001b[0m\u001b[36m \u001b[0m\u2502\u001b[35m \u001b[0m\u001b[35m 0.050549451261758804 \u001b[0m\u001b[35m \u001b[0m\u2502\n", + "\u2502\u001b[36m \u001b[0m\u001b[36m test_BinarySpecificity \u001b[0m\u001b[36m \u001b[0m\u2502\u001b[35m \u001b[0m\u001b[35m 0.9668141603469849 \u001b[0m\u001b[35m \u001b[0m\u2502\n", + "\u2502\u001b[36m \u001b[0m\u001b[36m test_accuracy_sklearn \u001b[0m\u001b[36m \u001b[0m\u2502\u001b[35m \u001b[0m\u001b[35m 0.5038588643074036 \u001b[0m\u001b[35m \u001b[0m\u2502\n", + "\u2502\u001b[36m \u001b[0m\u001b[36m test_auroc_sklearn \u001b[0m\u001b[36m \u001b[0m\u2502\u001b[35m \u001b[0m\u001b[35m 0.48449382185935974 \u001b[0m\u001b[35m \u001b[0m\u2502\n", + "\u2502\u001b[36m \u001b[0m\u001b[36m test_f1_sklearn \u001b[0m\u001b[36m \u001b[0m\u2502\u001b[35m \u001b[0m\u001b[35m 0.6671618223190308 \u001b[0m\u001b[35m \u001b[0m\u2502\n", + "\u2502\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m\u2502\u001b[35m \u001b[0m\u001b[35m 0.8901852369308472 \u001b[0m\u001b[35m \u001b[0m\u2502\n", + "\u2502\u001b[36m \u001b[0m\u001b[36m test_optim_threshold \u001b[0m\u001b[36m \u001b[0m\u2502\u001b[35m \u001b[0m\u001b[35m 0.07649494707584381 \u001b[0m\u001b[35m \u001b[0m\u2502\n", + "\u2502\u001b[36m \u001b[0m\u001b[36m test_sensitivity \u001b[0m\u001b[36m \u001b[0m\u2502\u001b[35m \u001b[0m\u001b[35m 0.006637168116867542 \u001b[0m\u001b[35m \u001b[0m\u2502\n", + "\u2502\u001b[36m \u001b[0m\u001b[36m test_specificity \u001b[0m\u001b[36m \u001b[0m\u2502\u001b[35m \u001b[0m\u001b[35m 0.997802197933197 \u001b[0m\u001b[35m \u001b[0m\u2502\n", + "\u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n" ] }, - "metadata": {}, - "output_type": "display_data" + "metadata": {} }, { + "output_type": "execute_result", "data": { "text/plain": [ "[{'test_loss': 0.8901852369308472,\n", @@ -1430,51 +1203,41 @@ " 'test_BinaryAccuracy': 0.5071665048599243}]" ] }, - "execution_count": 23, "metadata": {}, - "output_type": "execute_result" + "execution_count": 23 } ], - "source": [ - "trainer.test(drugban_trainer, dataloaders=test_generator, ckpt_path=\"best\")" - ] + "id": "c1415c02", + "execution_count": null }, { - "cell_type": "markdown", - "id": "bb0a08bec91d2bd9", - "metadata": { - "id": "bb0a08bec91d2bd9" - }, + "metadata": {}, "source": [ "### Performance Comparison\n", "\n", "The earlier example was a simple demonstration. To properly evaluate DrugBAN against baseline models, we train it for 100 epochs across multiple random seeds.\n", "\n", "We provide a checkpoint trained for 100 epochs in the `checkpoint` for your test after the tutorial. We will also use the provided checkpoint for the interpretation section for a better visualization.\n" - ] + ], + "cell_type": "markdown", + "id": "bb0a08bec91d2bd9" }, { - "cell_type": "markdown", - "id": "37dbe9f3", - "metadata": { - "id": "37dbe9f3" - }, + "metadata": {}, "source": [ "The figure below shows the performance of different models on the BioSNAP and BindingDB datasets:\n", "- Left plot: AUROC (Area Under the ROC Curve)\n", - "- Right plot: AUPRC (Area Under the Precision–Recall Curve)\n", + "- Right plot: AUPRC (Area Under the Precision\u2013Recall Curve)\n", "\n", "![](https://media.springernature.com/full/springer-static/image/art%3A10.1038%2Fs42256-022-00605-1/MediaObjects/42256_2022_605_Fig3_HTML.png?as=webp)\n", "\n", "The box plots show the median as the centre lines and the mean as green triangles. The minima and lower percentile represent the worst and second-worst scores. The maxima and upper percentile indicate the best and second-best scores. Supplementary Table 2 provides the data statistics of the BindingDB and BioSNAP datasets." - ] + ], + "cell_type": "markdown", + "id": "37dbe9f3" }, { - "cell_type": "markdown", - "id": "02e3c73e", - "metadata": { - "id": "02e3c73e" - }, + "metadata": {}, "source": [ "## Step 5: Interpretation\n", "\n", @@ -1483,27 +1246,21 @@ "2) generate molecule images with attention highlights.\n", "\n", "This helps us understand which parts of the drug contribute to the interaction with the target protein." - ] + ], + "cell_type": "markdown", + "id": "02e3c73e" }, { - "cell_type": "markdown", - "id": "4a56f260141b7368", - "metadata": { - "id": "4a56f260141b7368" - }, + "metadata": {}, "source": [ "### Extracting Attention Weights\n", "First, we need to load the test dataset and create a DataLoader for it. This will allow us to process the test samples in batches. We define functions to create the test dataset and DataLoader." - ] + ], + "cell_type": "markdown", + "id": "4a56f260141b7368" }, { - "cell_type": "code", - "execution_count": 24, - "id": "2c67553408592b2", - "metadata": { - "id": "2c67553408592b2" - }, - "outputs": [], + "metadata": {}, "source": [ "def get_test_dataset(dataFolder):\n", " df_test_target = pd.read_csv(dataFolder)\n", @@ -1521,48 +1278,40 @@ " drop_last=True,\n", " )\n", " return test_dataloader" - ] + ], + "cell_type": "code", + "outputs": [], + "id": "2c67553408592b2", + "execution_count": null }, { - "cell_type": "markdown", - "id": "ecdab66ee05da10c", - "metadata": { - "id": "ecdab66ee05da10c" - }, + "metadata": {}, "source": [ - "We load a small subset of samples for testing from the provided `.csv` file. You can create your own `.csv` file with the same format to test your drug–protein pairs." - ] + "We load a small subset of samples for testing from the provided `.csv` file. You can create your own `.csv` file with the same format to test your drug\u2013protein pairs." + ], + "cell_type": "markdown", + "id": "ecdab66ee05da10c" }, { - "cell_type": "code", - "execution_count": 25, - "id": "7ef1867541d2577a", - "metadata": { - "id": "7ef1867541d2577a" - }, - "outputs": [], + "metadata": {}, "source": [ "test_dataFolder = \"/content/drug-target-interaction/data/drug-target-interaction/bindingdb/interpretation_samples.csv\"" - ] + ], + "cell_type": "code", + "outputs": [], + "id": "7ef1867541d2577a", + "execution_count": null }, { - "cell_type": "markdown", - "id": "7fec5dc00a7b4aa4", - "metadata": { - "id": "7fec5dc00a7b4aa4" - }, + "metadata": {}, "source": [ "We then build the test dataset and DataLoader using the functions defined above. The `batchsize` is set to 1 to ensure we process one sample at a time for attention visualization later." - ] + ], + "cell_type": "markdown", + "id": "7fec5dc00a7b4aa4" }, { - "cell_type": "code", - "execution_count": 26, - "id": "c99a558c96a1ffd", - "metadata": { - "id": "c99a558c96a1ffd" - }, - "outputs": [], + "metadata": {}, "source": [ "test_dataset = get_test_dataset(test_dataFolder)\n", "test_dataloader = get_test_dataloader(\n", @@ -1571,26 +1320,22 @@ " num_workers=cfg.SOLVER.NUM_WORKERS,\n", " collate_fn=graph_collate_func,\n", ")" - ] + ], + "cell_type": "code", + "outputs": [], + "id": "c99a558c96a1ffd", + "execution_count": null }, { - "cell_type": "markdown", - "id": "e1ff543d132abc42", - "metadata": { - "id": "e1ff543d132abc42" - }, + "metadata": {}, "source": [ "Then, we use the following function to load the trained model with the PyKale API." - ] + ], + "cell_type": "markdown", + "id": "e1ff543d132abc42" }, { - "cell_type": "code", - "execution_count": 27, - "id": "3b7f12b12b139799", - "metadata": { - "id": "3b7f12b12b139799" - }, - "outputs": [], + "metadata": {}, "source": [ "def get_model_from_ckpt(ckpt_path, config):\n", " return DrugbanTrainer.load_from_checkpoint(\n", @@ -1611,31 +1356,31 @@ " da_random_dim=config.DA.RANDOM_DIM,\n", " decoder_in_dim=config.DECODER.IN_DIM,\n", " )" - ] + ], + "cell_type": "code", + "outputs": [], + "id": "3b7f12b12b139799", + "execution_count": null }, { - "cell_type": "markdown", - "id": "c0678dddcdf076fc", - "metadata": { - "id": "c0678dddcdf076fc" - }, + "metadata": {}, "source": [ "Once the model and test data are prepared, we extract attention maps from the trained model. We set the directory to the provided checkpoint file, load the trained model, and set it to evaluation mode." - ] + ], + "cell_type": "markdown", + "id": "c0678dddcdf076fc" }, { + "metadata": {}, + "source": [ + "checkpoint_path = \"/content/drug-target-interaction/checkpoint/best.ckpt\"\n", + "model = get_model_from_ckpt(checkpoint_path, cfg)\n", + "model.model.eval()" + ], "cell_type": "code", - "execution_count": 28, - "id": "d2a8931099b73c01", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "d2a8931099b73c01", - "outputId": "1dbad556-912d-44f3-88ad-059cf7876a36" - }, "outputs": [ { + "output_type": "execute_result", "data": { "text/plain": [ "DrugBAN(\n", @@ -1684,57 +1429,23 @@ ")" ] }, - "execution_count": 28, "metadata": {}, - "output_type": "execute_result" + "execution_count": 28 } ], - "source": [ - "checkpoint_path = \"/content/drug-target-interaction/checkpoint/best.ckpt\"\n", - "model = get_model_from_ckpt(checkpoint_path, cfg)\n", - "model.model.eval()" - ] + "id": "d2a8931099b73c01", + "execution_count": null }, { - "cell_type": "markdown", - "id": "159d3fa67b29c9e9", - "metadata": { - "id": "159d3fa67b29c9e9" - }, + "metadata": {}, "source": [ "We then iterate through the test DataLoader, passing each batch of drug and protein pairs to the model. The model's forward method returns the attention weights. After processing all batches, we concatenate the attention tensors into a single tensor." - ] + ], + "cell_type": "markdown", + "id": "159d3fa67b29c9e9" }, { - "cell_type": "code", - "execution_count": 29, - "id": "781a7762c36c72be", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "781a7762c36c72be", - "outputId": "fade7fb6-bea8-438a-9688-8cffd8dc9c47" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 6/6 [00:00<00:00, 65.52it/s]\n" - ] - }, - { - "data": { - "text/plain": [ - "torch.Size([6, 2, 290, 1185])" - ] - }, - "execution_count": 29, - "metadata": {}, - "output_type": "execute_result" - } - ], + "metadata": {}, "source": [ "from tqdm import tqdm\n", "\n", @@ -1755,44 +1466,56 @@ "torch.save(all_attentions, \"attention_maps.pt\")\n", "\n", "all_attentions.shape" - ] + ], + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 6/6 [00:00<00:00, 65.52it/s]\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "torch.Size([6, 2, 290, 1185])" + ] + }, + "metadata": {}, + "execution_count": 29 + } + ], + "id": "781a7762c36c72be", + "execution_count": null }, { - "cell_type": "markdown", - "id": "78dc763b6c0eef0", - "metadata": { - "id": "78dc763b6c0eef0" - }, + "metadata": {}, "source": [ "The attention has shape [B, H, V, Q] (Number of drug-target pairs, Heads of attentions, Drug tokens, Protein tokens)." - ] + ], + "cell_type": "markdown", + "id": "78dc763b6c0eef0" }, { - "cell_type": "markdown", - "id": "8f72ea4d93f640cb", - "metadata": { - "id": "8f72ea4d93f640cb" - }, + "metadata": {}, "source": [ "### Visualize Attention Maps and Molecule Images" - ] + ], + "cell_type": "markdown", + "id": "8f72ea4d93f640cb" }, { - "cell_type": "markdown", - "id": "383c342a7c31d7ae", - "metadata": { - "id": "383c342a7c31d7ae" - }, + "metadata": {}, "source": [ "Once attention maps are saved, run the visualization script:" - ] + ], + "cell_type": "markdown", + "id": "383c342a7c31d7ae" }, { - "cell_type": "markdown", - "id": "d8a746169def8da5", - "metadata": { - "id": "d8a746169def8da5" - }, + "metadata": {}, "source": [ "This script will:\n", "\n", @@ -1800,107 +1523,89 @@ "\n", "2) Plot:\n", "\n", - " a) A heatmap of attention over drug–protein tokens.\n", + " a) A heatmap of attention over drug\u2013protein tokens.\n", "\n", " b) Molecular structures with atoms highlighted by attention values.\n", "\n", "The output images are saved in the `visualization` directory. You can also modify the `data_file` to use your own input in the same format as `target_test.csv`.\n", "\n" - ] + ], + "cell_type": "markdown", + "id": "d8a746169def8da5" }, { - "cell_type": "markdown", - "id": "aac54bfc67ce32eb", - "metadata": { - "id": "aac54bfc67ce32eb" - }, + "metadata": {}, "source": [ "We first import the necessary PyKale APIs and set the output directory." - ] + ], + "cell_type": "markdown", + "id": "aac54bfc67ce32eb" }, { - "cell_type": "code", - "execution_count": 30, - "id": "d3c1d2e4cab69107", - "metadata": { - "id": "d3c1d2e4cab69107" - }, - "outputs": [], + "metadata": {}, "source": [ "from kale.interpret.visualize import draw_attention_map, draw_mol_with_attention\n", "from kale.prepdata.tensor_reshape import normalize_tensor\n", "\n", "out_dir = \"./visualization\"\n", "os.makedirs(out_dir, exist_ok=True)" - ] + ], + "cell_type": "code", + "outputs": [], + "id": "d3c1d2e4cab69107", + "execution_count": null }, { - "cell_type": "markdown", - "id": "126b62034111d92a", - "metadata": { - "id": "126b62034111d92a" - }, + "metadata": {}, "source": [ "We then load the attention maps, data, and SMILES strings from the test dataset." - ] + ], + "cell_type": "markdown", + "id": "126b62034111d92a" }, { - "cell_type": "code", - "execution_count": 31, - "id": "7f70a6810c1c5e60", - "metadata": { - "id": "7f70a6810c1c5e60" - }, - "outputs": [], + "metadata": {}, "source": [ "attention = torch.load(\"attention_maps.pt\", map_location=\"cpu\")\n", "data_df = pd.read_csv(test_dataFolder)\n", "smiles = data_df[\"SMILES\"]\n", "proteins = data_df[\"Protein\"]" - ] + ], + "cell_type": "code", + "outputs": [], + "id": "7f70a6810c1c5e60", + "execution_count": null }, { - "cell_type": "markdown", - "id": "d1a009bbb9f4a0f9", - "metadata": { - "id": "d1a009bbb9f4a0f9" - }, + "metadata": {}, "source": [ "We select the first sample from the attention maps and corresponding SMILES and protein sequence for visualization." - ] + ], + "cell_type": "markdown", + "id": "d1a009bbb9f4a0f9" }, { - "cell_type": "code", - "execution_count": 32, - "id": "e808c255fe862925", - "metadata": { - "id": "e808c255fe862925" - }, - "outputs": [], + "metadata": {}, "source": [ "index = 0\n", "att_path = os.path.join(out_dir, f\"att_map_{index}.png\")\n", "mol_path = os.path.join(out_dir, f\"mol_{index}.svg\")" - ] + ], + "cell_type": "code", + "outputs": [], + "id": "e808c255fe862925", + "execution_count": null }, { - "cell_type": "markdown", - "id": "438e6aa218e6b51d", - "metadata": { - "id": "438e6aa218e6b51d" - }, + "metadata": {}, "source": [ "We crop the attention map to the actual lengths of the drug and protein sequences. This is important because the attention map may include padding tokens." - ] + ], + "cell_type": "markdown", + "id": "438e6aa218e6b51d" }, { - "cell_type": "code", - "execution_count": 33, - "id": "af15baa1c8caabc0", - "metadata": { - "id": "af15baa1c8caabc0" - }, - "outputs": [], + "metadata": {}, "source": [ "from rdkit import Chem\n", "\n", @@ -1919,26 +1624,22 @@ "\n", "# Normalize\n", "att = normalize_tensor(att)" - ] + ], + "cell_type": "code", + "outputs": [], + "id": "af15baa1c8caabc0", + "execution_count": null }, { - "cell_type": "markdown", - "id": "60a4ce71146a721e", - "metadata": { - "id": "60a4ce71146a721e" - }, + "metadata": {}, "source": [ "Finally, we save the attention map and the molecule image with attention highlights." - ] + ], + "cell_type": "markdown", + "id": "60a4ce71146a721e" }, { - "cell_type": "code", - "execution_count": 34, - "id": "403f77ada0ecc446", - "metadata": { - "id": "403f77ada0ecc446" - }, - "outputs": [], + "metadata": {}, "source": [ "draw_attention_map(\n", " att,\n", @@ -1947,47 +1648,50 @@ " xlabel=\"Drug Tokens\",\n", " ylabel=\"Protein Tokens\",\n", ")" - ] - }, - { + ], "cell_type": "code", - "execution_count": 35, - "id": "b1003372361a66d6", - "metadata": { - "id": "b1003372361a66d6" - }, "outputs": [], + "id": "403f77ada0ecc446", + "execution_count": null + }, + { + "metadata": {}, "source": [ "draw_mol_with_attention(att.mean(dim=1), smile, mol_path)" - ] + ], + "cell_type": "code", + "outputs": [], + "id": "b1003372361a66d6", + "execution_count": null }, { - "cell_type": "code", - "execution_count": 36, - "id": "4mHWCbJmGMgG", "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 921 - }, - "id": "4mHWCbJmGMgG", - "outputId": "5b3f1960-85a5-4ca3-bbbb-f48f6bc67401", "tags": [ "hide-input" ] }, + "source": [ + "from IPython.display import Image, SVG\n", + "\n", + "attention_plot = Image(att_path)\n", + "molecular_img = SVG(mol_path)\n", + "display(attention_plot)\n", + "display(molecular_img)" + ], + "cell_type": "code", "outputs": [ { + "output_type": "display_data", "data": { "image/png": "", "text/plain": [ "" ] }, - "metadata": {}, - "output_type": "display_data" + "metadata": {} }, { + "output_type": "display_data", "data": { "image/svg+xml": [ "\n", @@ -2133,45 +1837,30 @@ "" ] }, - "metadata": {}, - "output_type": "display_data" + "metadata": {} } ], - "source": [ - "from IPython.display import Image, SVG\n", - "\n", - "attention_plot = Image(att_path)\n", - "molecular_img = SVG(mol_path)\n", - "display(attention_plot)\n", - "display(molecular_img)" - ] + "id": "4mHWCbJmGMgG", + "execution_count": null }, { - "cell_type": "markdown", - "id": "1999d67d6d14b263", - "metadata": { - "id": "1999d67d6d14b263" - }, + "metadata": {}, "source": [ "The output images are saved in the `visualization` directory. The attention map shows how much each drug token attends to each protein token, while the molecule image highlights the atoms based on their attention values." - ] + ], + "cell_type": "markdown", + "id": "1999d67d6d14b263" }, { - "cell_type": "markdown", - "id": "eeb308c3", - "metadata": { - "id": "eeb308c3" - }, + "metadata": {}, "source": [ "## Extension Tasks" - ] + ], + "cell_type": "markdown", + "id": "eeb308c3" }, { - "cell_type": "markdown", - "id": "aa2a83d8", - "metadata": { - "id": "aa2a83d8" - }, + "metadata": {}, "source": [ "### Task 1\n", "\n", @@ -2184,14 +1873,12 @@ "Reload the dataset and re-run training and testing.\n", "\n", "> Tip: See if the model struggles more or less with the new dataset. It can reveal how generalisable DrugBAN is.\n" - ] + ], + "cell_type": "markdown", + "id": "aa2a83d8" }, { - "cell_type": "markdown", - "id": "c94f174c", - "metadata": { - "id": "c94f174c" - }, + "metadata": {}, "source": [ "### Task 2\n", "\n", @@ -2203,1734 +1890,9 @@ "cfg.merge_from_file(\"configs/non_DA_cross_domain.yaml\")\n", "```\n", ">Tip: Compare the results with and without domain adaptation to see how it affects model performance." - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "gpuType": "T4", - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "07d966a71f604f63b00707e6d3a0bfe6": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_37f06572693f4972b468c3997fd0687c", - "IPY_MODEL_a0cec4295099427bb8e97229bd49d620", - "IPY_MODEL_931700f44cb0491ca187cbc58dc67476" - ], - "layout": "IPY_MODEL_dd7ad2a05e22470ab00e2118ff994147" - } - }, - "12c8f4410cc74d4f822c779c724bce94": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_79368bfcf5cc4067a59a46c242a77e2d", - "placeholder": "​", - "style": "IPY_MODEL_8632ffdd1b5844f183cc28e824cd117e", - "value": " 2/2 [00:01<00:00,  1.07it/s]" - } - }, - "16cdd1f5e14e405490575177c52f4408": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "1aa1b891fe6b447586d3e87ee62768a2": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "1ed13da64943461ab42ab18495a6246b": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_8f16c1ee593d406c88c9bfcf32fdd3df", - "placeholder": "​", - "style": "IPY_MODEL_16cdd1f5e14e405490575177c52f4408", - "value": " 305/305 [01:07<00:00,  4.52it/s, v_num=2]" - } - }, - "2439bd152f3f40a190e80cb5156a8be7": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "2f468dcdec6d4c6db229fcab061fd0d8": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_301141b0856941cba6a0a1f496267c4e", - "placeholder": "​", - "style": "IPY_MODEL_b4f64adc9af34262bb35ea42f0f58899", - "value": "Validation DataLoader 0: 100%" - } - }, - "2fc478f35c3e48f2b0c13bd9a73f3dc6": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "301141b0856941cba6a0a1f496267c4e": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "304a1cadd7c048a2a34568301fb1c4dd": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "32de4bd89e034c35adc198247950d4bf": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_941d8002127f499c9868abffea2a2429", - "max": 2, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_2fc478f35c3e48f2b0c13bd9a73f3dc6", - "value": 2 - } - }, - "33a0e7cd7f7d49ffa57e962ca2bf0b66": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": "2", - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "37f06572693f4972b468c3997fd0687c": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_1aa1b891fe6b447586d3e87ee62768a2", - "placeholder": "​", - "style": "IPY_MODEL_8618e343c4f1435ebf099bd70605606b", - "value": "Validation DataLoader 0: 100%" - } - }, - "3c34cab162424e51b4dcf0008c39c7a0": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": "2", - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "4b0bcd88167b469da36e8433a4d47377": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_a80aad59381c42edaa2adb89d781ddd6", - "max": 305, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_b71c219972474c2c89ed04fa48ec637d", - "value": 305 - } - }, - "530bec1df42241a9be6a34c561be4a36": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_3c34cab162424e51b4dcf0008c39c7a0", - "max": 29, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_b2e51980648f49f2963ba66745b82b57", - "value": 29 - } - }, - "5436ad6633f7491898361cb877e5c96a": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "635b3ed7aa264fa6907187152359a63f": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "64151477b0f4444b852363f72540296e": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "65bb75f11b1f4064b715a166bed1e215": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "6692877437d14f56bdf0bc4e8793bb4b": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": "inline-flex", - "flex": null, - "flex_flow": "row wrap", - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": "100%" - } - }, - "700bde0d67f64f9d8680c483d232fa5b": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_b9b44b6940884c359ebb60834280b401", - "placeholder": "​", - "style": "IPY_MODEL_807b512710c848eabb6816cf8e9ecc75", - "value": " 29/29 [00:02<00:00, 12.12it/s]" - } - }, - "79368bfcf5cc4067a59a46c242a77e2d": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "7d4fd3d9c5ed4cc6af20be7384a758ac": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": "inline-flex", - "flex": null, - "flex_flow": "row wrap", - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": "100%" - } - }, - "807b512710c848eabb6816cf8e9ecc75": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "8618e343c4f1435ebf099bd70605606b": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "8632ffdd1b5844f183cc28e824cd117e": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "8f16c1ee593d406c88c9bfcf32fdd3df": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "931700f44cb0491ca187cbc58dc67476": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_304a1cadd7c048a2a34568301fb1c4dd", - "placeholder": "​", - "style": "IPY_MODEL_e809f1b951f740ddabc3a3e56d4dd903", - "value": " 29/29 [00:02<00:00, 12.61it/s]" - } - }, - "941d8002127f499c9868abffea2a2429": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": "2", - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "98992d70c9d74d3fa794edf0dc9333b1": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_65bb75f11b1f4064b715a166bed1e215", - "placeholder": "​", - "style": "IPY_MODEL_a893676474004fb0a24023403f5acf46", - "value": "Sanity Checking DataLoader 0: 100%" - } - }, - "9a8deecaaef543fbb18d0f50b83b5abc": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_b113d7ba3f50417fb93de1118b4e4dec", - "IPY_MODEL_4b0bcd88167b469da36e8433a4d47377", - "IPY_MODEL_1ed13da64943461ab42ab18495a6246b" - ], - "layout": "IPY_MODEL_7d4fd3d9c5ed4cc6af20be7384a758ac" - } - }, - "9d5ebfa060ac49d6bd7a8c4837b4fc29": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "a0cec4295099427bb8e97229bd49d620": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_33a0e7cd7f7d49ffa57e962ca2bf0b66", - "max": 29, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_f223529cf08b4235b8e74ad1049c5a60", - "value": 29 - } - }, - "a3aa520f98b042f29f1fca44f24b61e0": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": "2", - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "a80aad59381c42edaa2adb89d781ddd6": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": "2", - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "a893676474004fb0a24023403f5acf46": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "abf0b363155c4b1fb8de818ae41606eb": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": "inline-flex", - "flex": null, - "flex_flow": "row wrap", - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": "hidden", - "width": "100%" - } - }, - "af4b74bd9f13458c8dd9b74ad2004783": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "b113d7ba3f50417fb93de1118b4e4dec": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_635b3ed7aa264fa6907187152359a63f", - "placeholder": "​", - "style": "IPY_MODEL_9d5ebfa060ac49d6bd7a8c4837b4fc29", - "value": "Epoch 1: 100%" - } - }, - "b2e51980648f49f2963ba66745b82b57": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "b4f64adc9af34262bb35ea42f0f58899": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "b71c219972474c2c89ed04fa48ec637d": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "b9b44b6940884c359ebb60834280b401": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "bc415b5a5635482eb20a65601866febf": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_98992d70c9d74d3fa794edf0dc9333b1", - "IPY_MODEL_32de4bd89e034c35adc198247950d4bf", - "IPY_MODEL_12c8f4410cc74d4f822c779c724bce94" - ], - "layout": "IPY_MODEL_fe6bdb21df9c4415b9899aaa96969502" - } - }, - "be7949a38cfe4129abd0583b3b9c080e": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_e9d78f7f6ffd4d22b7f85f65da6e54c3", - "placeholder": "​", - "style": "IPY_MODEL_5436ad6633f7491898361cb877e5c96a", - "value": " 29/29 [00:02<00:00, 10.86it/s]" - } - }, - "d10203b9631842348ce5ee323f56d8c3": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_a3aa520f98b042f29f1fca44f24b61e0", - "max": 29, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_64151477b0f4444b852363f72540296e", - "value": 29 - } - }, - "d782f86576c9463eb171f866f007b9b3": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_f7b99f048d0b430b913f03382dcf6cae", - "IPY_MODEL_d10203b9631842348ce5ee323f56d8c3", - "IPY_MODEL_be7949a38cfe4129abd0583b3b9c080e" - ], - "layout": "IPY_MODEL_6692877437d14f56bdf0bc4e8793bb4b" - } - }, - "dd7ad2a05e22470ab00e2118ff994147": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": "inline-flex", - "flex": null, - "flex_flow": "row wrap", - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": "hidden", - "width": "100%" - } - }, - "e809f1b951f740ddabc3a3e56d4dd903": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "DescriptionStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "DescriptionStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "description_width": "" - } - }, - "e9d78f7f6ffd4d22b7f85f65da6e54c3": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "f223529cf08b4235b8e74ad1049c5a60": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "f60cfca35cab44e5801bbb2b04f9cc6f": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_2f468dcdec6d4c6db229fcab061fd0d8", - "IPY_MODEL_530bec1df42241a9be6a34c561be4a36", - "IPY_MODEL_700bde0d67f64f9d8680c483d232fa5b" - ], - "layout": "IPY_MODEL_abf0b363155c4b1fb8de818ae41606eb" - } - }, - "f7b99f048d0b430b913f03382dcf6cae": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "1.5.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "1.5.0", - "_view_name": "HTMLView", - "description": "", - "description_tooltip": null, - "layout": "IPY_MODEL_2439bd152f3f40a190e80cb5156a8be7", - "placeholder": "​", - "style": "IPY_MODEL_af4b74bd9f13458c8dd9b74ad2004783", - "value": "Testing DataLoader 0: 100%" - } - }, - "fe6bdb21df9c4415b9899aaa96969502": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "1.2.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "1.2.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "1.2.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border": null, - "bottom": null, - "display": "inline-flex", - "flex": null, - "flex_flow": "row wrap", - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "overflow_x": null, - "overflow_y": null, - "padding": null, - "right": null, - "top": null, - "visibility": "hidden", - "width": "100%" - } - } - } + ], + "cell_type": "markdown", + "id": "c94f174c" } - }, - "nbformat": 4, - "nbformat_minor": 5 + ] } From f00e029115910c33b5a4f74d8e7a8c0034945b3f Mon Sep 17 00:00:00 2001 From: "L. M. Riza Rizky" <42672299+zaRizk7@users.noreply.github.com> Date: Mon, 14 Jul 2025 19:36:33 +0100 Subject: [PATCH 3/3] add outputs for cancer --- .../tutorial-cancer.ipynb | 1524 ++++++++++------- 1 file changed, 924 insertions(+), 600 deletions(-) diff --git a/tutorials/multiomics-cancer-classification/tutorial-cancer.ipynb b/tutorials/multiomics-cancer-classification/tutorial-cancer.ipynb index 39bf267..33613ce 100644 --- a/tutorials/multiomics-cancer-classification/tutorial-cancer.ipynb +++ b/tutorials/multiomics-cancer-classification/tutorial-cancer.ipynb @@ -1,602 +1,926 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "dbd2571f", - "metadata": {}, - "source": [ - "# Multiomics Cancer Classification\n", - "\n", - "![](images/mogonet-pykale-api.png)" - ] - }, - { - "cell_type": "markdown", - "id": "e0825580", - "metadata": {}, - "source": [ - "In this tutorial, we demonstrate how to use the standard pipeline in `PyKale` to integrate **patient multiomics data** in **cancer classification**.\n", - "We use **M**ulti-**O**mics **G**raph c**O**nvolutional **NET**works (MOGONET) by **Huang et al. (Nature Communication, 2021)** as an example." - ] - }, - { - "cell_type": "markdown", - "id": "15a944d7", - "metadata": {}, - "source": [ - "This tutorial is about cancer subtypes classification problem, which is a multi-class classification problem. The input is the multiomics data from patient, including mRNA expression data, DNA methylation data, and miRNA expression data. The output will be the subtype of cancers. We have two datasets to work with, **BRCA** and **ROSMAP**. BRCA has five subtypes and ROSMAP has only two." - ] - }, - { - "cell_type": "markdown", - "id": "b419011e", - "metadata": {}, - "source": [ - "## Step 0: Environment Preparation\n", - "\n", - "As a starting point, we will install the required packages and load a set of helper functions to assist throughout this tutorial. To keep the output clean and focused on interpretation, we will also suppress warnings.\n", - "\n", - "To prepare the helper functions and necessary materials, we download them from the GitHub repository." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "551867b5", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "\n", - "!rm -rf /content/mmai-tutorials\n", - "!git clone https://github.com/pykale/mmai-tutorials.git\n", - "\n", - "%cd /content/mmai-tutorials/tutorials/multiomics-cancer-classification\n", - "\n", - "print(\"Changed working directory to:\", os.getcwd())" - ] - }, - { - "cell_type": "markdown", - "id": "e014d91d", - "metadata": {}, - "source": [ - "### Package Installation" - ] - }, - { - "cell_type": "markdown", - "id": "41ce5cef", - "metadata": {}, - "source": [ - "The main package required for this tutorial is `PyKale`.\n", - "\n", - "`PyKale` is an open-source interdisciplinary machine learning library developed at the University of Sheffield, with a focus on applications in biomedical and scientific domains.\n", - "\n", - "Then, we install `PyG` (PyTorch Geometric) and related packages.\n", - "\n", - "[**WARNING**] Please **do not** re-run this session after installation completed. Runing this installation multiple times will trigger issues related to `PyG`. If you want to re-run this installation, please click the `Runtime` on the top menu and choose `Disconnect and delete runtime` before installing.\n", - "\n", - "[Estimated running time] 3 mins" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6050d5b4", - "metadata": {}, - "outputs": [], - "source": [ - "!pip install --quiet \\\n", - " git+https://github.com/pykale/pykale@main \\\n", - " yacs==0.1.8 \\\n", - " torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric \\\n", - " -f https://data.pyg.org/whl/torch-2.6.0+cu124.html \\\n", - " && echo \"pykale,yacs and wfdb installed successfully ✅\" \\\n", - " || echo \"Failed to install pykale,yacs ❌\"" - ] - }, - { - "cell_type": "markdown", - "id": "2027e726", - "metadata": {}, - "source": [ - "We then hide the warnings messages to get a clear output." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1c9c4856", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import warnings\n", - "\n", - "warnings.filterwarnings(\"ignore\")\n", - "os.environ[\"PYTHONWARNINGS\"] = \"ignore\"" - ] - }, - { - "cell_type": "markdown", - "id": "6b32af98", - "metadata": {}, - "source": [ - "### Configuration\n", - "\n", - "To minimize the footprint of the notebook when specifying configurations, we provide a [`config.py`](https://github.com/pykale/mmai-tutorial/blob/main/tutorials/multiomics-cancer-classification/config.py) file that defines default parameters. These can be customized by supplying a `.yaml` configuration file, such as [`configs/BRCA.yaml`](https://github.com/pykale/mmai-tutorial/blob/main/tutorials/multiomics-cancer-classification/configs/BRCA.yaml) as an example.\n", - "\n", - "First, we load the configuration from [`configs/BRCA.yaml`](https://github.com/pykale/mmai-tutorial/blob/main/tutorials/multiomics-cancer-classification/configs/BRCA.yaml)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "20700eaf", - "metadata": {}, - "outputs": [], - "source": [ - "from config import get_cfg_defaults\n", - "\n", - "cfg = get_cfg_defaults()\n", - "cfg.merge_from_file(\"configs/BRCA.yaml\")" - ] - }, - { - "cell_type": "markdown", - "id": "71add965", - "metadata": {}, - "source": [ - "Besides, we also provide a configuration file for another dataset **ROSMAP**, named [`configs/ROSMAP.yaml`](https://github.com/pykale/embc-mmai25/blob/main/tutorials/multiomics-cancer-classification/configs/ROSMAP.yaml). Users can try with this dataset later." - ] - }, - { - "cell_type": "markdown", - "id": "66a1eb4b", - "metadata": {}, - "source": [ - "In this tutorial, we list the hyperparameters we would like users to play with outside the `.yaml` file:\n", - "- `cfg.SOLVER.MAX_EPOCHS_PRETRAIN`: Number of epochs in pre-training stage.\n", - "- `cfg.SOLVER.MAX_EPOCHS`: Number of epochs in training stage.\n", - "- `cfg.DATASET.NUM_MODALITIES`: Number of modalities in the pipeline.\n", - " - `1`: mRNA expression.\n", - " - `2`: mRNA expression + DNA methylation.\n", - " - `3`: mRNA expression + DNA methylation + miRNA expression.\n", - "\n", - "[**NOTE**] Because this tutorial aims to demonmstrate `PyKale` pipeline, we only set `cfg.SOLVER.MAX_EPOCHS_PRETRAIN=100` and `cfg.SOLVER.MAX_EPOCHS=500` to reduce the training time.\n", - "If users are interested, please increase them to get more accurate predictions." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f1f8bb7c", - "metadata": {}, - "outputs": [], - "source": [ - "cfg.SOLVER.MAX_EPOCHS_PRETRAIN = 100\n", - "cfg.SOLVER.MAX_EPOCHS = 500\n", - "cfg.DATASET.NUM_MODALITIES = 3" - ] - }, - { - "cell_type": "markdown", - "id": "3bdf97a1", - "metadata": {}, - "source": [ - "Print hyperparameters:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f85914b1", - "metadata": {}, - "outputs": [], - "source": [ - "print(cfg)" - ] - }, - { - "cell_type": "markdown", - "id": "317fcb9b", - "metadata": {}, - "source": [ - "## Step 1: Data Loading and Preparation\n", - "\n", - "We use two multiomics benchmarks in this tutorial, BRCA and ROSMAP, which have been provided by the authors of MOGONET paper in [their repository](https://github.com/txWang/MOGONET).\n", - "\n", - "If users are interested in more details regarding **data organization, downloading, loading, and pre-processing**, please refer to the [Data page](https://pykale.github.io/mmai-tutorials/tutorials/multiomics-cancer-classification/extend-reading/data.html) of the tutorial." - ] - }, - { - "cell_type": "markdown", - "id": "8bf5c0c0", - "metadata": {}, - "source": [ - "Delete the potential existing data and download new version:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2ecd6082", - "metadata": {}, - "outputs": [], - "source": [ - "!rm -rf dataset/" - ] - }, - { - "cell_type": "markdown", - "id": "868bcf23", - "metadata": {}, - "source": [ - "To load data, we first define a list the names of data files:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1352ea41", - "metadata": {}, - "outputs": [], - "source": [ - "file_names = []\n", - "for modality in range(1, cfg.DATASET.NUM_MODALITIES + 1):\n", - " file_names.append(f\"{modality}_tr.csv\")\n", - " file_names.append(f\"{modality}_lbl_tr.csv\")\n", - " file_names.append(f\"{modality}_te.csv\")\n", - " file_names.append(f\"{modality}_lbl_te.csv\")\n", - " file_names.append(f\"{modality}_feat_name.csv\")" - ] - }, - { - "cell_type": "markdown", - "id": "ef417d0c", - "metadata": {}, - "source": [ - "Then, we download, load, and pre-process the data by `PyKale`.\n", - "\n", - "[Estimated running time] 20s" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9041fabd", - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "from kale.loaddata.multiomics_datasets import SparseMultiomicsDataset\n", - "from kale.prepdata.tabular_transform import ToOneHotEncoding, ToTensor\n", - "\n", - "multiomics_data = SparseMultiomicsDataset(\n", - " root=cfg.DATASET.ROOT,\n", - " raw_file_names=file_names,\n", - " num_modalities=cfg.DATASET.NUM_MODALITIES,\n", - " num_classes=cfg.DATASET.NUM_CLASSES,\n", - " edge_per_node=cfg.MODEL.EDGE_PER_NODE,\n", - " url=cfg.DATASET.URL,\n", - " random_split=cfg.DATASET.RANDOM_SPLIT,\n", - " equal_weight=cfg.MODEL.EQUAL_WEIGHT,\n", - " pre_transform=ToTensor(dtype=torch.float),\n", - " target_pre_transform=ToOneHotEncoding(dtype=torch.float),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "c8819b69", - "metadata": {}, - "source": [ - "Inspect the dataset:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "676ebd93", - "metadata": {}, - "outputs": [], - "source": [ - "print(multiomics_data)" - ] - }, - { - "cell_type": "markdown", - "id": "910ca35a", - "metadata": {}, - "source": [ - "## Step 2: Model Definition" - ] - }, - { - "cell_type": "markdown", - "id": "007e4533", - "metadata": {}, - "source": [ - "If users are interested in more details regarding the model, please refer to the [Helper Function & Model Definition](https://pykale.github.io/mmai-tutorials/tutorials/multiomics-cancer-classification/extend-reading/helper-functions.html) of the tutorial.\n", - "\n", - "To initialize the model, we firstly call `MogonetModel` from [`model.py`](https://github.com/pykale/mmai-tutorials/blob/main/tutorials/multiomics-cancer-classification/model.py)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1537ce26", - "metadata": {}, - "outputs": [], - "source": [ - "from model import MogonetModel\n", - "\n", - "mogonet_model = MogonetModel(cfg, dataset=multiomics_data)" - ] - }, - { - "cell_type": "markdown", - "id": "3bcb4126", - "metadata": {}, - "source": [ - "Visualize the model architecture:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "da221bd6", - "metadata": {}, - "outputs": [], - "source": [ - "print(mogonet_model)" - ] - }, - { - "cell_type": "markdown", - "id": "38d9195c", - "metadata": {}, - "source": [ - "## Step 3: Model Training" - ] - }, - { - "cell_type": "markdown", - "id": "a7f6ad5c", - "metadata": {}, - "source": [ - "### Pretrain Unimodal Encoders\n", - "\n", - "Before training the multiomics model, we first pretrain encoders for each modality independently. This step helps each GCN encoder learn a good representation of its respective modality before integration.\n", - "\n", - "We can define the trainer of pretraining stage by:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7383c5c1", - "metadata": {}, - "outputs": [], - "source": [ - "import pytorch_lightning as pl\n", - "\n", - "network = mogonet_model.get_model(pretrain=True)\n", - "trainer_pretrain = pl.Trainer(\n", - " max_epochs=cfg.SOLVER.MAX_EPOCHS_PRETRAIN,\n", - " default_root_dir=cfg.OUTPUT.OUT_DIR,\n", - " accelerator=\"auto\",\n", - " devices=\"auto\",\n", - " enable_model_summary=False,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "b0c71889", - "metadata": {}, - "source": [ - "We pretrain the model by:\n", - "\n", - "\n", - "[Estimated running time] 15s for 100 epochs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2b42b719", - "metadata": {}, - "outputs": [], - "source": [ - "trainer_pretrain.fit(network)" - ] - }, - { - "cell_type": "markdown", - "id": "0b03d93e", - "metadata": {}, - "source": [ - "### Train the Multimodal Model\n", - "After pretraining the unimodal pathways, we now train the full MOGONET model by enabling the VCDN. In this stage, all modality-specific encoders and VCDN are trained.\n", - "\n", - "We define the trainer of multimodal training by:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e94b710d", - "metadata": {}, - "outputs": [], - "source": [ - "network = mogonet_model.get_model(pretrain=False)\n", - "trainer = pl.Trainer(\n", - " max_epochs=cfg.SOLVER.MAX_EPOCHS,\n", - " default_root_dir=cfg.OUTPUT.OUT_DIR,\n", - " accelerator=\"auto\",\n", - " devices=\"auto\",\n", - " enable_model_summary=False,\n", - " log_every_n_steps=1,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "31b76385", - "metadata": {}, - "source": [ - "We start the multimodal training by:\n", - "\n", - "\n", - "[Estimated running time] 1 min for 500 epochs" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b3e66c8f", - "metadata": {}, - "outputs": [], - "source": [ - "trainer.fit(network)" - ] - }, - { - "cell_type": "markdown", - "id": "d41dc02a", - "metadata": {}, - "source": [ - "## Step 4: Evaluation\n", - "Once training is complete, we evaluate the model on the test set using `trainer.test()`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "019e2e7b", - "metadata": {}, - "outputs": [], - "source": [ - "trainer.test(network)" - ] - }, - { - "cell_type": "markdown", - "id": "719c655c", - "metadata": {}, - "source": [ - "## Step 5: Interpretation Study\n", - "We use `kale.interpret` to perform interpretation, where a function that systematically masks input features and observes the effect on performance—highlighting which features are most important for classification is provided. Please refer to [Interpret Study page](https://pykale.github.io/mmai-tutorials/tutorials/multiomics-cancer-classification/extend-reading/interpretation-study.html) for more details.\n", - "\n", - "Because the interpretation study needs us to mask one feature and observe the performance drop, we firstly define the trainer for the interpretation experiments.\n", - "\n", - "[**NOTE**] The final results may be different from what they should be because we only train the model for a few epochs to reduce waiting time in this tutorial." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f061dd93", - "metadata": {}, - "outputs": [], - "source": [ - "from kale.interpret.model_weights import select_top_features_by_masking\n", - "import pytorch_lightning as pl\n", - "\n", - "trainer_biomarker = pl.Trainer(\n", - " max_epochs=cfg.SOLVER.MAX_EPOCHS,\n", - " accelerator=\"auto\",\n", - " devices=\"auto\",\n", - " enable_progress_bar=False,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "38a31ccf", - "metadata": {}, - "source": [ - "Then, we start the experiment." - ] - }, - { - "cell_type": "markdown", - "id": "4a754a08", - "metadata": {}, - "source": [ - "To supress the verbose messages in the following experiments:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e428229c", - "metadata": {}, - "outputs": [], - "source": [ - "import logging\n", - "\n", - "logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR)" - ] - }, - { - "cell_type": "markdown", - "id": "8565e576", - "metadata": {}, - "source": [ - "Run the interpretation experiments:\n", - "\n", - "[Estimated running time] Because the following block will train the model for 2,503 times for BRCA dataset, the following block may take about 6 minutes." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2dd9e5e3", - "metadata": {}, - "outputs": [], - "source": [ - "f1_key = \"F1\" if multiomics_data.num_classes == 2 else \"F1 macro\"\n", - "df_featimp_top = select_top_features_by_masking(\n", - " trainer=trainer_biomarker,\n", - " model=network,\n", - " dataset=multiomics_data,\n", - " metric=f1_key,\n", - " num_top_feats=30,\n", - " verbose=False,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "964300ae", - "metadata": {}, - "source": [ - "Print the most important features:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c984bdb1", - "metadata": {}, - "outputs": [], - "source": [ - "print(\"{:>4}\\t{:<20}\\t{:>5}\\t{}\".format(\"Rank\", \"Feature name\", \"Omics\", \"Importance\"))\n", - "for rank, row in enumerate(df_featimp_top.itertuples(index=False), 1):\n", - " print(f\"{rank:>4}\\t{row.feat_name:<20}\\t{row.omics:>5}\\t{row.imp:.4f}\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "mmai-cancer-tutorial", - "language": "python", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 + "nbformat": 4, + "nbformat_minor": 5, + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "cells": [ + { + "metadata": {}, + "source": [ + "# Multiomics Cancer Classification\n", + "\n", + "![](https://github.com/pykale/mmai-tutorials/blob/main/tutorials/multiomics-cancer-classification/images/mogonet-pykale-api.png?raw=1)" + ], + "cell_type": "markdown", + "id": "dbd2571f" + }, + { + "metadata": {}, + "source": [ + "In this tutorial, we will use a [**M**ulti-**O**mics **G**raph c**O**nvolutional **NET**works (MOGONET) by **Wang et al. (Nature Communication, 2021)**](https://www.nature.com/articles/s41467-021-23774-w) [1] pipeline implemented in `PyKale` [2] to integrate **patient multiomics data** for **cancer subtypes classification**.\n", + "\n", + "We will work with multiomics data from two datasets: [**BRCA** of TCGA](https://www.cancerimagingarchive.net/collection/tcga-brca/) [3] and [**ROSMAP**](https://www.synapse.org/Synapse:syn3219045) [4,5]. The BRCA dataset has five subtypes, while the ROSMAP dataset has only two. Three omics modalities will be used: mRNA expression, DNA methylation, and miRNA expression.\n", + "\n", + "The multimodal approach used in this tutorial involves **interaction**, where a cross-omics tensor is constructed for the probability interaction across three omics modalities.\n", + "\n", + "The main tasks of this tutorial are:\n", + "\n", + "- Load BRCA or ROSMAP dataset.\n", + "- Define a MOGONET model.\n", + "- Train and evaluate the MOGONET model on the multiomics data.\n", + "- Obtain the feature importance and visualize the interpretation of the model." + ], + "cell_type": "markdown", + "id": "e0825580" + }, + { + "metadata": {}, + "source": [ + "## Step 0: Environment Preparation\n", + "\n", + "As a starting point, we will install the required packages and load a set of helper functions to assist throughout this tutorial. To keep the output clean and focused on interpretation, we will also suppress warnings.\n", + "\n", + "To prepare the helper functions and necessary materials, we download them from the GitHub repository." + ], + "cell_type": "markdown", + "id": "b419011e" + }, + { + "metadata": {}, + "source": [ + "import os\n", + "\n", + "!rm -rf /content/mmai-tutorials\n", + "!git clone https://github.com/pykale/mmai-tutorials.git\n", + "\n", + "%cd /content/mmai-tutorials/tutorials/multiomics-cancer-classification\n", + "\n", + "print(\"Changed working directory to:\", os.getcwd())" + ], + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "fatal: destination path 'mmai-tutorials' already exists and is not an empty directory.\n", + "/content/mmai-tutorials/tutorials/multiomics-cancer-classification\n", + "Changed working directory to: /content/mmai-tutorials/tutorials/multiomics-cancer-classification\n" + ] + } + ], + "id": "551867b5", + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "### Package Installation" + ], + "cell_type": "markdown", + "id": "e014d91d" + }, + { + "metadata": {}, + "source": [ + "The main package required for this tutorial is `PyKale`.\n", + "\n", + "`PyKale` is an open-source interdisciplinary machine learning library developed at the University of Sheffield, with a focus on applications in biomedical and scientific domains.\n", + "\n", + "Then, we install `PyG` (PyTorch Geometric) and related packages.\n", + "\n", + "[**WARNING**] Please **do not** re-run this session after installation completed. Runing this installation multiple times will trigger issues related to `PyG`. If you want to re-run this installation, please click the `Runtime` on the top menu and choose `Disconnect and delete runtime` before installing.\n", + "\n", + "[Estimated running time] 3 mins" + ], + "cell_type": "markdown", + "id": "41ce5cef" + }, + { + "metadata": {}, + "source": [ + "%pip install --quiet \\\n", + " \"pykale[example]@git+https://github.com/pykale/pykale@main\" \\\n", + " gdown==5.2.0 torch-geometric==2.6.0 torch_sparse torch_scatter \\\n", + " -f https://data.pyg.org/whl/torch-2.6.0+cu124.html \\\n", + " && echo \"pykale, gdown, nilearn, and yacs installed successfully \u2705\" \\\n", + " || echo \"Failed to install pykale, gdown, nilearn, and yacs \u274c\"" + ], + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "pykale, gdown, nilearn, and yacs installed successfully \u2705\n" + ] + } + ], + "id": "6050d5b4", + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "We then hide the warnings messages to get a clear output." + ], + "cell_type": "markdown", + "id": "2027e726" + }, + { + "metadata": {}, + "source": [ + "import os\n", + "import warnings\n", + "\n", + "warnings.filterwarnings(\"ignore\")\n", + "os.environ[\"PYTHONWARNINGS\"] = \"ignore\"" + ], + "cell_type": "code", + "outputs": [], + "id": "1c9c4856", + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "### Configuration\n", + "\n", + "To minimize the footprint of the notebook when specifying configurations, we provide a [`config.py`](https://github.com/pykale/mmai-tutorial/blob/main/tutorials/multiomics-cancer-classification/config.py) file that defines default parameters. These can be customized by supplying a `.yaml` configuration file, such as [`configs/BRCA.yaml`](https://github.com/pykale/mmai-tutorial/blob/main/tutorials/multiomics-cancer-classification/configs/BRCA.yaml) as an example.\n", + "\n", + "First, we load the configuration from [`configs/BRCA.yaml`](https://github.com/pykale/mmai-tutorial/blob/main/tutorials/multiomics-cancer-classification/configs/BRCA.yaml)." + ], + "cell_type": "markdown", + "id": "6b32af98" + }, + { + "metadata": {}, + "source": [ + "from config import get_cfg_defaults\n", + "\n", + "cfg = get_cfg_defaults()\n", + "cfg.merge_from_file(\"configs/BRCA.yaml\")" + ], + "cell_type": "code", + "outputs": [], + "id": "20700eaf", + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "Besides, we also provide a configuration file for another dataset **ROSMAP**, named [`configs/ROSMAP.yaml`](https://github.com/pykale/embc-mmai25/blob/main/tutorials/multiomics-cancer-classification/configs/ROSMAP.yaml). Users can try with this dataset later." + ], + "cell_type": "markdown", + "id": "71add965" + }, + { + "metadata": {}, + "source": [ + "In this tutorial, we list the hyperparameters we would like users to play with outside the `.yaml` file:\n", + "- `cfg.SOLVER.MAX_EPOCHS_PRETRAIN`: Number of epochs in pre-training stage.\n", + "- `cfg.SOLVER.MAX_EPOCHS`: Number of epochs in training stage.\n", + "- `cfg.DATASET.NUM_MODALITIES`: Number of modalities in the pipeline.\n", + " - `1`: mRNA expression.\n", + " - `2`: mRNA expression + DNA methylation.\n", + " - `3`: mRNA expression + DNA methylation + miRNA expression.\n", + "\n", + "[**NOTE**] Because this tutorial aims to demonmstrate `PyKale` pipeline, we only set `cfg.SOLVER.MAX_EPOCHS_PRETRAIN=100` and `cfg.SOLVER.MAX_EPOCHS=500` to reduce the training time.\n", + "If users are interested, please increase them to get more accurate predictions." + ], + "cell_type": "markdown", + "id": "66a1eb4b" + }, + { + "metadata": {}, + "source": [ + "cfg.SOLVER.MAX_EPOCHS_PRETRAIN = 100\n", + "cfg.SOLVER.MAX_EPOCHS = 500\n", + "cfg.DATASET.NUM_MODALITIES = 3" + ], + "cell_type": "code", + "outputs": [], + "id": "f1f8bb7c", + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "Print hyperparameters:" + ], + "cell_type": "markdown", + "id": "3bdf97a1" + }, + { + "metadata": {}, + "source": [ + "print(cfg)" + ], + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "DATASET:\n", + " NAME: TCGA_BRCA\n", + " NUM_CLASSES: 5\n", + " NUM_MODALITIES: 3\n", + " RANDOM_SPLIT: False\n", + " ROOT: dataset/\n", + " URL: https://github.com/pykale/data/raw/main/multiomics/TCGA_BRCA.zip\n", + "MODEL:\n", + " EDGE_PER_NODE: 10\n", + " EQUAL_WEIGHT: False\n", + " GCN_DROPOUT_RATE: 0.5\n", + " GCN_HIDDEN_DIM: [400, 400, 200]\n", + " GCN_LR: 0.0005\n", + " GCN_LR_PRETRAIN: 0.001\n", + " VCDN_LR: 0.001\n", + "OUTPUT:\n", + " OUT_DIR: ./outputs\n", + "SOLVER:\n", + " MAX_EPOCHS: 500\n", + " MAX_EPOCHS_PRETRAIN: 100\n", + " SEED: 2023\n" + ] + } + ], + "id": "f85914b1", + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "## Step 1: Data Loading and Preparation\n", + "\n", + "We use two multiomics benchmarks in this tutorial, BRCA and ROSMAP, which have been provided by the authors of MOGONET paper in [their repository](https://github.com/txWang/MOGONET).\n", + "\n", + "If users are interested in more details regarding **data organization, downloading, loading, and pre-processing**, please refer to the [Data page](https://pykale.github.io/mmai-tutorials/tutorials/multiomics-cancer-classification/extend-reading/data.html) of the tutorial." + ], + "cell_type": "markdown", + "id": "317fcb9b" + }, + { + "metadata": {}, + "source": [ + "Delete the potential existing data and download new version:" + ], + "cell_type": "markdown", + "id": "8bf5c0c0" + }, + { + "metadata": {}, + "source": [ + "!rm -rf dataset/" + ], + "cell_type": "code", + "outputs": [], + "id": "2ecd6082", + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "To load data, we first define a list the names of data files:" + ], + "cell_type": "markdown", + "id": "868bcf23" + }, + { + "metadata": {}, + "source": [ + "file_names = []\n", + "for modality in range(1, cfg.DATASET.NUM_MODALITIES + 1):\n", + " file_names.append(f\"{modality}_tr.csv\")\n", + " file_names.append(f\"{modality}_lbl_tr.csv\")\n", + " file_names.append(f\"{modality}_te.csv\")\n", + " file_names.append(f\"{modality}_lbl_te.csv\")\n", + " file_names.append(f\"{modality}_feat_name.csv\")" + ], + "cell_type": "code", + "outputs": [], + "id": "1352ea41", + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "Then, we download, load, and pre-process the data by `PyKale`.\n", + "\n", + "[Estimated running time] 20s" + ], + "cell_type": "markdown", + "id": "ef417d0c" + }, + { + "metadata": {}, + "source": [ + "import torch\n", + "from kale.loaddata.multiomics_datasets import SparseMultiomicsDataset\n", + "from kale.prepdata.tabular_transform import ToOneHotEncoding, ToTensor\n", + "\n", + "multiomics_data = SparseMultiomicsDataset(\n", + " root=cfg.DATASET.ROOT,\n", + " raw_file_names=file_names,\n", + " num_modalities=cfg.DATASET.NUM_MODALITIES,\n", + " num_classes=cfg.DATASET.NUM_CLASSES,\n", + " edge_per_node=cfg.MODEL.EDGE_PER_NODE,\n", + " url=cfg.DATASET.URL,\n", + " random_split=cfg.DATASET.RANDOM_SPLIT,\n", + " equal_weight=cfg.MODEL.EQUAL_WEIGHT,\n", + " pre_transform=ToTensor(dtype=torch.float),\n", + " target_pre_transform=ToOneHotEncoding(dtype=torch.float),\n", + ")" + ], + "cell_type": "code", + "outputs": [], + "id": "9041fabd", + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "Inspect the dataset:" + ], + "cell_type": "markdown", + "id": "c8819b69" + }, + { + "metadata": {}, + "source": [ + "print(multiomics_data)" + ], + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "Dataset info:\n", + " number of modalities: 3\n", + " number of classes: 5\n", + "\n", + " modality | total samples | num train | num test | num features\n", + " -----------------------------------------------------------------\n", + " 1 | 875 | 612 | 263 | 1000 \n", + " 2 | 875 | 612 | 263 | 1000 \n", + " 3 | 875 | 612 | 263 | 503 \n", + " -----------------------------------------------------------------\n", + "\n", + "\n" + ] + } + ], + "id": "676ebd93", + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "## Step 2: Model Definition" + ], + "cell_type": "markdown", + "id": "910ca35a" + }, + { + "metadata": {}, + "source": [ + "If users are interested in more details regarding the model, please refer to the [Helper Function & Model Definition](https://pykale.github.io/mmai-tutorials/tutorials/multiomics-cancer-classification/extend-reading/helper-functions.html) of the tutorial.\n", + "\n", + "To initialize the model, we firstly call `MogonetModel` from [`model.py`](https://github.com/pykale/mmai-tutorials/blob/main/tutorials/multiomics-cancer-classification/model.py)." + ], + "cell_type": "markdown", + "id": "007e4533" + }, + { + "metadata": {}, + "source": [ + "from model import MogonetModel\n", + "\n", + "mogonet_model = MogonetModel(cfg, dataset=multiomics_data)" + ], + "cell_type": "code", + "outputs": [], + "id": "1537ce26", + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "Visualize the model architecture:" + ], + "cell_type": "markdown", + "id": "3bcb4126" + }, + { + "metadata": {}, + "source": [ + "print(mogonet_model)" + ], + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\n", + "Model info:\n", + " Unimodal encoder:\n", + " (1) MogonetGCN(\n", + " (conv1): MogonetGCNConv(1000, 400)\n", + " (conv2): MogonetGCNConv(400, 400)\n", + " (conv3): MogonetGCNConv(400, 200)\n", + ") (2) MogonetGCN(\n", + " (conv1): MogonetGCNConv(1000, 400)\n", + " (conv2): MogonetGCNConv(400, 400)\n", + " (conv3): MogonetGCNConv(400, 200)\n", + ") (3) MogonetGCN(\n", + " (conv1): MogonetGCNConv(503, 400)\n", + " (conv2): MogonetGCNConv(400, 400)\n", + " (conv3): MogonetGCNConv(400, 200)\n", + ")\n", + "\n", + " Unimodal decoder:\n", + " (1) LinearClassifier(\n", + " (fc): Linear(in_features=200, out_features=5, bias=True)\n", + ") (2) LinearClassifier(\n", + " (fc): Linear(in_features=200, out_features=5, bias=True)\n", + ") (3) LinearClassifier(\n", + " (fc): Linear(in_features=200, out_features=5, bias=True)\n", + ")\n", + "\n", + " Multimodal decoder:\n", + " VCDN(\n", + " (model): Sequential(\n", + " (0): Linear(in_features=125, out_features=125, bias=True)\n", + " (1): LeakyReLU(negative_slope=0.25)\n", + " (2): Linear(in_features=125, out_features=5, bias=True)\n", + " )\n", + ")\n" + ] + } + ], + "id": "da221bd6", + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "## Step 3: Model Training" + ], + "cell_type": "markdown", + "id": "38d9195c" + }, + { + "metadata": {}, + "source": [ + "### Pretrain Unimodal Encoders\n", + "\n", + "Before training the multiomics model, we first pretrain encoders for each modality independently. This step helps each GCN encoder learn a good representation of its respective modality before integration.\n", + "\n", + "We can define the trainer of pretraining stage by:" + ], + "cell_type": "markdown", + "id": "a7f6ad5c" + }, + { + "metadata": {}, + "source": [ + "import pytorch_lightning as pl\n", + "\n", + "network = mogonet_model.get_model(pretrain=True)\n", + "trainer_pretrain = pl.Trainer(\n", + " max_epochs=cfg.SOLVER.MAX_EPOCHS_PRETRAIN,\n", + " default_root_dir=cfg.OUTPUT.OUT_DIR,\n", + " accelerator=\"auto\",\n", + " devices=\"auto\",\n", + " enable_model_summary=False,\n", + ")" + ], + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "INFO:pytorch_lightning.utilities.rank_zero:\ud83d\udca1 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n", + "INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True\n", + "INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n", + "INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs\n" + ] + } + ], + "id": "7383c5c1", + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "We pretrain the model by:\n", + "\n", + "\n", + "[Estimated running time] 15s for 100 epochs" + ], + "cell_type": "markdown", + "id": "b0c71889" + }, + { + "metadata": {}, + "source": [ + "trainer_pretrain.fit(network)" + ], + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Training: | | 0/? [00:00\u250f\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2533\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2513\n", + "\u2503 Test metric \u2503 DataLoader 0 \u2503\n", + "\u2521\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2547\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2501\u2529\n", + "\u2502 Accuracy \u2502 0.8019999861717224 \u2502\n", + "\u2502 F1 macro \u2502 0.6880000233650208 \u2502\n", + "\u2502 F1 weighted \u2502 0.7699999809265137 \u2502\n", + "\u2514\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2534\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2518\n", + "\n" + ] + }, + "metadata": {} + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[{'Accuracy': 0.8019999861717224,\n", + " 'F1 weighted': 0.7699999809265137,\n", + " 'F1 macro': 0.6880000233650208}]" + ] + }, + "metadata": {}, + "execution_count": 17 + } + ], + "id": "019e2e7b", + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "## Step 5: Interpretation Study\n", + "We use `kale.interpret` to perform interpretation, where a function that systematically masks input features and observes the effect on performance\u2014highlighting which features are most important for classification is provided. Please refer to [Interpret Study page](https://pykale.github.io/mmai-tutorials/tutorials/multiomics-cancer-classification/extend-reading/interpretation-study.html) for more details.\n", + "\n", + "Because the interpretation study needs us to mask one feature and observe the performance drop, we firstly define the trainer for the interpretation experiments.\n", + "\n", + "[**NOTE**] The final results may be different from what they should be because we only train the model for a few epochs to reduce waiting time in this tutorial." + ], + "cell_type": "markdown", + "id": "719c655c" + }, + { + "metadata": {}, + "source": [ + "from kale.interpret.model_weights import select_top_features_by_masking\n", + "import pytorch_lightning as pl\n", + "\n", + "trainer_biomarker = pl.Trainer(\n", + " max_epochs=cfg.SOLVER.MAX_EPOCHS,\n", + " accelerator=\"auto\",\n", + " devices=\"auto\",\n", + " enable_progress_bar=False,\n", + ")" + ], + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "INFO:pytorch_lightning.utilities.rank_zero:\ud83d\udca1 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n", + "INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True\n", + "INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n", + "INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs\n" + ] + } + ], + "id": "f061dd93", + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "Then, we start the experiment." + ], + "cell_type": "markdown", + "id": "38a31ccf" + }, + { + "metadata": {}, + "source": [ + "To supress the verbose messages in the following experiments:" + ], + "cell_type": "markdown", + "id": "4a754a08" + }, + { + "metadata": {}, + "source": [ + "import logging\n", + "\n", + "logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR)" + ], + "cell_type": "code", + "outputs": [], + "id": "e428229c", + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "Run the interpretation experiments:\n", + "\n", + "[Estimated running time] Because the following block will train the model for 2,503 times for BRCA dataset, the following block may take about 6 minutes." + ], + "cell_type": "markdown", + "id": "8565e576" + }, + { + "metadata": {}, + "source": [ + "f1_key = \"F1\" if multiomics_data.num_classes == 2 else \"F1 macro\"\n", + "df_featimp_top = select_top_features_by_masking(\n", + " trainer=trainer_biomarker,\n", + " model=network,\n", + " dataset=multiomics_data,\n", + " metric=f1_key,\n", + " num_top_feats=30,\n", + " verbose=False,\n", + ")" + ], + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [] + } + ], + "id": "2dd9e5e3", + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "Print the most important features:" + ], + "cell_type": "markdown", + "id": "964300ae" + }, + { + "metadata": {}, + "source": [ + "print(\"{:>4}\\t{:<20}\\t{:>5}\\t{}\".format(\"Rank\", \"Feature name\", \"Omics\", \"Importance\"))\n", + "for rank, row in enumerate(df_featimp_top.itertuples(index=False), 1):\n", + " print(f\"{rank:>4}\\t{row.feat_name:<20}\\t{row.omics:>5}\\t{row.imp:.4f}\")" + ], + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Rank\tFeature name \tOmics\tImportance\n", + " 1\tMSLN|10232 \t 0\t21.0000\n", + " 2\thsa-mir-9-2 \t 2\t17.6050\n", + " 3\thsa-mir-9-1 \t 2\t16.0960\n", + " 4\thsa-mir-203 \t 2\t15.0900\n", + " 5\tABCC11|85320 \t 0\t13.0000\n", + " 6\tTMEM207 \t 1\t13.0000\n", + " 7\tHOXD11 \t 1\t13.0000\n", + " 8\tKRTAP3-1 \t 1\t13.0000\n", + " 9\tOR1J4 \t 1\t13.0000\n", + " 10\tGPR37L1 \t 1\t13.0000\n", + " 11\thsa-mir-2115 \t 2\t11.5690\n", + " 12\thsa-mir-187 \t 2\t11.5690\n", + " 13\thsa-let-7a-3 \t 2\t9.5570\n", + " 14\thsa-let-7f-2 \t 2\t9.0540\n", + " 15\thsa-mir-205 \t 2\t8.5510\n", + " 16\thsa-mir-551b \t 2\t8.5510\n", + " 17\tANKRD45|339416 \t 0\t8.0000\n", + " 18\tNOTCH1|4851 \t 0\t8.0000\n", + " 19\tMDGA2|161357 \t 0\t8.0000\n", + " 20\tARHGEF4|50649 \t 0\t8.0000\n", + " 21\tCRHR1|1394 \t 0\t8.0000\n", + " 22\tCXCL3|2921 \t 0\t8.0000\n", + " 23\tCSDA|8531 \t 0\t8.0000\n", + " 24\tPI3|5266 \t 0\t8.0000\n", + " 25\tSLC43A3|29015 \t 0\t8.0000\n", + " 26\tTRIML2|205860 \t 0\t8.0000\n", + " 27\tRDH10|157506 \t 0\t8.0000\n", + " 28\tIFFO2|126917 \t 0\t8.0000\n", + " 29\tISL2|64843 \t 0\t8.0000\n", + " 30\tFGFBP1|9982 \t 0\t8.0000\n" + ] + } + ], + "id": "c984bdb1", + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "## References\n", + "\n", + "[1] Wang, T., Shao, W., Huang, Z., Tang, H., Zhang, J., Ding, Z., & Huang, K. (2021). MOGONET integrates multi-omics data using graph convolutional networks allowing patient classification and biomarker identification. Nature communications, 12(1), 3445.\n", + "\n", + "[2] Lu, H., Liu, X., Zhou, S., Turner, R., Bai, P., Koot, R. E., ... & Xu, H. (2022, October). PyKale: Knowledge-aware machine learning from multiple sources in Python. In _Proceedings of the 31st ACM International Conference on Information & Knowledge Management_ (pp. 4274-4278).\n", + "\n", + "[3] Lingle, W., Erickson, B. J., Zuley, M. L., Jarosz, R., Bonaccio, E., Filippini, J., Net, J. M., Levi, L., Morris, E. A., Figler, G. G., Elnajjar, P., Kirk, S., Lee, Y., Giger, M., & Gruszauskas, N. (2016). The Cancer Genome Atlas Breast Invasive Carcinoma Collection (TCGA-BRCA) (Version 3) [Data set]. The Cancer Imaging Archive.\n", + "\n", + "\n", + "\n", + "[4] Bennett, D. A., Buchman, A. S., Boyle, P. A., Barnes, L. L., Wilson, R. S., & Schneider, J. A. (2018). Religious orders study and rush memory and aging project. Journal of Alzheimer\u2019s disease, 64(s1), S161-S189.\n", + "\n", + "[5] De Jager, P.L.; Ma, Y.; McCabe, C.; Xu, J.; Vardarajan, B.N.; Felsky, D.; Klein, H.U.; White, C.C.; Peters, M.A.; Lodgson, B.; et al. (2018). A multi-omic atlas of the human frontal cortex for aging and Alzheimer\u2019s disease research. Scientific Data 5, 1-13" + ], + "cell_type": "markdown", + "id": "1da8fd92" + } + ] }