diff --git a/_toc.yml b/_toc.yml index 104102d..7fd2bf6 100644 --- a/_toc.yml +++ b/_toc.yml @@ -23,3 +23,4 @@ parts: - caption: Tutorials chapters: - file: tutorials/brain-disorder-diagnosis/notebook + - file: tutorials/drug-target-interaction/notebook-cross-domain diff --git a/tutorials/drug-target-interaction/configs.py b/tutorials/drug-target-interaction/configs.py new file mode 100644 index 0000000..dd4ba45 --- /dev/null +++ b/tutorials/drug-target-interaction/configs.py @@ -0,0 +1,109 @@ +from yacs.config import CfgNode + +_C = CfgNode() + +# ---------------------------------------------------------------------------- # +# DATA setting +# ---------------------------------------------------------------------------- # +_C.DATA = CfgNode() +_C.DATA.DATASET = None # Name of the dataset to use +_C.DATA.SPLIT = None # Data splitting strategy + +# ---------------------------------------------------------------------------- # +# Drug feature extractor +# ---------------------------------------------------------------------------- # +_C.DRUG = CfgNode() +_C.DRUG.NODE_IN_FEATS = 7 # Number of input node features +_C.DRUG.NODE_IN_EMBEDDING = ( + 128 # Dimensionality of input node features after linear transformation +) +_C.DRUG.PADDING = True # Whether to apply padding +_C.DRUG.HIDDEN_LAYERS = [ + 128, + 128, + 128, +] # Sizes of hidden layers in the GCN feature extractor +_C.DRUG.MAX_NODES = 290 # Max number of nodes to pad to (used when PADDING=True) + +# ---------------------------------------------------------------------------- # +# Protein feature extractor +# ---------------------------------------------------------------------------- # +_C.PROTEIN = CfgNode() +_C.PROTEIN.NUM_FILTERS = [ + 128, + 128, + 128, +] # Number of filters in each convolutional layer +_C.PROTEIN.KERNEL_SIZE = [3, 6, 9] # Kernel size for each convolutional layer +_C.PROTEIN.EMBEDDING_DIM = 128 # Dimension of character embedding for amino acids +_C.PROTEIN.PADDING = True # Whether to apply zero-padding to the embedding + +# ---------------------------------------------------------------------------- # +# BCN setting +# ---------------------------------------------------------------------------- # +_C.BCN = CfgNode() +_C.BCN.HEADS = 2 # Number of attention heads in the Bilinear Attention Network + +# ---------------------------------------------------------------------------- # +# MLP decoder +# ---------------------------------------------------------------------------- # +_C.DECODER = CfgNode() +_C.DECODER.NAME = "MLP" # Decoder type +_C.DECODER.IN_DIM = 256 # Input dimension to the MLP (typically fused BAN feature size) +_C.DECODER.HIDDEN_DIM = 512 # Hidden layer size in the MLP +_C.DECODER.OUT_DIM = 128 # Output dimension before the final classification layer +_C.DECODER.BINARY = 1 # Number of output classes + +# ---------------------------------------------------------------------------- # +# SOLVER +# ---------------------------------------------------------------------------- # +_C.SOLVER = CfgNode() +_C.SOLVER.MAX_EPOCH = 100 # Total number of training epochs +_C.SOLVER.BATCH_SIZE = 64 # Batch size for training and evaluation +_C.SOLVER.NUM_WORKERS = 0 # Number of subprocesses for data loading +_C.SOLVER.LEARNING_RATE = 5e-5 # Learning rate for the main model +_C.SOLVER.DA_LEARNING_RATE = ( + 1e-3 # Learning rate for the domain adaptation (if DA is enabled) +) +_C.SOLVER.SEED = 2048 # Random seed for reproducibility + +# ---------------------------------------------------------------------------- # +# RESULT +# ---------------------------------------------------------------------------- # +_C.RESULT = CfgNode() +_C.RESULT.SAVE_MODEL = True # Whether to save model checkpoints during training + +# ---------------------------------------------------------------------------- # +# Domain adaptation +# ---------------------------------------------------------------------------- # +_C.DA = CfgNode() +_C.DA.TASK = ( + False # False = in-domain splitting task, True = cross-domain splitting task +) +_C.DA.METHOD = "CDAN" # Domain adaptation method to use +_C.DA.USE = False # Whether to enable domain adaptation +_C.DA.INIT_EPOCH = 10 # Number of epochs to wait before applying domain adaptation +_C.DA.LAMB_DA = 1 # Initial value of λ (lambda) used to weight the domain adaptation loss in the total loss # Total loss = model loss + λ * domain loss +_C.DA.RANDOM_LAYER = False # Whether to use a random projection layer in CDAN +_C.DA.ORIGINAL_RANDOM = False # If True, uses the original RandomLayer from the CDAN paper (multi-input form) # If False, uses a simplified linear layer implementation. +_C.DA.RANDOM_DIM = None # Output dimensionality of the random layer (only used if RANDOM_LAYER is True) +_C.DA.USE_ENTROPY = True # Whether to use entropy-based weighting when computing domain adversarial loss + +# ---------------------------------------------------------------------------- # +# Comet config, ignore it If not installed. +# ---------------------------------------------------------------------------- # +_C.COMET = CfgNode() +_C.COMET.USE = ( + True # Enable Comet logging (set True if Comet is installed and configured) +) +_C.COMET.PROJECT_NAME = "drugban-23-May" # Comet project name (if applicable) +_C.COMET.EXPERIMENT_NAME = None # Optional experiment name (e.g., 'drugban-run-1') +_C.COMET.TAG = None # Comet tags (optional) +_C.COMET.API_KEY = "" # Comet API key (leave blank if unused) + + +# ---------------------------------------------------------------------------- # +# Function to return a clone of the default config +# ---------------------------------------------------------------------------- # +def get_cfg_defaults(): + return _C.clone() diff --git a/tutorials/drug-target-interaction/experiments/DA_cross_domain.yaml b/tutorials/drug-target-interaction/experiments/DA_cross_domain.yaml new file mode 100644 index 0000000..87a1c65 --- /dev/null +++ b/tutorials/drug-target-interaction/experiments/DA_cross_domain.yaml @@ -0,0 +1,30 @@ +# This is for cross-domain experiments using DrugBAN with domain adaptation. + +DATA: + DATASET: "bindingdb" # bindingdb, biosnap + SPLIT: "cluster" + +SOLVER: + BATCH_SIZE: 32 + MAX_EPOCH: 100 + LEARNING_RATE: 1e-4 + DA_LEARNING_RATE: 5e-5 + SEED: 20 + +DA: + TASK: True + USE: True + METHOD: "CDAN" + USE_ENTROPY: False + RANDOM_LAYER: True + ORIGINAL_RANDOM: True + RANDOM_DIM: 256 + INIT_EPOCH: 10 + +DECODER: + BINARY: 2 + +# Config below only when you use comet +COMET: + EXPERIMENT_NAME: "DA_cross_domain" + TAG: "DrugBAN_CDAN" diff --git a/tutorials/drug-target-interaction/experiments/non_DA_cross_domain.yaml b/tutorials/drug-target-interaction/experiments/non_DA_cross_domain.yaml new file mode 100644 index 0000000..17ef009 --- /dev/null +++ b/tutorials/drug-target-interaction/experiments/non_DA_cross_domain.yaml @@ -0,0 +1,24 @@ +# This is for cross-domain experiments using DrugBAN without domain adaptation. + + +DATA: + DATASET: "bindingdb" # bindingdb, biosnap + SPLIT: "cluster" + +SOLVER: + BATCH_SIZE: 32 + MAX_EPOCH: 100 + LEARNING_RATE: 5e-5 + SEED: 20 + +DA: + TASK: True + USE: False + +DECODER: + BINARY: 2 + +# Config below only when you use comet +COMET: + EXPERIMENT_NAME: "Non_DA_cross_domain" + TAG: "DrugBAN_Vanilla" diff --git a/tutorials/drug-target-interaction/experiments/non_DA_in_domain.yaml b/tutorials/drug-target-interaction/experiments/non_DA_in_domain.yaml new file mode 100644 index 0000000..1ebb455 --- /dev/null +++ b/tutorials/drug-target-interaction/experiments/non_DA_in_domain.yaml @@ -0,0 +1,23 @@ +# This is for in-domain experiments using DrugBAN without domain adaptation. + +DATA: + DATASET: "bindingdb" # bindingdb, biosnap + SPLIT: "random" # random + +SOLVER: + BATCH_SIZE: 64 + MAX_EPOCH: 100 + LEARNING_RATE: 5e-5 + SEED: 20 + +DA: + TASK: False + USE: False + +DECODER: + BINARY: 1 + +# Config below only when you use comet +COMET: + EXPERIMENT_NAME: "Non_DA_in_domain" + TAG: "DrugBAN_Vanilla" diff --git a/tutorials/drug-target-interaction/notebook-cross-domain.ipynb b/tutorials/drug-target-interaction/notebook-cross-domain.ipynb new file mode 100644 index 0000000..e6b988c --- /dev/null +++ b/tutorials/drug-target-interaction/notebook-cross-domain.ipynb @@ -0,0 +1,1587 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "cells": [ + { + "metadata": {}, + "source": [ + "!pip install \"numpy<2.0\" \"transformers==4.30.2\" --force-reinstall --quiet" + ], + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting numpy<2.0\n", + " Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m61.0/61.0 kB\u001b[0m \u001b[31m3.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting transformers==4.30.2\n", + " Downloading transformers-4.30.2-py3-none-any.whl.metadata (113 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m113.6/113.6 kB\u001b[0m \u001b[31m5.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting filelock (from transformers==4.30.2)\n", + " Downloading filelock-3.18.0-py3-none-any.whl.metadata (2.9 kB)\n", + "Collecting huggingface-hub<1.0,>=0.14.1 (from transformers==4.30.2)\n", + " Downloading huggingface_hub-0.33.0-py3-none-any.whl.metadata (14 kB)\n", + "Collecting packaging>=20.0 (from transformers==4.30.2)\n", + " Downloading packaging-25.0-py3-none-any.whl.metadata (3.3 kB)\n", + "Collecting pyyaml>=5.1 (from transformers==4.30.2)\n", + " Downloading PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.1 kB)\n", + "Collecting regex!=2019.12.17 (from transformers==4.30.2)\n", + " Downloading regex-2024.11.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.5/40.5 kB\u001b[0m \u001b[31m2.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting requests (from transformers==4.30.2)\n", + " Downloading requests-2.32.4-py3-none-any.whl.metadata (4.9 kB)\n", + "Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers==4.30.2)\n", + " Downloading tokenizers-0.13.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)\n", + "Collecting safetensors>=0.3.1 (from transformers==4.30.2)\n", + " Downloading safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)\n", + "Collecting tqdm>=4.27 (from transformers==4.30.2)\n", + " Downloading tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m57.7/57.7 kB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting fsspec>=2023.5.0 (from huggingface-hub<1.0,>=0.14.1->transformers==4.30.2)\n", + " Downloading fsspec-2025.5.1-py3-none-any.whl.metadata (11 kB)\n", + "Collecting typing-extensions>=3.7.4.3 (from huggingface-hub<1.0,>=0.14.1->transformers==4.30.2)\n", + " Downloading typing_extensions-4.14.0-py3-none-any.whl.metadata (3.0 kB)\n", + "Collecting hf-xet<2.0.0,>=1.1.2 (from huggingface-hub<1.0,>=0.14.1->transformers==4.30.2)\n", + " Downloading hf_xet-1.1.5-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (879 bytes)\n", + "Collecting charset_normalizer<4,>=2 (from requests->transformers==4.30.2)\n", + " Downloading charset_normalizer-3.4.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (35 kB)\n", + "Collecting idna<4,>=2.5 (from requests->transformers==4.30.2)\n", + " Downloading idna-3.10-py3-none-any.whl.metadata (10 kB)\n", + "Collecting urllib3<3,>=1.21.1 (from requests->transformers==4.30.2)\n", + " Downloading urllib3-2.5.0-py3-none-any.whl.metadata (6.5 kB)\n", + "Collecting certifi>=2017.4.17 (from requests->transformers==4.30.2)\n", + " Downloading certifi-2025.6.15-py3-none-any.whl.metadata (2.4 kB)\n", + "Downloading transformers-4.30.2-py3-none-any.whl (7.2 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.2/7.2 MB\u001b[0m \u001b[31m99.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m18.3/18.3 MB\u001b[0m \u001b[31m110.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading huggingface_hub-0.33.0-py3-none-any.whl (514 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m514.8/514.8 kB\u001b[0m \u001b[31m36.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading packaging-25.0-py3-none-any.whl (66 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m66.5/66.5 kB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (762 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m763.0/763.0 kB\u001b[0m \u001b[31m41.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading regex-2024.11.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (792 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m792.7/792.7 kB\u001b[0m \u001b[31m47.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (471 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m471.6/471.6 kB\u001b[0m \u001b[31m38.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading tokenizers-0.13.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m75.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading tqdm-4.67.1-py3-none-any.whl (78 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m78.5/78.5 kB\u001b[0m \u001b[31m6.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading filelock-3.18.0-py3-none-any.whl (16 kB)\n", + "Downloading requests-2.32.4-py3-none-any.whl (64 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m64.8/64.8 kB\u001b[0m \u001b[31m5.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading certifi-2025.6.15-py3-none-any.whl (157 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m157.7/157.7 kB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading charset_normalizer-3.4.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (147 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m147.3/147.3 kB\u001b[0m \u001b[31m13.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading fsspec-2025.5.1-py3-none-any.whl (199 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m199.1/199.1 kB\u001b[0m \u001b[31m17.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading hf_xet-1.1.5-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.1/3.1 MB\u001b[0m \u001b[31m97.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading idna-3.10-py3-none-any.whl (70 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m70.4/70.4 kB\u001b[0m \u001b[31m5.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading typing_extensions-4.14.0-py3-none-any.whl (43 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.8/43.8 kB\u001b[0m \u001b[31m3.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading urllib3-2.5.0-py3-none-any.whl (129 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m129.8/129.8 kB\u001b[0m \u001b[31m10.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hInstalling collected packages: tokenizers, urllib3, typing-extensions, tqdm, safetensors, regex, pyyaml, packaging, numpy, idna, hf-xet, fsspec, filelock, charset_normalizer, certifi, requests, huggingface-hub, transformers\n", + " Attempting uninstall: tokenizers\n", + " Found existing installation: tokenizers 0.21.1\n", + " Uninstalling tokenizers-0.21.1:\n", + " Successfully uninstalled tokenizers-0.21.1\n", + " Attempting uninstall: urllib3\n", + " Found existing installation: urllib3 2.4.0\n", + " Uninstalling urllib3-2.4.0:\n", + " Successfully uninstalled urllib3-2.4.0\n", + " Attempting uninstall: typing-extensions\n", + " Found existing installation: typing_extensions 4.14.0\n", + " Uninstalling typing_extensions-4.14.0:\n", + " Successfully uninstalled typing_extensions-4.14.0\n", + " Attempting uninstall: tqdm\n", + " Found existing installation: tqdm 4.67.1\n", + " Uninstalling tqdm-4.67.1:\n", + " Successfully uninstalled tqdm-4.67.1\n", + " Attempting uninstall: safetensors\n", + " Found existing installation: safetensors 0.5.3\n", + " Uninstalling safetensors-0.5.3:\n", + " Successfully uninstalled safetensors-0.5.3\n", + " Attempting uninstall: regex\n", + " Found existing installation: regex 2024.11.6\n", + " Uninstalling regex-2024.11.6:\n", + " Successfully uninstalled regex-2024.11.6\n", + " Attempting uninstall: pyyaml\n", + " Found existing installation: PyYAML 6.0.2\n", + " Uninstalling PyYAML-6.0.2:\n", + " Successfully uninstalled PyYAML-6.0.2\n", + " Attempting uninstall: packaging\n", + " Found existing installation: packaging 24.2\n", + " Uninstalling packaging-24.2:\n", + " Successfully uninstalled packaging-24.2\n", + " Attempting uninstall: numpy\n", + " Found existing installation: numpy 2.0.2\n", + " Uninstalling numpy-2.0.2:\n", + " Successfully uninstalled numpy-2.0.2\n", + " Attempting uninstall: idna\n", + " Found existing installation: idna 3.10\n", + " Uninstalling idna-3.10:\n", + " Successfully uninstalled idna-3.10\n", + " Attempting uninstall: hf-xet\n", + " Found existing installation: hf-xet 1.1.3\n", + " Uninstalling hf-xet-1.1.3:\n", + " Successfully uninstalled hf-xet-1.1.3\n", + " Attempting uninstall: fsspec\n", + " Found existing installation: fsspec 2025.3.2\n", + " Uninstalling fsspec-2025.3.2:\n", + " Successfully uninstalled fsspec-2025.3.2\n", + " Attempting uninstall: filelock\n", + " Found existing installation: filelock 3.18.0\n", + " Uninstalling filelock-3.18.0:\n", + " Successfully uninstalled filelock-3.18.0\n", + " Attempting uninstall: charset_normalizer\n", + " Found existing installation: charset-normalizer 3.4.2\n", + " Uninstalling charset-normalizer-3.4.2:\n", + " Successfully uninstalled charset-normalizer-3.4.2\n", + " Attempting uninstall: certifi\n", + " Found existing installation: certifi 2025.6.15\n", + " Uninstalling certifi-2025.6.15:\n", + " Successfully uninstalled certifi-2025.6.15\n", + " Attempting uninstall: requests\n", + " Found existing installation: requests 2.32.3\n", + " Uninstalling requests-2.32.3:\n", + " Successfully uninstalled requests-2.32.3\n", + " Attempting uninstall: huggingface-hub\n", + " Found existing installation: huggingface-hub 0.33.0\n", + " Uninstalling huggingface-hub-0.33.0:\n", + " Successfully uninstalled huggingface-hub-0.33.0\n", + " Attempting uninstall: transformers\n", + " Found existing installation: transformers 4.52.4\n", + " Uninstalling transformers-4.52.4:\n", + " Successfully uninstalled transformers-4.52.4\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "google-colab 1.0.0 requires requests==2.32.3, but you have requests 2.32.4 which is incompatible.\n", + "gcsfs 2025.3.2 requires fsspec==2025.3.2, but you have fsspec 2025.5.1 which is incompatible.\n", + "langchain-core 0.3.65 requires packaging<25,>=23.2, but you have packaging 25.0 which is incompatible.\n", + "thinc 8.3.6 requires numpy<3.0.0,>=2.0.0, but you have numpy 1.26.4 which is incompatible.\n", + "sentence-transformers 4.1.0 requires transformers<5.0.0,>=4.41.0, but you have transformers 4.30.2 which is incompatible.\n", + "torch 2.6.0+cu124 requires nvidia-cublas-cu12==12.4.5.8; platform_system == \"Linux\" and platform_machine == \"x86_64\", but you have nvidia-cublas-cu12 12.5.3.2 which is incompatible.\n", + "torch 2.6.0+cu124 requires nvidia-cuda-cupti-cu12==12.4.127; platform_system == \"Linux\" and platform_machine == \"x86_64\", but you have nvidia-cuda-cupti-cu12 12.5.82 which is incompatible.\n", + "torch 2.6.0+cu124 requires nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == \"Linux\" and platform_machine == \"x86_64\", but you have nvidia-cuda-nvrtc-cu12 12.5.82 which is incompatible.\n", + "torch 2.6.0+cu124 requires nvidia-cuda-runtime-cu12==12.4.127; platform_system == \"Linux\" and platform_machine == \"x86_64\", but you have nvidia-cuda-runtime-cu12 12.5.82 which is incompatible.\n", + "torch 2.6.0+cu124 requires nvidia-cudnn-cu12==9.1.0.70; platform_system == \"Linux\" and platform_machine == \"x86_64\", but you have nvidia-cudnn-cu12 9.3.0.75 which is incompatible.\n", + "torch 2.6.0+cu124 requires nvidia-cufft-cu12==11.2.1.3; platform_system == \"Linux\" and platform_machine == \"x86_64\", but you have nvidia-cufft-cu12 11.2.3.61 which is incompatible.\n", + "torch 2.6.0+cu124 requires nvidia-curand-cu12==10.3.5.147; platform_system == \"Linux\" and platform_machine == \"x86_64\", but you have nvidia-curand-cu12 10.3.6.82 which is incompatible.\n", + "torch 2.6.0+cu124 requires nvidia-cusolver-cu12==11.6.1.9; platform_system == \"Linux\" and platform_machine == \"x86_64\", but you have nvidia-cusolver-cu12 11.6.3.83 which is incompatible.\n", + "torch 2.6.0+cu124 requires nvidia-cusparse-cu12==12.3.1.170; platform_system == \"Linux\" and platform_machine == \"x86_64\", but you have nvidia-cusparse-cu12 12.5.1.3 which is incompatible.\n", + "torch 2.6.0+cu124 requires nvidia-nvjitlink-cu12==12.4.127; platform_system == \"Linux\" and platform_machine == \"x86_64\", but you have nvidia-nvjitlink-cu12 12.5.82 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0mSuccessfully installed certifi-2025.6.15 charset_normalizer-3.4.2 filelock-3.18.0 fsspec-2025.5.1 hf-xet-1.1.5 huggingface-hub-0.33.0 idna-3.10 numpy-1.26.4 packaging-25.0 pyyaml-6.0.2 regex-2024.11.6 requests-2.32.4 safetensors-0.5.3 tokenizers-0.13.3 tqdm-4.67.1 transformers-4.30.2 typing-extensions-4.14.0 urllib3-2.5.0\n" + ] + }, + { + "output_type": "display_data", + "data": { + "application/vnd.colab-display-data+json": { + "pip_warning": { + "packages": [ + "certifi", + "numpy", + "packaging" + ] + }, + "id": "45157e696c714aad87be51c98f340caf" + } + }, + "metadata": {} + } + ], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "# ✅ Now run this AFTER restarting runtime\n", + "import numpy as np\n", + "\n", + "print(\"NumPy version:\", np.__version__) # should be <2.0" + ], + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "NumPy version: 1.26.4\n" + ] + } + ], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "# **Drug–Target Interaction Prediction**\n", + "\n", + "Welcome to this tutorial on drug–target interaction (DTI) prediction using the **PyKale library**.\n", + "\n", + "**PyKale** is a Python toolkit that helps make machine learning more approachable, especially for researchers working in interdisciplinary fields. It is particularly useful when dealing with **multimodal data**, which simply means combining different types of data — for example, information about drugs and proteins — to learn patterns from them together.\n", + "\n", + "Even if you’re new to Python or machine learning, don’t worry — we’ll explain key concepts as we go.\n", + "\n", + " \n", + "\n", + "---\n", + "\n", + " \n", + "\n", + "This tutorial builds on the work of [**Bai et al. (_Nature Machine Intelligence_, 2023)**](https://www.nature.com/articles/s42256-022-00605-1), which introduced the **DrugBAN** framework. The DrugBAN includes two key ideas:\n", + "\n", + "- A **bilinear attention network (BAN)**. This is a model that learns the features of both the drug and the protein, and how these features interact locally.\n", + "\n", + "\n", + "- **Adversarial domain adaptation**. This is a method that helps the model generalise to data that is different from what it was trained on (also known as out-of-distribution data), improving its performance on unseen drug–target pairs.\n", + "\n", + " \n", + "\n", + "---\n", + "\n", + "\n", + " \n", + "\n", + "## 🔍 What You'll Learn\n", + "\n", + "In the sections that follow, we’ll guide you through the PyKale development pipeline. Specifically, you will learn how to use PyKale to:\n", + "\n", + "- Load and preprocess the data\n", + "\n", + "- Set up the model and the training process\n", + "\n", + "- Train and test the model\n", + "\n", + "Finally, we will compare the results from DrugBAN with those from other established models.\n", + "\n", + "Let’s get started!\n" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "To begin, we will install the necessary packages required for this tutorial. To maintain clarity and focus on interpretation, we will also suppress any warnings." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "import os\n", + "import warnings\n", + "\n", + "warnings.filterwarnings(\"ignore\")\n", + "os.environ[\"PYTHONWARNINGS\"] = \"ignore\"" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "[Optional] If you are using Google Colab, please using the following codes to load necessary demo data and code files." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "!git clone --branch drug-target-interaction https://github.com/pykale/embc-mmai25.git\n", + "%cd /content/embc-mmai25/tutorials/drug-target-interaction" + ], + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Cloning into 'embc-mmai25'...\n", + "remote: Enumerating objects: 1165, done.\u001b[K\n", + "remote: Counting objects: 100% (186/186), done.\u001b[K\n", + "remote: Compressing objects: 100% (128/128), done.\u001b[K\n", + "remote: Total 1165 (delta 87), reused 119 (delta 54), pack-reused 979 (from 1)\u001b[K\n", + "Receiving objects: 100% (1165/1165), 128.89 MiB | 34.22 MiB/s, done.\n", + "Resolving deltas: 100% (543/543), done.\n", + "/content/embc-mmai25/tutorials/drug-target-interaction\n" + ] + } + ], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "from google.colab import drive\n", + "\n", + "drive.mount(\"/content/drive\")\n", + "\n", + "shared_drives_path = (\n", + " \"/content/drive/Shared drives/EMBC-MMAI 25 Workshop/data/drug-target-interaction\"\n", + ")\n", + "\n", + "import os\n", + "import shutil\n", + "\n", + "print(\"Contents of the folder:\")\n", + "for item in os.listdir(shared_drives_path):\n", + " print(item)" + ], + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Mounted at /content/drive\n", + "Contents of the folder:\n", + "bindingdb\n", + "biosnap\n" + ] + } + ], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "### 📦 Packages\n", + "\n", + "The main packages required for this tutorial are **PyKale**, **PyTorch Geometric**, and **RDKit**.\n", + "\n", + "- **PyKale** is an open-source interdisciplinary machine learning library developed at the University of Sheffield, designed for applications in biomedical and scientific domains.\n", + "- **PyG** (PyTorch Geometric) is a library built on top of PyTorch for building and training Graph Neural Networks (GNNs) on structured data.\n", + "- **RDKit** is a cheminformatics toolkit for handling and processing molecular structures, particularly useful for working with SMILES strings and molecular graphs.\n", + "\n", + "Other dependencies are listed in [`embc-mmai25/requirements.txt`](https://github.com/pykale/embc-mmai25/blob/main/requirements.txt).\n" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "!pip install --quiet git+https://github.com/pykale/pykale@main\\\n", + " && echo \"PyKale installed successfully ✅\" \\\n", + " || echo \"Failed to install PyKale ❌\"\n", + "\n", + "!pip install --quiet -r /content/embc-mmai25/requirements.txt \\\n", + " && echo \"Required packages installed successfully ✅\" \\\n", + " || echo \"Failed to install required packages ❌\"\n", + "\n", + "import torch\n", + "os.environ['TORCH'] = torch.__version__\n", + "!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html\n", + "!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html\n", + "\n", + "!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git \\\n", + " && echo \"PyG installed successfully ✅\" \\\n", + " || echo \"Failed to install PyG ❌\"\n", + "\n", + "!pip install rdkit-pypi \\\n", + " && echo \"PyG installed successfully ✅\" \\\n", + " || echo \"Failed to install PyG ❌\"" + ], + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m812.3/812.3 kB\u001b[0m \u001b[31m25.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m779.2/779.2 MB\u001b[0m \u001b[31m1.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m410.6/410.6 MB\u001b[0m \u001b[31m3.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.1/14.1 MB\u001b[0m \u001b[31m85.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m23.7/23.7 MB\u001b[0m \u001b[31m64.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.6/823.6 kB\u001b[0m \u001b[31m42.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m731.7/731.7 MB\u001b[0m \u001b[31m1.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m121.6/121.6 MB\u001b[0m \u001b[31m7.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.5/56.5 MB\u001b[0m \u001b[31m13.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m124.2/124.2 MB\u001b[0m \u001b[31m7.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m196.0/196.0 MB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m176.2/176.2 MB\u001b[0m \u001b[31m6.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m99.1/99.1 kB\u001b[0m \u001b[31m9.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m168.1/168.1 MB\u001b[0m \u001b[31m6.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.4/7.4 MB\u001b[0m \u001b[31m71.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.0/7.0 MB\u001b[0m \u001b[31m67.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m962.6/962.6 kB\u001b[0m \u001b[31m41.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Building wheel for pykale (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "torchaudio 2.6.0+cu124 requires torch==2.6.0, but you have torch 2.3.0 which is incompatible.\n", + "sentence-transformers 4.1.0 requires transformers<5.0.0,>=4.41.0, but you have transformers 4.30.2 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0mPyKale installed successfully ✅\n", + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m45.0/45.0 kB\u001b[0m \u001b[31m3.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.6/8.6 MB\u001b[0m \u001b[31m103.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.5/10.5 MB\u001b[0m \u001b[31m13.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m80.3/80.3 kB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m83.2/83.2 kB\u001b[0m \u001b[31m6.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.4/3.4 MB\u001b[0m \u001b[31m93.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m434.0/434.0 kB\u001b[0m \u001b[31m28.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.6/4.6 MB\u001b[0m \u001b[31m103.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.2/2.2 MB\u001b[0m \u001b[31m89.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.4/40.4 kB\u001b[0m \u001b[31m3.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.4/127.4 kB\u001b[0m \u001b[31m10.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m60.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.4/1.4 MB\u001b[0m \u001b[31m68.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequired packages installed successfully ✅\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.9/10.9 MB\u001b[0m \u001b[31m74.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.1/5.1 MB\u001b[0m \u001b[31m41.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Building wheel for torch-geometric (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "PyG installed successfully ✅\n", + "Collecting rdkit-pypi\n", + " Downloading rdkit_pypi-2022.9.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.9 kB)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from rdkit-pypi) (1.26.4)\n", + "Requirement already satisfied: Pillow in /usr/local/lib/python3.11/dist-packages (from rdkit-pypi) (11.2.1)\n", + "Downloading rdkit_pypi-2022.9.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (29.4 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m29.4/29.4 MB\u001b[0m \u001b[31m71.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hInstalling collected packages: rdkit-pypi\n", + "Successfully installed rdkit-pypi-2022.9.5\n", + "PyG installed successfully ✅\n" + ] + } + ], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "### ⚙️ Configuration\n", + "\n", + "Before running any model or data processing, we need to tell the code **what settings to use**. To make this easier, we provide a file called `config.py`. Think of this file as a **menu of default settings** that the rest of the code can refer to — for example, where to find the data, which model to use, and how many times to train it.\n", + "\n", + "You can find `config.py` in the same folder as this notebook. You don’t need to change it directly. Instead, we use a **YAML file** to customise the settings.\n", + "\n", + "> **What is a YAML file?** \n", + "> YAML is a simple text file format often used for configuration. It lets you list out settings in a way that’s easier to read than raw Python code.\n", + "\n", + "For example, we have a YAML file called `experiments/non_da_in_domain.yaml`. You can change this file to adjust things like:\n", + "\n", + "- Which dataset to use \n", + "- How long to train the model \n", + "- Which model settings to apply \n", + "\n", + "This helps keep your work organised and flexible. You don’t need to modify the original Python files — just change the YAML file instead.\n", + "\n", + "Now let’s see how we actually load and apply the settings from a YAML file in Python." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "from configs import get_cfg_defaults\n", + "\n", + "# Load the default settings from config.py\n", + "cfg = get_cfg_defaults()\n", + "\n", + "# Update (or override) some of those settings using a custom YAML file\n", + "cfg.merge_from_file(\"experiments/DA_cross_domain.yaml\")\n", + "\n", + "# Example: temporarily shorten the training time by setting fewer training rounds\n", + "cfg.SOLVER.MAX_EPOCH = 2\n", + "\n", + "# Example: switch the dataset to Biosnap\n", + "cfg.DATA.DATASET = \"biosnap\"\n", + "\n", + "# Print the current settings to check what’s being used\n", + "print(cfg)" + ], + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "BCN:\n", + " HEADS: 2\n", + "COMET:\n", + " API_KEY: InDQ1UsqJt7QMiANWg55Ulebe\n", + " EXPERIMENT_NAME: DA_cross_domain\n", + " PROJECT_NAME: drugban-23-May\n", + " TAG: DrugBAN_CDAN\n", + " USE: True\n", + "DA:\n", + " INIT_EPOCH: 10\n", + " LAMB_DA: 1\n", + " METHOD: CDAN\n", + " ORIGINAL_RANDOM: True\n", + " RANDOM_DIM: 256\n", + " RANDOM_LAYER: True\n", + " TASK: True\n", + " USE: True\n", + " USE_ENTROPY: False\n", + "DATA:\n", + " DATASET: biosnap\n", + " SPLIT: cluster\n", + "DECODER:\n", + " BINARY: 2\n", + " HIDDEN_DIM: 512\n", + " IN_DIM: 256\n", + " NAME: MLP\n", + " OUT_DIM: 128\n", + "DRUG:\n", + " HIDDEN_LAYERS: [128, 128, 128]\n", + " MAX_NODES: 290\n", + " NODE_IN_EMBEDDING: 128\n", + " NODE_IN_FEATS: 7\n", + " PADDING: True\n", + "PROTEIN:\n", + " EMBEDDING_DIM: 128\n", + " KERNEL_SIZE: [3, 6, 9]\n", + " NUM_FILTERS: [128, 128, 128]\n", + " PADDING: True\n", + "RESULT:\n", + " SAVE_MODEL: True\n", + "SOLVER:\n", + " BATCH_SIZE: 32\n", + " DA_LEARNING_RATE: 5e-05\n", + " LEARNING_RATE: 0.0001\n", + " MAX_EPOCH: 2\n", + " NUM_WORKERS: 0\n", + " SEED: 20\n" + ] + } + ], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "## Data Overview\n", + "\n", + "In this tutorial, we use a benchmark dataset called **Biosnap**, which contains information about how well different drugs interact with specific proteins. This dataset has been preprocessed and provided by the authors of the **DrugBAN** paper. You can also find it in their [GitHub repository](https://github.com/peizhenbai/DrugBAN/tree/main).\n", + "\n", + "### 📁 Folder Structure\n", + "\n", + "The dataset is stored in a folder called `biosnap`, which contains a few subfolders. Each subfolder corresponds to a different experimental setting for training and testing machine learning models.\n", + "\n", + "Here is a simplified view of the folder structure:\n", + "\n", + "\n", + "\n", + "```sh\n", + " ├───biosnap\n", + " │ ├───cluster\n", + " │ │ ├───source_train.csv\n", + " │ │ ├───target_train.csv\n", + " │ │ ├───target_test.csv\n", + " │ ├───random\n", + " │ │ ├───test.csv\n", + " │ │ ├───train.csv\n", + " │ │ ├───val.csv\n", + " │ ├───full.csv\n", + "\n", + "```\n", + "\n", + "\n", + "Each file listed here is in **CSV format**, which you can open using spreadsheet software (like Excel) or load into Python using tools like `pandas`. These files contain rows of data, with each row representing one drug–protein pair.\n", + "\n", + "### 🧬 What’s Inside Each File?\n", + "\n", + "Each row of the dataset contains three key pieces of information:\n", + "\n", + "**SMILES** \n", + "This is a way to describe the structure of a drug molecule using a short string of letters and symbols. It’s a compact format called *Simplified Molecular Input Line Entry System*. You don’t need to understand chemistry to use this, but just know that this string uniquely represents a drug.\n", + "\n", + "**Protein Sequence** \n", + "This is a string of letters where each letter stands for an amino acid, the building blocks of proteins. For example, `MGYTSLLT...` is a short protein sequence.\n", + "\n", + "**Y** \n", + "This is the label or answer. It tells us whether the drug and the protein interact. \n", + "`1` means yes, they interact. \n", + "`0` means no, they do not interact.\n", + "\n", + "### 📊 Sample of the Data\n", + "\n", + "Here’s what the data looks like in a table format:\n", + "\n", + "| SMILES | Protein Sequence | Y |\n", + "|--------------------|--------------------------|---|\n", + "| Fc1ccc(C2(COC…) | MDNVLPVDSDLS… | 1 |\n", + "| O=c1oc2c(O)c(…) | MMYSKLLTLTTL… | 0 |\n", + "| CC(C)Oc1cc(N…) | MGMACLTMTEME… | 1 |\n", + "\n", + "Each row shows one drug–protein pair. The goal of our machine learning model is to predict the last column (**Y**) — whether or not the drug and protein interact." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "### 🧪 Preprocessing\n", + "\n", + "Before we train a model, we need to prepare the data in a format that the model can understand. This process is called **preprocessing**. In this task, we work with two types of biological data: **drugs** and **proteins**.\n", + "\n", + "#### How is the data represented?\n", + "\n", + "**Drugs**: \n", + "Drugs are often written as SMILES strings, which are like chemical formulas in text format (for example, `\"CC(=O)OC1=CC=CC=C1C(=O)O\"` is aspirin). \n", + "\n", + "To make this information useful for machine learning, we convert each SMILES string into a **molecular graph**. In a molecular graph:\n", + "- Each **atom** is a node\n", + "- Each **bond** is an edge between nodes \n", + "\n", + " \n", + "\n", + "---\n", + "\n", + " \n", + "\n", + "**Proteins**: \n", + "Proteins are sequences of amino acids. We convert each sequence into numbers using:\n", + "\n", + "- **One-hot encoding**, which assigns each amino acid a unique numerical representation. A full sequence is then turned into an embedding-like vector, similar to how sentences are represented in natural language processing.\n", + "\n", + "\n", + " \n", + "\n", + "---\n", + "\n", + " \n", + "\n", + "\n", + "**Labels**: \n", + "Each drug–protein pair is given a label:\n", + "- `1` if they interact (i.e. the drug affects the protein)\n", + "- `0` if they do not\n", + "\n", + " \n", + "\n", + "---\n", + "\n", + " \n", + "\n", + "#### How is preprocessing handled in the code?\n", + "\n", + "We use a class called `DTIDataset`, provided by **PyKale**, to handle all of this preprocessing for us. It takes care of:\n", + "- Reading the data\n", + "- Converting drugs to molecular graphs\n", + "- Encoding protein sequences\n", + "- Assigning labels to each pair\n" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "from kale.loaddata.molecular_datasets import DTIDataset\n", + "import pandas as pd\n", + "\n", + "# Path to the dataset folder\n", + "dataFolder = os.path.join(\n", + " f\"/content/drive/Shared drives/EMBC-MMAI 25 Workshop/data/drug-target-interaction/{cfg.DATA.DATASET}\",\n", + " str(cfg.DATA.SPLIT),\n", + ")\n", + "\n", + "# Path to the dataset folder\n", + "df_train_source = pd.read_csv(os.path.join(dataFolder, \"source_train.csv\"))\n", + "df_train_target = pd.read_csv(os.path.join(dataFolder, \"target_train.csv\"))\n", + "df_test_target = pd.read_csv(os.path.join(dataFolder, \"target_test.csv\"))\n", + "\n", + "# Create preprocessed datasets\n", + "train_dataset = DTIDataset(df_train_source.index.values, df_train_source)\n", + "train_target_dataset = DTIDataset(df_train_target.index.values, df_train_target)\n", + "test_target_dataset = DTIDataset(df_test_target.index.values, df_test_target)" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "### 🗂️ Dataset Inspection\n", + "\n", + "Once we’ve loaded the dataset, it's useful to take a quick look at what it contains. This helps us understand the data format and what kind of information we’ll be working with in the rest of the tutorial.\n", + "\n", + "In this project, our dataset has been split into three parts:\n", + "\n", + "**Train samples from the source domain** \n", + "These are drug–protein pairs the model will learn from. The \"source domain\" typically refers to a distribution of data that the model is familiar with.\n", + "\n", + "**Train samples from the target domain** \n", + "These are additional training samples, but from a different distribution (a \"target domain\") the model should generalise to. This helps simulate real-world scenarios where new data may come from different conditions.\n", + "\n", + "**Test samples from the target domain** \n", + "These are drug–protein pairs that the model has never seen before, and they’re used to evaluate how well the model generalises to new, unseen cases.\n", + "\n", + "Let’s print out the number of samples in each set, and take a peek at one example from the training data.\n" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "print(\n", + " f\"Train samples from source domain: {len(train_dataset)}, Train samples from target domain: {len(train_target_dataset)}, Test samples from target domain: {len(test_target_dataset)}\"\n", + ")\n", + "\n", + "print(\"\\nAn example sample from source domain:\")\n", + "print(train_dataset[0])" + ], + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Train samples from source domain: 9766, Train samples from target domain: 3628, Test samples from target domain: 907\n", + "\n", + "An example sample from source domain:\n", + "(Data(x=[290, 7], edge_index=[2, 58], edge_attr=[58, 1], num_nodes=290), array([11., 1., 18., ..., 0., 0., 0.]), 0.0)\n" + ] + } + ], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "### 🧾 Example Sample Explained\n", + "\n", + "Let’s break down what this example from the **source domain** means:\n", + "\n", + "```\n", + "An example sample from source domain:\n", + "(Data(x=[290, 7], edge_index=[2, 58], edge_attr=[58, 1], num_nodes=290), array([11., 1., 18., ..., 0., 0., 0.]), 0.0)\n", + "```\n", + "\n", + "This sample is a tuple with **three parts**:\n", + "\n", + "---\n", + "\n", + "#### 1. **Drug Graph (Data object)**\n", + "\n", + "This part is a graph-based representation of the **drug**, built using the PyTorch Geometric `Data` object:\n", + "\n", + "- `x=[290, 7]` \n", + " This is a table (matrix) with **290 nodes** (atoms) and **7 features** per atom. \n", + " Each row represents an atom, and each column describes one feature of the atom.\n", + "\n", + "- `edge_index=[2, 58]` \n", + " This shows how the atoms are connected (like chemical bonds). \n", + " There are **58 edges**, and the matrix has 2 rows — the first row lists source atoms, the second lists the target atoms.\n", + "\n", + "- `edge_attr=[58, 1]` \n", + " Each edge (bond) has **1 feature**, such as bond type. \n", + " So there are 58 rows (one for each edge), and 1 column.\n", + "\n", + "- `num_nodes=290` \n", + " This confirms that the graph has **290 atoms (nodes)**.\n", + "\n", + "---\n", + "\n", + "#### 2. **Protein Features (array)**\n", + "\n", + "- This is a **1D array** (or vector) representing the **protein**. \n", + "- It contains numerical features extracted from the protein sequence or structure. \n", + "- Example values: `[11., 1., 18., ..., 0., 0., 0.]` \n", + " These could represent biochemical or structural properties, with padding at the end (zeros) to ensure a consistent input size.\n", + "\n", + "---\n", + "\n", + "#### 3. **Label (float)**\n", + "\n", + "- `0.0` \n", + " This is the **label**, which tells us the ground truth: \n", + " The drug and protein **do not interact** in this sample. \n", + "\n", + " If the label were `1.0`, it would mean they **do interact**.\n", + "\n", + "---\n", + "\n", + "This format allows the model to learn from both structured graph data (the drug) and feature-based data (the protein), and predict whether they interact based on the label.\n" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "### 🧱 Batching\n", + "\n", + "When training machine learning models, especially on large datasets like molecular graphs, it’s inefficient and memory-intensive to load everything at once. Instead, we split the data into **mini-batches** and feed them into the model one at a time. This process is called **batching**, i.e, loading data in manageable pieces.\n", + "\n", + "In this tutorial, we use PyTorch’s `DataLoader` to help us do this. A `DataLoader` handles the process of batching, shuffling, and loading data efficiently during training and evaluation.\n", + "\n", + "However, because molecular data involves **graphs of different sizes and shapes**, we can't just stack them like regular tables or images. That’s where a custom helper function called `graph_collate_func` comes in. This function tells the `DataLoader` how to correctly combine graphs of different structures into a batch.\n", + "\n", + "#### 🔄 Training vs Testing\n", + "\n", + "- During **training**, we shuffle the data randomly. This helps the model generalise better and prevents it from learning the order of the data.\n", + "- During **validation and testing**, we **don’t** shuffle the data. This ensures consistent and reproducible evaluation.\n", + "\n", + "Now let’s see how this looks in code." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "from torch.utils.data import DataLoader\n", + "from kale.loaddata.molecular_datasets import graph_collate_func\n", + "from kale.loaddata.sampler import MultiDataLoader\n", + "\n", + "# Define the parameters used by the data loaders\n", + "params = {\n", + " \"batch_size\": cfg.SOLVER.BATCH_SIZE, # Number of samples per batch\n", + " \"shuffle\": True, # Shuffle data during training\n", + " \"num_workers\": cfg.SOLVER.NUM_WORKERS, # Number of workers loading the data\n", + " \"drop_last\": True, # Drop the last batch if it's smaller than batch_size\n", + " \"collate_fn\": graph_collate_func, # Custom function to batch graphs correctly\n", + "}\n", + "\n", + "\n", + "# Create data loaders for source and target training datasets\n", + "source_generator = DataLoader(train_dataset, **params)\n", + "target_generator = DataLoader(train_target_dataset, **params)\n", + "\n", + "# Get the number of batches in the longer dataset to align both\n", + "n_batches = max(len(source_generator), len(target_generator))\n", + "\n", + "# Combine the source and target data loaders using MultiDataLoader\n", + "training_generator = MultiDataLoader(\n", + " dataloaders=[source_generator, target_generator], n_batches=n_batches\n", + ") # used to be named as multi_generator\n", + "\n", + "\n", + "# Now we set up data loaders for validation and testing. Since we don’t want to shuffle or drop any samples, we adjust the parameters accordingly.\n", + "\n", + "# Update parameters for validation/testing (no shuffling, keep all data)\n", + "params.update({\"shuffle\": False, \"drop_last\": False})\n", + "\n", + "# Create validation and test data loaders\n", + "valid_generator = DataLoader(test_target_dataset, **params)\n", + "test_generator = DataLoader(test_target_dataset, **params)" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "## Model and Trainer Overview\n", + "\n", + "In this section, we'll look at the model and trainer we're using and how to set it up in your code. Don’t worry if the names sound technical — we’ll break them down for you." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "### 🏗️ Setting Up the DrugBAN Model\n", + "\n", + "The **DrugBAN** model is designed to predict whether a drug and protein interact. It brings together different parts of the data using specialised tools from deep learning.\n", + "\n", + "Here’s what DrugBAN is made of:\n", + "\n", + "**1. GCN (Graph Convolutional Network)** \n", + "This handles the structure of drug molecules. It treats each molecule as a graph — where atoms are nodes and bonds are edges — and learns useful patterns from it.\n", + "\n", + "**2. CNN (Convolutional Neural Network)** \n", + "This works with the protein sequences. Think of it like scanning the sequence for patterns, just like image recognition scans for edges or shapes.\n", + "\n", + "**3. BAN (Bilinear Attention Network)** \n", + "This connects the drug and protein features and helps the model learn **how parts of the drug interact with parts of the protein**.\n", + "\n", + "**4. MLP (Multilayer Perceptron)** \n", + "This is the final decision-maker. It takes all the features the model has learned and makes the final prediction: will this drug bind to this protein?\n", + "\n", + "Here’s how you can create the model in your code:" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "from kale.embed.ban import DrugBAN\n", + "\n", + "# Create the model using settings from your config\n", + "model = DrugBAN(**cfg)\n", + "\n", + "# Print the model structure to see what's inside\n", + "print(model)" + ], + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "DrugBAN(\n", + " (drug_extractor): MolecularGCN(\n", + " (init_transform): Linear(in_features=7, out_features=128, bias=False)\n", + " (gcn_layers): ModuleList(\n", + " (0-2): 3 x GCNConv(128, 128)\n", + " )\n", + " )\n", + " (protein_extractor): ProteinCNN(\n", + " (embedding): Embedding(26, 128, padding_idx=0)\n", + " (conv1): Conv1d(128, 128, kernel_size=(3,), stride=(1,))\n", + " (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (conv2): Conv1d(128, 128, kernel_size=(6,), stride=(1,))\n", + " (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (conv3): Conv1d(128, 128, kernel_size=(9,), stride=(1,))\n", + " (bn3): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (bcn): BANLayer(\n", + " (v_net): FCNet(\n", + " (main): Sequential(\n", + " (0): Dropout(p=0.2, inplace=False)\n", + " (1): Linear(in_features=128, out_features=768, bias=True)\n", + " (2): ReLU()\n", + " )\n", + " )\n", + " (q_net): FCNet(\n", + " (main): Sequential(\n", + " (0): Dropout(p=0.2, inplace=False)\n", + " (1): Linear(in_features=128, out_features=768, bias=True)\n", + " (2): ReLU()\n", + " )\n", + " )\n", + " (p_net): AvgPool1d(kernel_size=(3,), stride=(3,), padding=(0,))\n", + " (bn): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (mlp_classifier): MLPDecoder(\n", + " (fc1): Linear(in_features=256, out_features=512, bias=True)\n", + " (bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (fc2): Linear(in_features=512, out_features=512, bias=True)\n", + " (bn2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (fc3): Linear(in_features=512, out_features=128, bias=True)\n", + " (bn3): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (fc4): Linear(in_features=128, out_features=2, bias=True)\n", + " )\n", + ")\n" + ] + } + ], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "### 🏋️‍♀️ Setup Trainer\n", + "\n", + "In this section, we will set up the **training process** using **PyTorch Lightning**, a high-level library that simplifies training loops and experiment tracking in deep learning. Think of it as a way to organise all the messy training code into something tidy and reusable.\n", + "\n", + "We will use a training class called `DrugbanTrainer`, which is part of **PyKale**, and handles model training, domain adaptation, and evaluation.\n", + "\n", + "The values for the trainer's setup come from a configuration file written in YAML. If you are curious about what each setting means, check the YAML file. We've added comments there to explain each parameter.\n", + "\n", + "---\n", + "\n", + "Step 1: Initialise the Trainer\n", + "```python\n", + "from kale.pipeline.drugban_trainer import DrugbanTrainer\n", + "\n", + "drugban_trainer = DrugbanTrainer(\n", + " model=DrugBAN(**cfg),\n", + " solver_lr=cfg[\"SOLVER\"][\"LEARNING_RATE\"], # learning rate for training the model\n", + " num_classes=cfg[\"DECODER\"][\"BINARY\"], # number of output classes (1 for binary classification)\n", + " batch_size=cfg[\"SOLVER\"][\"BATCH_SIZE\"], # how many samples the model sees at once\n", + "\n", + " # Domain adaptation settings (you can think of this as helping the model\n" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "from kale.pipeline.drugban_trainer import DrugbanTrainer\n", + "\n", + "# Create an instance of the trainer with your model and configuration\n", + "drugban_trainer = DrugbanTrainer(\n", + " model=DrugBAN(**cfg),\n", + " solver_lr=cfg[\"SOLVER\"][\"LEARNING_RATE\"],\n", + " num_classes=cfg[\"DECODER\"][\"BINARY\"],\n", + " batch_size=cfg[\"SOLVER\"][\"BATCH_SIZE\"],\n", + " # Domain adaptation settings\n", + " is_da=cfg[\"DA\"][\"USE\"],\n", + " solver_da_lr=cfg[\"SOLVER\"][\"DA_LEARNING_RATE\"],\n", + " da_init_epoch=cfg[\"DA\"][\"INIT_EPOCH\"],\n", + " da_method=cfg[\"DA\"][\"METHOD\"],\n", + " original_random=cfg[\"DA\"][\"ORIGINAL_RANDOM\"],\n", + " use_da_entropy=cfg[\"DA\"][\"USE_ENTROPY\"],\n", + " da_random_layer=cfg[\"DA\"][\"RANDOM_LAYER\"],\n", + " # --- discriminator parameters ---\n", + " da_random_dim=cfg[\"DA\"][\"RANDOM_DIM\"],\n", + " decoder_in_dim=cfg[\"DECODER\"][\"IN_DIM\"],\n", + ")" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "Step 2: Setup Checkpointing\n", + "\n", + "We want to save the best model during training. This helps you avoid rerunning training if you want to reuse the model later." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "import pytorch_lightning as pl\n", + "from pytorch_lightning.callbacks import ModelCheckpoint\n", + "\n", + "# Save the model when it achieves the best AUROC score on the validation set\n", + "checkpoint_callback = ModelCheckpoint(\n", + " filename=\"{epoch}-{step}-{val_BinaryAUROC:.4f}\", # how to name saved files\n", + " monitor=\"val_BinaryAUROC\", # which metric to monitor\n", + " mode=\"max\", # we want to maximise this score\n", + ")" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "Step 3: Launch the Trainer\n", + "\n", + "We now create the actual PyTorch Lightning trainer, which handles the training loop." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "trainer = pl.Trainer(\n", + " callbacks=[checkpoint_callback], # automatically save best model\n", + " devices=\"auto\", # use all available GPUs\n", + " accelerator=(\n", + " \"gpu\" if torch.cuda.is_available() else \"cpu\"\n", + " ), # decide training hardware\n", + " max_epochs=cfg[\"SOLVER\"][\"MAX_EPOCH\"], # how many passes over the training set\n", + " deterministic=True, # makes results reproducible every time you run\n", + ")" + ], + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True\n", + "INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n", + "INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs\n" + ] + } + ], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "## Training and Testing Overview\n", + "\n", + "Before we can make any predictions, we need to **train** the model using known examples of drug–protein interactions. This step helps the model learn patterns, so that it can later predict whether a new drug and protein might interact.\n", + "\n", + "### What is training?\n", + "\n", + "Training is the process where the model adjusts itself to improve its guesses. Imagine giving it many examples of drug–protein pairs along with the correct answers (whether they interact or not). The model learns from these examples by updating its internal settings to reduce mistakes.\n", + "\n", + "### What is validation?\n", + "\n", + "Validation happens *during* training. We use a separate set of data (different from the training data) to check how well the model is doing as it learns. This helps us tune the model without accidentally letting it memorise all the training examples. It’s like checking your understanding by doing practice questions while revising.\n", + "\n", + "### What is testing?\n", + "\n", + "Testing is the final step, done *after* training is complete. We give the model new examples it has never seen before. This tells us how well it might perform in the real world when predicting new drug–protein interactions.\n" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "### 🏋️‍♀️ Training\n", + "\n", + "The following code starts the training process. It uses a function called `fit` which is part of PyTorch Lightning's training system. You do not need to change anything here unless you're experimenting." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "trainer.fit(\n", + " drugban_trainer,\n", + " train_dataloaders=training_generator,\n", + " val_dataloaders=valid_generator,\n", + ")" + ], + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "WARNING:pytorch_lightning.loggers.tensorboard:Missing logger folder: /content/embc-mmai25/tutorials/drug-target-interaction/lightning_logs\n", + "INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "INFO:pytorch_lightning.callbacks.model_summary:\n", + " | Name | Type | Params | Mode \n", + "---------------------------------------------------------------------\n", + "0 | model | DrugBAN | 1.0 M | train\n", + "1 | domain_discriminator | DomainNetSmallImage | 133 K | train\n", + "2 | random_layer | RandomLayer | 66.0 K | train\n", + "3 | valid_metrics | MetricCollection | 0 | train\n", + "4 | test_metrics | MetricCollection | 0 | train\n", + "---------------------------------------------------------------------\n", + "1.2 M Trainable params\n", + "0 Non-trainable params\n", + "1.2 M Total params\n", + "4.847 Total estimated model params size (MB)\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Sanity Checking: | | 0/? [00:00┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Test metric DataLoader 0 ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ test_BinaryAUROC 0.5569921731948853 │\n", + "│ test_BinaryAccuracy 0.5126791596412659 │\n", + "│ test_BinaryF1Score 0.21071428060531616 │\n", + "│ test_BinaryRecall 0.12967033684253693 │\n", + "│ test_BinarySpecificity 0.8982300758361816 │\n", + "│ test_accuracy_sklearn 0.532524824142456 │\n", + "│ test_auroc_sklearn 0.5569921135902405 │\n", + "│ test_f1_sklearn 0.6724442839622498 │\n", + "│ test_loss 0.7800763249397278 │\n", + "│ test_optim_threshold 0.15476277470588684 │\n", + "│ test_sensitivity 0.09955751895904541 │\n", + "│ test_specificity 0.9626373648643494 │\n", + "└───────────────────────────┴───────────────────────────┘\n", + "\n" + ] + }, + "metadata": {} + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[{'test_loss': 0.7800763249397278,\n", + " 'test_auroc_sklearn': 0.5569921135902405,\n", + " 'test_accuracy_sklearn': 0.532524824142456,\n", + " 'test_f1_sklearn': 0.6724442839622498,\n", + " 'test_sensitivity': 0.09955751895904541,\n", + " 'test_specificity': 0.9626373648643494,\n", + " 'test_optim_threshold': 0.15476277470588684,\n", + " 'test_BinaryAUROC': 0.5569921731948853,\n", + " 'test_BinaryF1Score': 0.21071428060531616,\n", + " 'test_BinaryRecall': 0.12967033684253693,\n", + " 'test_BinarySpecificity': 0.8982300758361816,\n", + " 'test_BinaryAccuracy': 0.5126791596412659}]" + ] + }, + "metadata": {}, + "execution_count": 15 + } + ], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "### 📊 Understanding the Evaluation Metrics\n", + "\n", + "After testing the model, several performance metrics are displayed. \n", + "These help you understand how well the model is making predictions. \n", + "Below is a brief explanation of each metric:\n", + "\n", + "---\n", + "\n", + "#### **AUROC (Area Under the Receiver Operating Characteristic Curve)** \n", + "Measures the model’s ability to distinguish between positive (interacting) and negative (non-interacting) pairs. \n", + "A value close to 1.0 means excellent distinction. \n", + "A value around 0.5 means the model is guessing randomly.\n", + "\n", + "---\n", + "\n", + "#### **Accuracy** \n", + "Represents the percentage of correct predictions (both positive and negative) out of all predictions. \n", + "While easy to understand, accuracy can be misleading if the classes are imbalanced.\n", + "\n", + "---\n", + "\n", + "#### **F1 Score** \n", + "A balanced measure that combines **precision** and **recall**. \n", + "It is especially useful when you care equally about false positives and false negatives.\n", + "\n", + "---\n", + "\n", + "#### **Recall (also called Sensitivity or True Positive Rate)** \n", + "Shows the proportion of actual positives (interacting pairs) that the model correctly identified. \n", + "High recall means the model is good at finding interactions.\n", + "\n", + "---\n", + "\n", + "#### **Specificity (also called True Negative Rate)** \n", + "Shows the proportion of actual negatives (non-interacting pairs) that the model correctly identified. \n", + "High specificity means the model is good at ruling out non-interactions.\n", + "\n", + "---\n", + "\n", + "#### **Optimised Threshold** \n", + "During evaluation, the model can choose a threshold value for classification that maximises certain metrics like F1 score. \n", + "This threshold is what the model uses to decide between \"interaction\" and \"no interaction\".\n", + "\n", + "---\n", + "\n", + "#### **Loss** \n", + "This is a number the model tries to minimise during training. \n", + "Lower loss generally means better performance, but it should always be considered alongside the other metrics.\n", + "\n", + "---\n", + "\n", + "> These metrics provide different perspectives on the model’s behaviour. Together, they help you judge how well the model performs on your task.\n" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "### 📊 Compare with Baselines\n", + "\n", + "To evaluate the robustness and generalisability of different models, the DrugBAN model was run for **100 epochs** across multiple random seeds and dataset splits. \n", + "\n", + "The figure below compares the performance of these models on the **BioSNAP** and **BindingDB** datasets.\n", + "\n", + "- The **left plot** shows results based on **AUROC** (Area Under the Receiver Operating Characteristic Curve).\n", + "- The **right plot** shows results based on **AUPRC** (Area Under the Precision–Recall Curve).\n", + "\n", + "---\n", + "\n", + "### 🧪 Experimental Setup\n", + "\n", + "- Each model was trained and evaluated multiple times with different random seeds to capture performance variability.\n", + "- Each box plot summarises results from these runs.\n", + "- **DrugBAN** and **DrugBANDA** were trained for 100 epochs per run.\n", + "- Performance on the **BioSNAP** dataset is shown in **blue**, and **BindingDB** results are shown in **orange**.\n", + "\n", + "---\n", + "\n", + "### 📈 How to Read the Box Plots\n", + "\n", + "- The **centre line** of each box represents the **median** performance.\n", + "- The **green triangle** shows the **mean** performance.\n", + "- The **lower and upper edges** of the box indicate the **first and third quartiles**.\n", + "- The **whiskers** show the full range (excluding outliers).\n", + "\n", + "---\n", + "\n", + "### 🔍 Key Insights\n", + "\n", + "- **DrugBANDA** consistently achieves top performance across both metrics and datasets.\n", + "- On the **BioSNAP** dataset (blue), performance varies more across models, highlighting its challenging nature.\n", + "- Simpler models such as **SVM** and **Random Forest (RF)** show limited ability to generalise.\n", + "- Deep learning models such as **GraphDTA** and **MolTrans** show competitive AUROC but less stability in AUPRC.\n", + "- **Domain adaptation** improves the model's ability to generalise from BindingDB to BioSNAP, as seen in DrugBANDA's superior scores.\n", + "\n", + "---" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "![Screenshot from 2025-06-23 16-13-23.png]()" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "In this tutorial, you learned how to use the **PyKale** library to build and evaluate a deep learning model for drug–target interaction (DTI) prediction.\n", + "\n", + "We walked through the pipeline in three key steps:\n", + "\n", + "### 1. Data Overview \n", + "You explored how to load and prepare drug and protein data using PyKale’s data handling tools.\n", + "\n", + "### 2. Model and Trainer Overview \n", + "You saw how to configure and use the **DrugBAN** model with PyKale. You also learned how PyKale’s trainer simplifies the training process, including logging and model saving.\n", + "\n", + "### 3. Training and Testing Overview \n", + "You trained the model and evaluated its performance using commonly used metrics such as **AUROC**, **F1 score**, **recall**, and **specificity**.\n", + "\n", + "---\n", + "\n", + "This notebook is designed as an accessible entry point for researchers who are new to Python or machine learning, and who want to explore how graph-based deep learning can be applied to biomedical problems.\n", + "\n", + "> 🧠 Tip: Try experimenting with the dataset, changing model settings, or applying PyKale to a new task. That’s the best way to learn!\n", + "\n", + "For more information, check out the [original DrugBAN codebase](https://github.com/peizhenbai/DrugBAN) and the full paper in *Nature Machine Intelligence*.\n" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "### Explore More: 3 Tasks to Try (~1 Hour Total)\n", + "\n", + "These tasks are designed to help you go beyond the tutorial and gain deeper, hands-on experience with model development, interpretation, and dataset handling. You don’t need prior experience in machine learning — just curiosity!" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "### 🔁 Task 1: Use the BindingDB Dataset (20 minutes)\n", + "\n", + "Swap the current dataset with **BindingDB**, a real-world dataset containing experimentally measured drug–target interactions.\n", + "\n", + "Steps:\n", + "1. Download or prepare the BindingDB dataset (if needed).\n", + "2. Update the `data` fields in the YAML config file.\n", + "3. Reload the dataset and re-run training and testing.\n", + "\n", + "**What to explore:**\n", + "- Does model performance change?\n", + "- Is the dataset more imbalanced?\n", + "- How does training time compare?\n", + "\n", + "> Tip: See if the model struggles more or less with the new dataset. It can reveal how generalisable DrugBAN is.\n", + "\n" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "### 🧪 Task 2: Inspect Misclassified Samples (20 minutes)\n", + "\n", + "Dive into the test results and check where the model made incorrect predictions. \n", + "This helps you understand where the model struggles and whether those mistakes make sense.\n", + "\n", + "---\n", + "\n", + "#### ✅ Steps\n", + "\n", + "1. After testing, collect predicted probabilities and the true labels.\n", + "2. Print out the predictions that the model got wrong.\n", + "3. (Optional) Visualise the drug or protein graph for those samples.\n", + "\n", + "---\n", + "\n", + "#### 🔍 What to Explore\n", + "\n", + "- **Are the wrong predictions close to 0.5?** \n", + " If so, the model was unsure. This can help you identify borderline cases.\n", + "\n", + "- **Are there more false positives or false negatives?** \n", + " This tells you whether the model is more likely to over-predict interactions or miss them.\n", + "\n", + "- **Do certain types of drug–protein pairs seem harder to classify?** \n", + " For example, some proteins or drugs may always appear in misclassified samples. This could point to noisy or hard-to-learn examples.\n", + "\n", + "---\n", + "\n", + "#### 🧬 (Optional) Visualise the Graph of a Misclassified Sample\n", + "\n", + "You can plot the structure of a drug or protein graph that the model got wrong. \n", + "This helps you interpret what the model was seeing when it made the incorrect prediction.\n" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "### 🧠 Task 3: Change the Protein Encoder (20 minutes)\n", + "\n", + "Explore how the model behaves when you change the way proteins are represented. \n", + "This task helps you understand whether the model relies more on protein structure or other features when making predictions.\n", + "\n", + "---\n", + "\n", + "💡 Ideas to Try\n", + "Use one hot embedding instead of integer encoding.\n", + "\n", + "Replace the protein graph with a flat vector input (e.g. sequence length, molecular weight, or hydrophobicity).\n", + "\n", + "\n", + "🔍 What to Observe\n", + "- Does model performance improve, stay the same, or get worse?\n", + "\n", + "- Which metric changes the most — AUROC, F1 score, or recall?\n", + "\n", + "- Does training take more time or less time with the new encoder?\n", + "\n" + ], + "cell_type": "markdown" + } + ] +}