diff --git a/tutorials/drug-target-interaction/configs.py b/tutorials/drug-target-interaction/configs.py index dd4ba45..a64a522 100644 --- a/tutorials/drug-target-interaction/configs.py +++ b/tutorials/drug-target-interaction/configs.py @@ -94,7 +94,7 @@ # ---------------------------------------------------------------------------- # _C.COMET = CfgNode() _C.COMET.USE = ( - True # Enable Comet logging (set True if Comet is installed and configured) + False # Enable Comet logging (set True if Comet is installed and configured) ) _C.COMET.PROJECT_NAME = "drugban-23-May" # Comet project name (if applicable) _C.COMET.EXPERIMENT_NAME = None # Optional experiment name (e.g., 'drugban-run-1') diff --git a/tutorials/drug-target-interaction/notebook-cross-domain.ipynb b/tutorials/drug-target-interaction/notebook-cross-domain.ipynb deleted file mode 100644 index 69b4754..0000000 --- a/tutorials/drug-target-interaction/notebook-cross-domain.ipynb +++ /dev/null @@ -1,1161 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Drug–Target Interaction Prediction\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Introduction\n", - "\n", - "In this tutorial, we demonstrate the standard pipeline in `PyKale` and show how to integrate multimodal data from **drugs** and **proteins** to perform **drug-target interaction (DTI) prediction**.\n", - "\n", - "\n", - "\n", - "This tutorial builds on the work of [**Bai et al. (_Nature Machine Intelligence_, 2023)**](https://www.nature.com/articles/s42256-022-00605-1), which introduced the **DrugBAN** framework. The DrugBAN includes two key ideas:\n", - "\n", - "- A **bilinear attention network (BAN)**. This is a model that learns the features of both the drug and the protein, and how these features interact locally.\n", - "\n", - "\n", - "- **Adversarial domain adaptation**. This is a method that helps the model generalise to data that is different from what it was trained on (also known as out-of-distribution data), improving its performance on unseen drug–target pairs.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Problem Formulation\n", - "\n", - "This tutorial focuses on the drug–target interaction (DTI) prediction problem, which is framed as a binary classification task. The inputs are drug SMILES strings and protein amino acid sequences, and the output is a binary label (1 or 0) indicating whether an interaction occurs.\n", - "\n", - "We will work with two datasets: **BioSNAP** and **BindingDB**. The main tutorial will use the BioSNAP dataset, while BindingDB is provided as an additional dataset for you to explore and reproduce results in your own time after completing the tutorial." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Objective\n", - "- Understand the standard pipeline of `PyKale` library.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Environment Preparation\n", - "\n", - "As a starting point, we will install the required packages and load a set of helper functions to assist throughout this tutorial. To keep the output clean and focused on interpretation, we will also suppress warnings.\n", - "\n", - "Moreover, we provide helper functions that can be inspected directly in the `.py` files located in the notebook's current directory. The additional helper script is:\n", - "- [`config.py`](https://github.com/pykale/mmai-tutorials/blob/main/tutorials/drug-target-interaction/configs.py): Defines the base configuration settings, which can be overridden using a custom `.yaml` file." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Package Installation\n", - "\n", - "The main packages required for this tutorial are **PyKale**, **PyTorch Geometric**, and **RDKit**.\n", - "\n", - "- **PyKale** is an open-source interdisciplinary machine learning library developed at the University of Sheffield, designed for applications in biomedical and scientific domains.\n", - "- **PyG** (PyTorch Geometric) is a library built on top of PyTorch for building and training Graph Neural Networks (GNNs) on structured data.\n", - "- **RDKit** is a cheminformatics toolkit for handling and processing molecular structures, particularly useful for working with SMILES strings and molecular graphs.\n", - "\n", - "Other required packages can be found in [`mmai-tutorials/requirements.txt`](https://github.com/pykale/mmai-tutorials/blob/main/requirements.txt).\n", - "\n", - "--------------------------------------------------------------------------------\n", - "\n", - "#### **WARNINGS**\n", - "Please don't re-run this session after installation completed. Runing this installation multiple times will trigger issues related to `PyG`. If you want to re-run this installation, please click the `Runtime` on the top menu and choose `Disconnect and delete runtime` before installing.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Cloning into 'embc-mmai25'...\n", - "remote: Enumerating objects: 1444, done.\u001b[K\n", - "remote: Counting objects: 100% (193/193), done.\u001b[K\n", - "remote: Compressing objects: 100% (108/108), done.\u001b[K\n", - "remote: Total 1444 (delta 111), reused 136 (delta 85), pack-reused 1251 (from 1)\u001b[K\n", - "Receiving objects: 100% (1444/1444), 15.80 MiB | 16.94 MiB/s, done.\n", - "Resolving deltas: 100% (752/752), done.\n", - "/content/embc-mmai25/tutorials/drug-target-interaction\n", - " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", - " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m45.0/45.0 kB\u001b[0m \u001b[31m3.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.6/8.6 MB\u001b[0m \u001b[31m88.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.4/10.4 MB\u001b[0m \u001b[31m69.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m80.3/80.3 kB\u001b[0m \u001b[31m7.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m83.2/83.2 kB\u001b[0m \u001b[31m7.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.4/3.4 MB\u001b[0m \u001b[31m104.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m434.0/434.0 kB\u001b[0m \u001b[31m35.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.6/4.6 MB\u001b[0m \u001b[31m105.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.2/2.2 MB\u001b[0m \u001b[31m86.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.4/40.4 kB\u001b[0m \u001b[31m3.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.4/127.4 kB\u001b[0m \u001b[31m11.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m78.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.4/1.4 MB\u001b[0m \u001b[31m63.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h\u001b[33m WARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33m WARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33m WARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", - "thinc 8.3.6 requires numpy<3.0.0,>=2.0.0, but you have numpy 1.26.4 which is incompatible.\n", - "sentence-transformers 4.1.0 requires transformers<5.0.0,>=4.41.0, but you have transformers 4.30.2 which is incompatible.\u001b[0m\u001b[31m\n", - "\u001b[0mRequired packages installed successfully ✅\n", - "\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", - " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", - "cupy-cuda12x 13.3.0 requires numpy<2.3,>=1.22, but you have numpy 2.3.1 which is incompatible.\n", - "tensorflow 2.18.0 requires numpy<2.1.0,>=1.26.0, but you have numpy 2.3.1 which is incompatible.\n", - "numba 0.60.0 requires numpy<2.1,>=1.22, but you have numpy 2.3.1 which is incompatible.\n", - "sentence-transformers 4.1.0 requires transformers<5.0.0,>=4.41.0, but you have transformers 4.30.2 which is incompatible.\u001b[0m\u001b[31m\n", - "\u001b[0mPyKale installed successfully ✅\n", - "\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", - "cupy-cuda12x 13.3.0 requires numpy<2.3,>=1.22, but you have numpy 2.3.1 which is incompatible.\n", - "tensorflow 2.18.0 requires numpy<2.1.0,>=1.26.0, but you have numpy 2.3.1 which is incompatible.\n", - "numba 0.60.0 requires numpy<2.1,>=1.22, but you have numpy 2.3.1 which is incompatible.\n", - "sentence-transformers 4.1.0 requires transformers<5.0.0,>=4.41.0, but you have transformers 4.30.2 which is incompatible.\u001b[0m\u001b[31m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", - " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - " Building wheel for torch-geometric (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33m WARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", - "pykale 0.1.2 requires torch-geometric==2.3.0, but you have torch-geometric 2.7.0 which is incompatible.\n", - "cupy-cuda12x 13.3.0 requires numpy<2.3,>=1.22, but you have numpy 2.3.1 which is incompatible.\n", - "tensorflow 2.18.0 requires numpy<2.1.0,>=1.26.0, but you have numpy 2.3.1 which is incompatible.\n", - "numba 0.60.0 requires numpy<2.1,>=1.22, but you have numpy 2.3.1 which is incompatible.\n", - "sentence-transformers 4.1.0 requires transformers<5.0.0,>=4.41.0, but you have transformers 4.30.2 which is incompatible.\u001b[0m\u001b[31m\n", - "\u001b[0mPyG installed successfully ✅\n", - "\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0mRequirement already satisfied: rdkit-pypi in /usr/local/lib/python3.11/dist-packages (2022.9.5)\n", - "Collecting numpy (from rdkit-pypi)\n", - " Using cached numpy-2.3.1-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (62 kB)\n", - "Requirement already satisfied: Pillow in /usr/local/lib/python3.11/dist-packages (from rdkit-pypi) (11.2.1)\n", - "Using cached numpy-2.3.1-cp311-cp311-manylinux_2_28_x86_64.whl (16.9 MB)\n", - "\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0mInstalling collected packages: numpy\n", - "\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", - "pykale 0.1.2 requires torch-geometric==2.3.0, but you have torch-geometric 2.7.0 which is incompatible.\n", - "cupy-cuda12x 13.3.0 requires numpy<2.3,>=1.22, but you have numpy 2.3.1 which is incompatible.\n", - "tensorflow 2.18.0 requires numpy<2.1.0,>=1.26.0, but you have numpy 2.3.1 which is incompatible.\n", - "numba 0.60.0 requires numpy<2.1,>=1.22, but you have numpy 2.3.1 which is incompatible.\n", - "sentence-transformers 4.1.0 requires transformers<5.0.0,>=4.41.0, but you have transformers 4.30.2 which is incompatible.\u001b[0m\u001b[31m\n", - "\u001b[0mSuccessfully installed numpy\n", - "PyG installed successfully ✅\n", - "\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0mCollecting numpy==2.0.0\n", - " Using cached numpy-2.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)\n", - "Using cached numpy-2.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (19.3 MB)\n", - "\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", - "\u001b[0mInstalling collected packages: numpy\n", - "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", - "pykale 0.1.2 requires torch-geometric==2.3.0, but you have torch-geometric 2.7.0 which is incompatible.\n", - "sentence-transformers 4.1.0 requires transformers<5.0.0,>=4.41.0, but you have transformers 4.30.2 which is incompatible.\u001b[0m\u001b[31m\n", - "\u001b[0mSuccessfully installed numpy-2.0.0\n" - ] - } - ], - "source": [ - "import os\n", - "import warnings\n", - "\n", - "warnings.filterwarnings(\"ignore\")\n", - "os.environ[\"PYTHONWARNINGS\"] = \"ignore\"\n", - "\n", - "!git clone https://github.com/pykale/mmai-tutorials\n", - "%cd /content/mmai-tutorials/tutorials/drug-target-interaction\n", - "\n", - "!pip install --quiet -r /content/mmai-tutorials/requirements.txt \\\n", - " && echo \"Required packages installed successfully ✅\" \\\n", - " || echo \"Failed to install required packages ❌\"\n", - "\n", - "!pip install --quiet git+https://github.com/pykale/pykale@main\\\n", - " && echo \"PyKale installed successfully ✅\" \\\n", - " || echo \"Failed to install PyKale ❌\"\n", - "\n", - "import torch\n", - "os.environ['TORCH'] = torch.__version__\n", - "!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html\n", - "!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html\n", - "\n", - "!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git \\\n", - " && echo \"PyG installed successfully ✅\" \\\n", - " || echo \"Failed to install PyG ❌\"\n", - "\n", - "!pip install rdkit-pypi \\\n", - " && echo \"PyG installed successfully ✅\" \\\n", - " || echo \"Failed to install PyG ❌\"\n", - "\n", - "\n", - "# !pip install \"numpy<2.0\" \"transformers==4.30.2\" --force-reinstall --quiet\n", - "!pip install --upgrade --force-reinstall numpy==2.0.0\n", - "os.kill(os.getpid(), 9)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "NumPy version: 2.0.0\n" - ] - } - ], - "source": [ - "# (Optional: Numpy version check) ✅ Now run this AFTER restarting runtime\n", - "import numpy as np\n", - "\n", - "print(\"NumPy version:\", np.__version__) # should be <2.0" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Mount Data (Optional)\n", - "\n", - "If you are using Google Colab, please using the following codes to load necessary datasets." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mounted at /content/drive\n", - "Contents of the folder:\n", - "bindingdb\n", - "biosnap\n" - ] - } - ], - "source": [ - "from google.colab import drive\n", - "\n", - "drive.mount(\"/content/drive\")\n", - "\n", - "shared_drives_path = (\n", - " \"/content/drive/Shared drives/EMBC-MMAI 25 Workshop/data/drug-target-interaction\"\n", - ")\n", - "\n", - "import os\n", - "import shutil\n", - "\n", - "print(\"Contents of the folder:\")\n", - "for item in os.listdir(shared_drives_path):\n", - " print(item)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Configuration\n", - "\n", - "To minimize the footprint of the notebook when specifying configurations, we provide a [`config.py`](https://github.com/pykale/embc-mmai25/blob/main/tutorials/drug-target-interaction/configs.py) file that defines default parameters. These can be customized by supplying a `.yaml` configuration file, such as [`experiments/DA_cross_domain.yaml`](https://github.com/pykale/embc-mmai25/blob/main/tutorials/drug-target-interaction/experiments/DA_cross_domain.yaml) as an example.\n", - "\n", - "In this tutorial, we list the hyperparameters we would like users to play with outside the `.yaml` file:\n", - "- `cfg.SOLVER.MAX_EPOCH`: Number of epochs in training stage.\n", - "- `cfg.DATA.DATASET`: The dataset used in the study. This can be `bindingdb` or `biosnap`.\n", - "\n", - "As a quick exercise, please take a moment to review and understand the parameters in [`config.py`](https://github.com/pykale/embc-mmai25/blob/main/tutorials/drug-target-interaction/configs.py)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/content/embc-mmai25/tutorials/drug-target-interaction\n", - "BCN:\n", - " HEADS: 2\n", - "COMET:\n", - " API_KEY: \n", - " EXPERIMENT_NAME: DA_cross_domain\n", - " PROJECT_NAME: drugban-23-May\n", - " TAG: DrugBAN_CDAN\n", - " USE: True\n", - "DA:\n", - " INIT_EPOCH: 10\n", - " LAMB_DA: 1\n", - " METHOD: CDAN\n", - " ORIGINAL_RANDOM: True\n", - " RANDOM_DIM: 256\n", - " RANDOM_LAYER: True\n", - " TASK: True\n", - " USE: True\n", - " USE_ENTROPY: False\n", - "DATA:\n", - " DATASET: biosnap\n", - " SPLIT: cluster\n", - "DECODER:\n", - " BINARY: 2\n", - " HIDDEN_DIM: 512\n", - " IN_DIM: 256\n", - " NAME: MLP\n", - " OUT_DIM: 128\n", - "DRUG:\n", - " HIDDEN_LAYERS: [128, 128, 128]\n", - " MAX_NODES: 290\n", - " NODE_IN_EMBEDDING: 128\n", - " NODE_IN_FEATS: 7\n", - " PADDING: True\n", - "PROTEIN:\n", - " EMBEDDING_DIM: 128\n", - " KERNEL_SIZE: [3, 6, 9]\n", - " NUM_FILTERS: [128, 128, 128]\n", - " PADDING: True\n", - "RESULT:\n", - " SAVE_MODEL: True\n", - "SOLVER:\n", - " BATCH_SIZE: 32\n", - " DA_LEARNING_RATE: 5e-05\n", - " LEARNING_RATE: 0.0001\n", - " MAX_EPOCH: 2\n", - " NUM_WORKERS: 0\n", - " SEED: 20\n" - ] - } - ], - "source": [ - "%cd /content/mmai-tutorials/tutorials/drug-target-interaction\n", - "\n", - "from configs import get_cfg_defaults\n", - "\n", - "# Load the default settings from config.py\n", - "cfg = get_cfg_defaults()\n", - "\n", - "# Update (or override) some of those settings using a custom YAML file\n", - "cfg.merge_from_file(\"experiments/DA_cross_domain.yaml\")\n", - "\n", - "# ------ Hyperparameters to play with -----\n", - "# User can reduce the training epochs to decrease training time if necessary\n", - "cfg.SOLVER.MAX_EPOCH = 2\n", - "\n", - "# User can change to a different dataset\n", - "cfg.DATA.DATASET = \"biosnap\"\n", - "\n", - "# -----------------------------------------\n", - "print(cfg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Data Loading and Pre-processing\n", - "\n", - "In this tutorial, we use a benchmark dataset called **Biosnap**, which contains information about how well different drugs interact with specific proteins. This dataset has been preprocessed and provided by the authors of the **DrugBAN** paper. You can also find it in their [GitHub repository](https://github.com/peizhenbai/DrugBAN/tree/main)." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The Biosnap dataset is stored in a folder called `biosnap`, which contains a few subfolders. Each subfolder corresponds to a different experimental setting for training and testing machine learning models.\n", - "\n", - "Here is a simplified view of the folder structure:\n", - "\n", - "\n", - "\n", - "```sh\n", - " ├───biosnap\n", - " │ ├───cluster\n", - " │ │ ├───source_train.csv\n", - " │ │ ├───target_train.csv\n", - " │ │ ├───target_test.csv\n", - " │ ├───random\n", - " │ │ ├───test.csv\n", - " │ │ ├───train.csv\n", - " │ │ ├───val.csv\n", - " │ ├───full.csv\n", - "\n", - "```\n", - "\n", - "\n", - "Each file listed here is in **CSV format**, which you can open using spreadsheet software (like Excel) or load into Python using tools like `pandas`. These files contain rows of data, with each row representing one drug–protein pair.\n", - "\n", - "\n", - "Here’s what each csv file looks like in a table format:\n", - "\n", - "| SMILES | Protein Sequence | Y |\n", - "|--------------------|--------------------------|---|\n", - "| Fc1ccc(C2(COC…) | MDNVLPVDSDLS… | 1 |\n", - "| O=c1oc2c(O)c(…) | MMYSKLLTLTTL… | 0 |\n", - "| CC(C)Oc1cc(N…) | MGMACLTMTEME… | 1 |\n", - "\n", - "Each row of the dataset contains three key pieces of information:\n", - "\n", - "**Drugs**: \n", - "Drugs are often written as SMILES strings, which are like chemical formulas in text format (for example, `\"CC(=O)OC1=CC=CC=C1C(=O)O\"` is aspirin). \n", - "\n", - "\n", - "**Protein Sequence** \n", - "This is a string of letters where each letter stands for an amino acid, the building blocks of proteins. For example, `MGYTSLLT...` is a short protein sequence.\n", - "\n", - "\n", - "**Y (Labels)**: \n", - "Each drug–protein pair is given a label:\n", - "- `1` if they interact\n", - "- `0` if they do not\n", - "\n", - "\n", - "Each row shows one drug–protein pair. The goal of our machine learning model is to predict the last column (**Y**) — whether or not the drug and protein interact." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\n", - "\n", - "To generate the drug graphs, we use the `kale.loaddata.molecular_datasets.smiles_to_graph` function, which converts SMILES strings into graph structures. For each molecule, atom-level features such as atomic number, degree, valence, and aromaticity are encoded as node features. Bond information is represented through edge indices and edge attributes. The function automatically adds self-loops to all nodes to ensure that each node has at least one connection. For molecules with fewer atoms than the maximum allowed, the function applies node padding by adding virtual nodes with zero features.\n", - "\n", - "\n", - "We use the kale.prepdata.`chem_transform.integer_label_protein` function to convert protein sequences into fixed-length integer arrays. Each amino acid is mapped to a unique integer based on a predefined dictionary (CHARPROTSET). Sequences longer than the maximum length (default: 1200) are truncated, while shorter sequences are zero-padded. Unknown characters are treated as padding, ensuring all protein inputs have a consistent numerical format.\n", - "\n", - "\n", - "We then use the `kale.loaddata.molecular_datasets.DTIDataset` class to integrate these steps by organising the drug-protein-label triplets into a dataset format compatible with PyTorch. During training and evaluation, the DataLoader calls graph_collate_func to batch the molecular graphs, protein sequences, and labels into a single batch. The output is a batched drug graph, a stacked protein sequence tensor, and a label tensor, ready for input into the DrugBAN model." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/torch_geometric/__init__.py:4: UserWarning: An issue occurred while importing 'torch-scatter'. Disabling its usage. Stacktrace: /usr/local/lib/python3.11/dist-packages/torch_scatter/_version_cuda.so: undefined symbol: _ZN5torch3jit17parseSchemaOrNameERKSsb\n", - " import torch_geometric.typing\n", - "/usr/local/lib/python3.11/dist-packages/torch_geometric/__init__.py:4: UserWarning: An issue occurred while importing 'torch-sparse'. Disabling its usage. Stacktrace: /usr/local/lib/python3.11/dist-packages/torch_sparse/_version_cuda.so: undefined symbol: _ZN5torch3jit17parseSchemaOrNameERKSsb\n", - " import torch_geometric.typing\n" - ] - } - ], - "source": [ - "from kale.loaddata.molecular_datasets import DTIDataset\n", - "import pandas as pd\n", - "\n", - "# Path to the dataset folder\n", - "dataFolder = os.path.join(\n", - " f\"/content/drive/Shared drives/EMBC-MMAI 25 Workshop/data/drug-target-interaction/{cfg.DATA.DATASET}\",\n", - " str(cfg.DATA.SPLIT),\n", - ")\n", - "\n", - "# Path to the dataset folder\n", - "df_train_source = pd.read_csv(os.path.join(dataFolder, \"source_train.csv\"))\n", - "df_train_target = pd.read_csv(os.path.join(dataFolder, \"target_train.csv\"))\n", - "df_test_target = pd.read_csv(os.path.join(dataFolder, \"target_test.csv\"))\n", - "\n", - "# Create preprocessed datasets\n", - "train_dataset = DTIDataset(df_train_source.index.values, df_train_source)\n", - "train_target_dataset = DTIDataset(df_train_target.index.values, df_train_target)\n", - "test_target_dataset = DTIDataset(df_test_target.index.values, df_test_target)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Dataset Inspection\n", - "\n", - "Once we’ve loaded the dataset, it's useful to take a quick look at what it contains. This helps us understand the data format and what kind of information we’ll be working with in the rest of the tutorial.\n", - "\n", - "In this project, our dataset has been split into three parts:\n", - "\n", - "**Train samples from the source domain** \n", - "These are drug–protein pairs the model will learn from. The \"source domain\" typically refers to a distribution of data that the model is familiar with.\n", - "\n", - "**Train samples from the target domain** \n", - "These are additional training samples, but from a different distribution (a \"target domain\") the model should generalise to. This helps simulate real-world scenarios where new data may come from different conditions.\n", - "\n", - "**Test samples from the target domain** \n", - "These are drug–protein pairs that the model has never seen before, and they’re used to evaluate how well the model generalises to new, unseen cases.\n", - "\n", - "Let’s print out the number of samples in each set, and take a peek at one example from the training data.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Train samples from source domain: 9766, Train samples from target domain: 3628, Test samples from target domain: 907\n", - "\n", - "An example sample from source domain:\n", - "(Data(x=[290, 7], edge_index=[2, 58], edge_attr=[58, 1], num_nodes=290), array([11., 1., 18., ..., 0., 0., 0.]), np.float64(0.0))\n" - ] - } - ], - "source": [ - "print(\n", - " f\"Train samples from source domain: {len(train_dataset)}, Train samples from target domain: {len(train_target_dataset)}, Test samples from target domain: {len(test_target_dataset)}\"\n", - ")\n", - "\n", - "print(\"\\nAn example sample from source domain:\")\n", - "print(train_dataset[0])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let’s break down what this example from the **source domain** means:\n", - "\n", - "```\n", - "An example sample from source domain:\n", - "(Data(x=[290, 7], edge_index=[2, 58], edge_attr=[58, 1], num_nodes=290), array([11., 1., 18., ..., 0., 0., 0.]), 0.0)\n", - "```\n", - "\n", - "This sample is a tuple with **three parts**:\n", - "\n", - "\n", - "\n", - "1. **Drug Graph (Data object)**\n", - "\n", - "This part is a graph-based representation of the **drug**, built using the PyTorch Geometric `Data` object:\n", - "\n", - "- `x=[290, 7]` \n", - " This is a table (matrix) with **290 nodes** (atoms) and **7 features** per atom. \n", - " Each row represents an atom, and each column describes one feature of the atom.\n", - "\n", - "- `edge_index=[2, 58]` \n", - " This shows how the atoms are connected (like chemical bonds). \n", - " There are **58 edges**, and the matrix has 2 rows — the first row lists source atoms, the second lists the target atoms.\n", - "\n", - "- `edge_attr=[58, 1]` \n", - " Each edge (bond) has **1 feature**, such as bond type. \n", - " So there are 58 rows (one for each edge), and 1 column.\n", - "\n", - "- `num_nodes=290` \n", - " This confirms that the graph has **290 atoms (nodes)**.\n", - "\n", - "\n", - "\n", - "2. **Protein Features (array)**\n", - "\n", - "- This is a **1D array** (or vector) representing the **protein**. \n", - "- It contains numerical features extracted from the protein sequence or structure. \n", - "- Example values: `[11., 1., 18., ..., 0., 0., 0.]` \n", - " These could represent biochemical or structural properties, with padding at the end (zeros) to ensure a consistent input size.\n", - "\n", - "\n", - "3. **Label (float)**\n", - "\n", - "- `0.0` \n", - " This is the **label**, which tells us the ground truth: \n", - " The drug and protein **do not interact** in this sample. \n", - "\n", - " If the label were `1.0`, it would mean they **do interact**.\n", - "\n", - "\n", - "\n", - "This format allows the model to learn from both structured graph data (the drug) and feature-based data (the protein), and predict whether they interact based on the label.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Batching\n", - "\n", - "When training machine learning models, especially on large datasets like molecular graphs, it’s inefficient and memory-intensive to load everything at once. Instead, we split the data into **mini-batches** and feed them into the model one at a time. This process is called **batching**, i.e, loading data in manageable pieces.\n", - "\n", - "We use PyTorch’s `DataLoader` to efficiently batch and load samples during training and evaluation. For training, we create two separate data loaders: one for the source domain and one for the target domain. To enable domain adaptation, we combine them using `kale.loaddata.sampler.MultiDataLoader`, which yields one batch from each domain at every training step and ensures a consistent number of batches per epoch by automatically restarting smaller datasets when needed.\n", - "\n", - "However, because molecular data involves graphs of varying sizes and structures, we cannot stack them like regular tensors or images. To handle this, we use a custom collate function called `kale.loaddata.molecular_datasets.graph_collate_func`, which tells the `DataLoader` how to correctly combine multiple graphs into a single batch that the model can process." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from torch.utils.data import DataLoader\n", - "from kale.loaddata.molecular_datasets import graph_collate_func\n", - "from kale.loaddata.sampler import MultiDataLoader\n", - "\n", - "# Define the parameters used by the data loaders\n", - "params = {\n", - " \"batch_size\": cfg.SOLVER.BATCH_SIZE, # Number of samples per batch\n", - " \"shuffle\": True, # Shuffle data during training\n", - " \"num_workers\": cfg.SOLVER.NUM_WORKERS, # Number of workers loading the data\n", - " \"drop_last\": True, # Drop the last batch if it's smaller than batch_size\n", - " \"collate_fn\": graph_collate_func, # Custom function to batch graphs correctly\n", - "}\n", - "\n", - "\n", - "# Create data loaders for source and target training datasets\n", - "source_generator = DataLoader(train_dataset, **params)\n", - "target_generator = DataLoader(train_target_dataset, **params)\n", - "\n", - "# Get the number of batches in the longer dataset to align both\n", - "n_batches = max(len(source_generator), len(target_generator))\n", - "\n", - "# Combine the source and target data loaders using MultiDataLoader\n", - "training_generator = MultiDataLoader(\n", - " dataloaders=[source_generator, target_generator], n_batches=n_batches\n", - ") # used to be named as multi_generator\n", - "\n", - "\n", - "# Now we set up data loaders for validation and testing. Since we don’t want to shuffle or drop any samples, we adjust the parameters accordingly.\n", - "\n", - "# Update parameters for validation/testing (no shuffling, keep all data)\n", - "params.update({\"shuffle\": False, \"drop_last\": False})\n", - "\n", - "# Create validation and test data loaders\n", - "valid_generator = DataLoader(test_target_dataset, **params)\n", - "test_generator = DataLoader(test_target_dataset, **params)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Model Definition\n", - "\n", - "DrugBAN consists of three main components: a Graph Convolutional Network (GCN) for extracting structural features from drug molecular graphs, a Convolutional Neural Network (CNN) for encoding protein sequences, and a Bilinear Attention Network (BAN) for fusing drug and protein features. The fused representation is then passed through a Multi-Layer Perceptron (MLP) classifier to predict interaction scores.\n", - "\n", - "We define the DrugBAN class in `kale.embed.ban`, which wraps all key modules of the DrugBAN pipeline based on the configuration.\n", - "This wrapper handles:\n", - "\n", - "- Initialising the GCN-based drug feature extractor (MolecularGCN).\n", - "\n", - "- Building the CNN-based protein sequence encoder (ProteinCNN).\n", - "\n", - "- Integrating the BAN layer for drug-protein feature fusion (BANLayer).\n", - "\n", - "- Creating the MLP classifier for final prediction (MLPDecoder)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from kale.embed.ban import DrugBAN\n", - "\n", - "# Create the model using settings from your config\n", - "model = DrugBAN(**cfg)\n", - "\n", - "# Print the model structure to see what's inside\n", - "print(model)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Model Training\n", - "We use the training class `kale.pipeline.drugban_trainer`, which handles model training, domain adaptation, and evaluation for DrugBAN." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/torch/nn/utils/weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.\n", - " warnings.warn(\"torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.\")\n" - ] - } - ], - "source": [ - "from kale.pipeline.drugban_trainer import DrugbanTrainer\n", - "\n", - "# Create an instance of the trainer with your model and configuration\n", - "drugban_trainer = DrugbanTrainer(\n", - " model=DrugBAN(**cfg),\n", - " solver_lr=cfg[\"SOLVER\"][\"LEARNING_RATE\"],\n", - " num_classes=cfg[\"DECODER\"][\"BINARY\"],\n", - " batch_size=cfg[\"SOLVER\"][\"BATCH_SIZE\"],\n", - " # Domain adaptation settings\n", - " is_da=cfg[\"DA\"][\"USE\"],\n", - " solver_da_lr=cfg[\"SOLVER\"][\"DA_LEARNING_RATE\"],\n", - " da_init_epoch=cfg[\"DA\"][\"INIT_EPOCH\"],\n", - " da_method=cfg[\"DA\"][\"METHOD\"],\n", - " original_random=cfg[\"DA\"][\"ORIGINAL_RANDOM\"],\n", - " use_da_entropy=cfg[\"DA\"][\"USE_ENTROPY\"],\n", - " da_random_layer=cfg[\"DA\"][\"RANDOM_LAYER\"],\n", - " # --- discriminator parameters ---\n", - " da_random_dim=cfg[\"DA\"][\"RANDOM_DIM\"],\n", - " decoder_in_dim=cfg[\"DECODER\"][\"IN_DIM\"],\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We want to save the best model during training so we can reuse it later without needing to retrain. PyTorch Lightning’s `ModelCheckpoint` does this by automatically saving the model whenever it achieves a new best validation AUROC score." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import pytorch_lightning as pl\n", - "from pytorch_lightning.callbacks import ModelCheckpoint\n", - "\n", - "# Save the model when it achieves the best AUROC score on the validation set\n", - "checkpoint_callback = ModelCheckpoint(\n", - " filename=\"{epoch}-{step}-{val_BinaryAUROC:.4f}\", # how to name saved files\n", - " monitor=\"val_BinaryAUROC\", # which metric to monitor\n", - " mode=\"max\", # we want to maximise this score\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We now create the `Trainer`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True\n", - "INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n", - "INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs\n" - ] - } - ], - "source": [ - "import torch\n", - "\n", - "trainer = pl.Trainer(\n", - " callbacks=[checkpoint_callback], # automatically save best model\n", - " devices=\"auto\", # use all available GPUs\n", - " accelerator=(\n", - " \"gpu\" if torch.cuda.is_available() else \"cpu\"\n", - " ), # decide training hardware\n", - " max_epochs=cfg[\"SOLVER\"][\"MAX_EPOCH\"], # how many passes over the training set\n", - " deterministic=True, # makes results reproducible every time you run\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Train the DrugBAN Model\n", - "After setting up the model and data loaders, we now start training the full DrugBAN model using the PyTorch Lightning Trainer via calling `trainer.fit()`.\n", - "\n", - "#### What Happens Here?\n", - "- The model receives batches of drug-protein pairs from the training data loader.\n", - "\n", - "- During each step, the GCN, CNN, BAN layer, and MLP classifier are updated to improve interaction prediction.\n", - "\n", - "- Validation is automatically run at the end of each epoch to track performance and save the best model based on AUROC." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:pytorch_lightning.loggers.tensorboard:Missing logger folder: /content/embc-mmai25/tutorials/drug-target-interaction/lightning_logs\n", - "INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", - "INFO:pytorch_lightning.callbacks.model_summary:\n", - " | Name | Type | Params | Mode \n", - "---------------------------------------------------------------------\n", - "0 | model | DrugBAN | 1.0 M | train\n", - "1 | domain_discriminator | DomainNetSmallImage | 133 K | train\n", - "2 | random_layer | RandomLayer | 66.0 K | train\n", - "3 | valid_metrics | MetricCollection | 0 | train\n", - "4 | test_metrics | MetricCollection | 0 | train\n", - "---------------------------------------------------------------------\n", - "1.2 M Trainable params\n", - "0 Non-trainable params\n", - "1.2 M Total params\n", - "4.847 Total estimated model params size (MB)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "de1e5fc3654547e4b7078fd9511d4764", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Sanity Checking: | | 0/? [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/pytorch_lightning/utilities/data.py:78: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 9280. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", - "/usr/local/lib/python3.11/dist-packages/torchmetrics/utilities/prints.py:43: TorchMetricsUserWarning: You are trying to use a metric in deterministic mode on GPU that uses `torch.cumsum`, which is currently not supported. The tensor will be copied to the CPU memory to compute it and then copied back to GPU. Expect some slowdowns.\n", - " warnings.warn(*args, **kwargs)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "8ef34bbc1d944d8398a243f6c45ab415", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Training: | | 0/? [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[15:53:49] Unusual charge on atom 0 number of radical electrons set to zero\n", - "[15:53:55] Unusual charge on atom 0 number of radical electrons set to zero\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "fc8cb437cdc34a5eb335d42c56f2975c", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: | | 0/? [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/pytorch_lightning/utilities/data.py:78: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 3190. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", - "[15:54:38] Unusual charge on atom 0 number of radical electrons set to zero\n", - "[15:55:25] Unusual charge on atom 0 number of radical electrons set to zero\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "7d6c70a42d194a1f956778d4543fe9ff", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation: | | 0/? [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=2` reached.\n" - ] - } - ], - "source": [ - "trainer.fit(\n", - " drugban_trainer,\n", - " train_dataloaders=training_generator,\n", - " val_dataloaders=valid_generator,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Evaluate\n", - "\n", - "Once training is complete, we evaluate the model on the test set using `trainer.test()`.\n", - "\n", - "#### What is included in this step?\n", - "- The best model checkpoint (based on validation AUROC) is automatically loaded.\n", - "\n", - "- The model runs on the test data to generate predictions.\n", - "\n", - "- Final classification metrics, including AUROC, F1 score, accuracy, sensitivity, and specificity, are calculated and logged." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at /content/embc-mmai25/tutorials/drug-target-interaction/lightning_logs/version_0/checkpoints/epoch=1-step=1220-val_BinaryAUROC=0.5640.ckpt\n", - "INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", - "INFO:pytorch_lightning.utilities.rank_zero:Loaded model weights from the checkpoint at /content/embc-mmai25/tutorials/drug-target-interaction/lightning_logs/version_0/checkpoints/epoch=1-step=1220-val_BinaryAUROC=0.5640.ckpt\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "4fb3eaa50982466696a329fb30165511", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Testing: | | 0/? [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.11/dist-packages/kale/pipeline/drugban_trainer.py:377: RuntimeWarning: invalid value encountered in divide\n", - " precision = tpr / (tpr + fpr)\n", - "/usr/local/lib/python3.11/dist-packages/torchmetrics/utilities/prints.py:43: TorchMetricsUserWarning: You are trying to use a metric in deterministic mode on GPU that uses `torch.cumsum`, which is currently not supported. The tensor will be copied to the CPU memory to compute it and then copied back to GPU. Expect some slowdowns.\n", - " warnings.warn(*args, **kwargs)\n" - ] - }, - { - "data": { - "text/html": [ - "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃ Test metric ┃ DataLoader 0 ┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│ test_BinaryAUROC │ 0.5640328526496887 │\n", - "│ test_BinaryAccuracy │ 0.5082690119743347 │\n", - "│ test_BinaryF1Score │ 0.15209124982357025 │\n", - "│ test_BinaryRecall │ 0.08791209012269974 │\n", - "│ test_BinarySpecificity │ 0.9314159154891968 │\n", - "│ test_accuracy_sklearn │ 0.5226019620895386 │\n", - "│ test_auroc_sklearn │ 0.5640328526496887 │\n", - "│ test_f1_sklearn │ 0.6693024635314941 │\n", - "│ test_loss │ 0.8550940155982971 │\n", - "│ test_optim_threshold │ 0.08681917190551758 │\n", - "│ test_sensitivity │ 0.07300885021686554 │\n", - "│ test_specificity │ 0.9692307710647583 │\n", - "└───────────────────────────┴───────────────────────────┘\n", - "\n" - ], - "text/plain": [ - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│\u001b[36m \u001b[0m\u001b[36m test_BinaryAUROC \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.5640328526496887 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_BinaryAccuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.5082690119743347 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_BinaryF1Score \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.15209124982357025 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_BinaryRecall \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.08791209012269974 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_BinarySpecificity \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9314159154891968 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_accuracy_sklearn \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.5226019620895386 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_auroc_sklearn \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.5640328526496887 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_f1_sklearn \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.6693024635314941 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.8550940155982971 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_optim_threshold \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.08681917190551758 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_sensitivity \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.07300885021686554 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_specificity \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9692307710647583 \u001b[0m\u001b[35m \u001b[0m│\n", - "└───────────────────────────┴───────────────────────────┘\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "[{'test_loss': 0.8550940155982971,\n", - " 'test_auroc_sklearn': 0.5640328526496887,\n", - " 'test_accuracy_sklearn': 0.5226019620895386,\n", - " 'test_f1_sklearn': 0.6693024635314941,\n", - " 'test_sensitivity': 0.07300885021686554,\n", - " 'test_specificity': 0.9692307710647583,\n", - " 'test_optim_threshold': 0.08681917190551758,\n", - " 'test_BinaryAUROC': 0.5640328526496887,\n", - " 'test_BinaryF1Score': 0.15209124982357025,\n", - " 'test_BinaryRecall': 0.08791209012269974,\n", - " 'test_BinarySpecificity': 0.9314159154891968,\n", - " 'test_BinaryAccuracy': 0.5082690119743347}]" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "trainer.test(drugban_trainer, dataloaders=test_generator, ckpt_path=\"best\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Compare with Baselines\n", - "\n", - "To assess the robustness and generalisability of DrugBAN, we compare its performance against baseline models. In this example, DrugBAN was trained for 100 epochs across multiple random seeds。\n", - "\n", - "The figure below presents the comparison on the BioSNAP and BindingDB datasets.\n", - "\n", - "- The left plot shows model performance based on AUROC (Area Under the Receiver Operating Characteristic Curve).\n", - "\n", - "- The right plot shows performance based on AUPRC (Area Under the Precision–Recall Curve)." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Interpretation Study - Extracting Embeddings from DrugBAN\n", - "After training and evaluating the model, we can study how DrugBAN represents drug and protein information at different stages by extracting key embeddings: **drug embedding**, **protein embedding**, and **joint interaction embedding**. This helps us understand how structural and sequential features are captured and how drug-protein interactions are encoded.\n", - "\n", - "The DrugBAN model provides these embeddings through its `forward()` function, returning intermediate outputs before and after the Bilinear Attention Network (BAN) layer.\n", - "\n", - "### How Are the Embeddings Computed?\n", - "The `forward()` function sequentially applies three main modules:\n", - "\n", - "- The **GCN-based drug extractor** processes the molecular graph, learning structural features and generating the **drug embedding**.\n", - "\n", - "- The **CNN-based protein extractor** processes the protein sequence, capturing local and global sequence patterns as the **protein embedding**.\n", - "\n", - "- The **BAN layer** fuses the drug and protein embeddings using bilinear attention, creating a **joint embedding** that highlights interaction-specific features.\n", - "\n", - "You should save embeddings during the evaluation phase (validation or test), not during training. This ensures you are extracting embeddings from a model that is not updating its weights, and you avoid interfering with training performance." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Extra Tasks\n", - "\n", - "### Task 1: Try the BindingDB Dataset\n", - "Swap the current dataset with **BindingDB**, a real-world dataset containing experimentally measured drug–target interactions.\n", - "\n", - "Steps:\n", - "1. Download or prepare the BindingDB dataset (if needed).\n", - "2. Update the relevant field in the YAML config file.\n", - "3. Reload the dataset and re-run training and testing.\n", - "\n", - "**What to explore:**\n", - "- Does model performance change?\n", - "- Is the dataset more imbalanced?\n", - "- How does training time compare?\n", - "\n", - "> Tip: See if the model struggles more or less with the new dataset. It can reveal how generalisable DrugBAN is.\n", - "\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/tutorials/drug-target-interaction/notebook.ipynb b/tutorials/drug-target-interaction/notebook.ipynb new file mode 100644 index 0000000..fbc582c --- /dev/null +++ b/tutorials/drug-target-interaction/notebook.ipynb @@ -0,0 +1,882 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "cells": [ + { + "metadata": {}, + "source": [ + "# Drug–Target Interaction Prediction\n" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "## Introduction" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "In this tutorial, we show how to use `PyKale`’s standard pipeline to integrate **multimodal data** from **drugs** and **proteins** for drug-target interaction (DTI) prediction. DTI prediction plays a key role in drug discovery and identifying potential therapeutic targets. This example is based on the **DrugBAN** framework by [**Bai et al. (_Nature Machine Intelligence_, 2023)**](https://www.nature.com/articles/s42256-022-00605-1)." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "## Problem Formulation\n", + "\n", + "The DTI prediction problem is formulated as a **binary classification task**, where the goal is to predict whether a given **drug–protein pair interacts or not**. The DrugBAN framework tackles this problem using two key ideas:\n", + "\n", + "- **Bilinear Attention Network (BAN)**, which learns detailed feature representations for both drugs and proteins and captures local interaction patterns between them.\n", + "\n", + "- **Adversarial Domain Adaptation**, which helps the model generalise to out-of-distribution datasets, improving its ability to predict interactions on unseen drug–target pairs.\n", + "\n", + "With `PyKale`, implementing such a multimodal DTI prediction pipeline is straightforward. The library provides ready-to-use modules and configuration support, making it easy to apply advanced techniques like BAN and domain adaptation with minimal custom coding." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "## Step 0: Environment preparation" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "As a starting point, we will install the required packages and load a set of helper functions to assist throughout this tutorial." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "To prepare the helper functions and necessary materials, we download them from the GitHub repository.\n", + "\n", + "Moreover, we provide helper functions that can be inspected directly in the `.py` files located in the notebook's current directory. The additional helper script is:\n", + "- [`config.py`](https://github.com/pykale/embc-mmai25/blob/main/tutorials/drug-target-interaction/configs.py): Defines the base configuration settings, which can be overridden using a custom `.yaml` file." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "!rm -rf /content/mmai-tutorial\n", + "!git clone --branch drug-target https://github.com/pykale/mmai-tutorial.git\n", + "%cd /content/mmai-tutorial/tutorials/drug-target-interaction" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "### Package Installation" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "The main package required for this tutorial is `PyKale`.\n", + "\n", + "`PyKale` is an open-source interdisciplinary machine learning library developed at the University of Sheffield, with a focus on applications in biomedical and scientific domains." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "!pip install --quiet git+https://github.com/pykale/pykale@main\\\n", + " && echo \"PyKale installed successfully ✅\" \\\n", + " || echo \"Failed to install PyKale ❌\"" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "Then, we install `PyG` (PyTorch Geometric) and related packages.\n", + "\n", + "Please **do not** re-run this session after installation completed. Runing this installation multiple times will trigger issues related to `PyG`. If you want to re-run this installation, please click the `Runtime` on the top menu and choose `Disconnect and delete runtime` before installing." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "import torch\n", + "import os\n", + "os.environ['TORCH'] = torch.__version__\n", + "!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html\n", + "!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html\n", + "\n", + "!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git \\\n", + " && echo \"PyG installed successfully ✅\" \\\n", + " || echo \"Failed to install PyG ❌\"\n", + "\n", + "!pip install rdkit-pypi \\\n", + " && echo \"rdkit installed successfully ✅\" \\\n", + " || echo \"Failed to install rdkit ❌\"" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "Then, we install other required packages in [`mmai-tutorial/requirements.txt`](https://github.com/pykale/mmai-tutorial/blob/main/requirements.txt)." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "%cd /content/mmai-tutorial/tutorials/drug-target-interaction\n", + "\n", + "!pip install --quiet -r /content/mmai-tutorial/requirements.txt \\\n", + " && echo \"Required packages installed successfully ✅\" \\\n", + " || echo \"Failed to install required packages ❌\"" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "Please run the following block to reinstall `NumPy` to avoid bugs." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "import os\n", + "\n", + "!pip install --upgrade --force-reinstall numpy==2.0.0\n", + "os.kill(os.getpid(), 9)" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "Install yacs" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "!pip install yacs" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "We then hide the warnings messages to get a clear output." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "import os\n", + "import warnings\n", + "\n", + "warnings.filterwarnings(\"ignore\")\n", + "os.environ[\"PYTHONWARNINGS\"] = \"ignore\"" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "Exercise: Check NumPy Version" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "import numpy as np\n", + "\n", + "print(\"NumPy version:\", np.__version__) # numpy should be 2.0.0" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "### Configuration\n", + "\n", + "To minimize the footprint of the notebook when specifying configurations, we provide a [`config.py`](https://github.com/pykale/embc-mmai25/blob/main/tutorials/drug-target-interaction/configs.py) file that defines default parameters. These can be customized by supplying a `.yaml` configuration file, such as [`experiments/DA_cross_domain.yaml`](https://github.com/pykale/embc-mmai25/blob/main/tutorials/drug-target-interaction/experiments/DA_cross_domain.yaml) as an example." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "%cd /content/mmai-tutorial/tutorials/drug-target-interaction\n", + "\n", + "from configs import get_cfg_defaults\n", + "\n", + "cfg = get_cfg_defaults() # Load the default settings from config.py\n", + "cfg.merge_from_file(\n", + " \"experiments/DA_cross_domain.yaml\"\n", + ") # Update (or override) some of those settings using a custom YAML file" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "In this tutorial, we list the hyperparameters we would like users to play with outside the `.yaml` file:\n", + "- `cfg.SOLVER.MAX_EPOCH`: Number of epochs in training stage. You can reduce the number of training epochs to shorten runtime.\n", + "- `cfg.DATA.DATASET`: The dataset used in the study. This can be `bindingdb` or `biosnap`.\n", + "\n", + "As a quick exercise, please take a moment to review and understand the parameters in [`config.py`](https://github.com/pykale/embc-mmai25/blob/main/tutorials/drug-target-interaction/configs.py)." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "cfg.SOLVER.MAX_EPOCH = 2" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "You can also switch to a different dataset." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "cfg.DATA.DATASET = \"biosnap\"" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "Exercise: Now print the full configuration to check all current hyperparameter and dataset settings." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "print(cfg)" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "## Step 1: Data loading and preparation\n", + "\n", + "In this tutorial, we use the **Biosnap** dataset for the main demonstration and the **BindingDB** dataset for the exercise at the end." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "### Data downloading\n", + "\n", + "Please using the following codes to download necessary datasets." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "!rm -rf data\n", + "!mkdir data\n", + "!cd data\n", + "\n", + "!pip install -q gdown\n", + "!gdown --id 1ogOcxZn-1q418LOT-gQ94aHQV0Y1sOmk --output data/drug-target-interaction.zip\n", + "!unzip data/drug-target-interaction.zip -d data/" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "Exercise: Check the data is ready" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "import os\n", + "import shutil\n", + "\n", + "print(\"Contents of the folder:\")\n", + "for item in os.listdir(\"data/drug-target-interaction\"):\n", + " print(item)" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "Each dataset folder follows the structure:\n", + "\n", + "```sh\n", + " ├───dataset_name\n", + " │ ├───cluster\n", + " │ │ ├───source_train.csv\n", + " │ │ ├───target_train.csv\n", + " │ │ ├───target_test.csv\n", + " │ ├───random\n", + " │ │ ├───test.csv\n", + " │ │ ├───train.csv\n", + " │ │ ├───val.csv\n", + " │ ├───full.csv\n", + "```" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "We use the cluster dataset folder for cross-domain prediction, containing three parts:\n", + "\n", + "- Train samples from the source domain: Drug–protein pairs the model learns from.\n", + "\n", + "- Train samples from the target domain: Additional training data from a different distribution to improve generalisation.\n", + "\n", + "- Test samples from the target domain: Unseen drug–protein pairs used to evaluate model performance on new data." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "### Data loading" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "Here’s what each csv file looks like in a table format:\n", + "\n", + "| SMILES | Protein Sequence | Y |\n", + "|--------------------|--------------------------|---|\n", + "| Fc1ccc(C2(COC…) | MDNVLPVDSDLS… | 1 |\n", + "| O=c1oc2c(O)c(…) | MMYSKLLTLTTL… | 0 |\n", + "| CC(C)Oc1cc(N…) | MGMACLTMTEME… | 1 |\n", + "\n", + "Each row of the dataset contains three key pieces of information:\n", + "\n", + "**Drugs**: \n", + "Drugs are often written as SMILES strings, which are like chemical formulas in text format (for example, `\"CC(=O)OC1=CC=CC=C1C(=O)O\"` is aspirin). \n", + "\n", + "\n", + "**Protein Sequence** \n", + "This is a string of letters where each letter stands for an amino acid, the building blocks of proteins. For example, `MGYTSLLT...` is a short protein sequence.\n", + "\n", + "\n", + "**Y (Labels)**: \n", + "Each drug–protein pair is given a label:\n", + "- `1` if they interact\n", + "- `0` if they do not\n", + "\n", + "\n", + "Each row shows one drug–protein pair. The goal of our machine learning model is to predict the last column (**Y**) — whether or not the drug and protein interact." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "You can load CSV files into Python using tools like `pandas`." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "import pandas as pd\n", + "\n", + "dataFolder = os.path.join(\n", + " f\"data/drug-target-interaction/{cfg.DATA.DATASET}\", str(cfg.DATA.SPLIT)\n", + ")\n", + "\n", + "df_train_source = pd.read_csv(os.path.join(dataFolder, \"source_train.csv\"))\n", + "df_train_target = pd.read_csv(os.path.join(dataFolder, \"target_train.csv\"))\n", + "df_test_target = pd.read_csv(os.path.join(dataFolder, \"target_test.csv\"))" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "### Data preprocessing\n", + "\n", + "We convert drug SMILES strings into molecular graphs using `kale.loaddata.molecular_datasets.smiles_to_graph`, encoding atom-level features as node attributes and bond types as edges.\n", + "\n", + "\n", + "Protein sequences are transformed into fixed-length integer arrays using `kale.prepdata.chem_transform.integer_label_protein`, with each amino acid mapped to an integer and sequences padded or truncated to a uniform length.\n", + "\n", + "Finally, the `kale.loaddata.molecular_datasets.DTIDataset` class packages drugs, proteins, and labels into a PyTorch-ready dataset." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "**Note:** If you encounter an error related to requiring numpy `<2.0`, simply ignore it and re-run this block until it completes successfully." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "from kale.loaddata.molecular_datasets import DTIDataset\n", + "\n", + "# Create preprocessed datasets\n", + "train_dataset = DTIDataset(df_train_source.index.values, df_train_source)\n", + "train_target_dataset = DTIDataset(df_train_target.index.values, df_train_target)\n", + "test_target_dataset = DTIDataset(df_test_target.index.values, df_test_target)" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "We load data in small, manageable pieces called batches to save memory and speed up training. We use `kale.loaddata.sampler.MultiDataLoader` from PyKale to load one batch from the source domain and one from the target domain at each training step." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "First, we specify a few DataLoader parameters:\n", + "- Batch size: Number of samples per batch\n", + "- Shuffle: Randomly shuffle data\n", + "- Number of workers: Parallel data loading\n", + "- Drop last: Discard the last incomplete batch for consistent batch sizes\n", + "- Collate function: Use graph_collate_func to batch variable-sized molecular graphs" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "from torch.utils.data import DataLoader\n", + "from kale.loaddata.molecular_datasets import graph_collate_func\n", + "from kale.loaddata.sampler import MultiDataLoader\n", + "\n", + "params = {\n", + " \"batch_size\": cfg.SOLVER.BATCH_SIZE,\n", + " \"shuffle\": True,\n", + " \"num_workers\": cfg.SOLVER.NUM_WORKERS,\n", + " \"drop_last\": True,\n", + " \"collate_fn\": graph_collate_func,\n", + "}" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "Then, we create a DataLoader from both the source and target datasets for training." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "if not cfg.DA.USE:\n", + " training_generator = DataLoader(train_dataset, **params)\n", + "else:\n", + " source_generator = DataLoader(train_dataset, **params)\n", + " target_generator = DataLoader(train_target_dataset, **params)\n", + "\n", + " # Get the number of batches in the longer dataset to align both\n", + " n_batches = max(len(source_generator), len(target_generator))\n", + "\n", + " # Combine the source and target data loaders using MultiDataLoader\n", + " training_generator = MultiDataLoader(\n", + " dataloaders=[source_generator, target_generator], n_batches=n_batches\n", + " )" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "Lastly, we set up DataLoaders for validation and testing. Since we don’t want to shuffle or drop any samples, we adjust the parameters accordingly." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "# Update parameters for validation/testing (no shuffling, keep all data)\n", + "params.update({\"shuffle\": False, \"drop_last\": False})\n", + "\n", + "# Create validation and test data loaders\n", + "valid_generator = DataLoader(test_target_dataset, **params)\n", + "test_generator = DataLoader(test_target_dataset, **params)" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "### Exercise: Dataset Inspection\n", + "\n", + "Once the dataset is ready, let’s inspect one sample from the training data to check the input graph, protein sequence, and label format." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "# Get the first batch (contains one batch from source and one from target)\n", + "first_batch = next(iter(training_generator))\n", + "\n", + "# Unpack source and target batches\n", + "source_batch, target_batch = first_batch\n", + "\n", + "# Inspect the first sample from the source batch\n", + "print(\"First sample from source batch:\")\n", + "print(\"Drug graph:\", source_batch[0][0])\n", + "print(\"Protein sequence:\", source_batch[1][0])\n", + "print(\"Label:\", source_batch[2][0])" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "This sample is a tuple with three parts:\n", + "\n", + "1. **Drug Graph**\n", + "- `x=[290, 7]`: Feature matrix with 290 atoms (nodes) and 7 features per atom.\n", + "- `edge_index=[2, 58]`: Shows 146 edges, with source and target node indices.\n", + "- `edge_attr=[58, 1]`: Each edge has 1 bond feature, such as bond type.\n", + "- `num_nodes=290`: Confirms the graph has 290 nodes.\n", + "\n", + "2. **Protein Features (array)**\n", + "- Example values: `[11., 1., 18., ..., 0., 0., 0.]`: A fixed-length numeric array representing the protein sequence. Each position holds an integer-encoded amino acid, with zeros for padding.\n", + "\n", + "3. **Label (float)**\n", + "- `0.0`; The ground-truth interaction label indicating no interaction." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "## Step 2: Model definition" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "### Embed\n", + "\n", + "DrugBAN consists of three main components: a Graph Convolutional Network (GCN) for extracting structural features from drug molecular graphs, a Convolutional Neural Network (CNN) for encoding protein sequences, and a Bilinear Attention Network (BAN) for fusing drug and protein features. The fused representation is then passed through a Multi-Layer Perceptron (MLP) classifier to predict interaction scores.\n", + "\n", + "We define the DrugBAN class in `kale.embed.ban`." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "from kale.embed.ban import DrugBAN\n", + "\n", + "model = DrugBAN(**cfg)\n", + "print(model)" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "### Predict\n", + "We use the training class `kale.pipeline.drugban_trainer`, which handles model training, domain adaptation, and evaluation for DrugBAN." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "from kale.pipeline.drugban_trainer import DrugbanTrainer\n", + "\n", + "drugban_trainer = DrugbanTrainer(\n", + " model=DrugBAN(**cfg),\n", + " solver_lr=cfg[\"SOLVER\"][\"LEARNING_RATE\"],\n", + " num_classes=cfg[\"DECODER\"][\"BINARY\"],\n", + " batch_size=cfg[\"SOLVER\"][\"BATCH_SIZE\"],\n", + " is_da=cfg[\"DA\"][\"USE\"],\n", + " solver_da_lr=cfg[\"SOLVER\"][\"DA_LEARNING_RATE\"],\n", + " da_init_epoch=cfg[\"DA\"][\"INIT_EPOCH\"],\n", + " da_method=cfg[\"DA\"][\"METHOD\"],\n", + " original_random=cfg[\"DA\"][\"ORIGINAL_RANDOM\"],\n", + " use_da_entropy=cfg[\"DA\"][\"USE_ENTROPY\"],\n", + " da_random_layer=cfg[\"DA\"][\"RANDOM_LAYER\"],\n", + " da_random_dim=cfg[\"DA\"][\"RANDOM_DIM\"],\n", + " decoder_in_dim=cfg[\"DECODER\"][\"IN_DIM\"],\n", + ")" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "We want to save the best model during training so we can reuse it later without needing to retrain. PyTorch Lightning’s `ModelCheckpoint` does this by automatically saving the model whenever it achieves a new best validation AUROC score." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "import pytorch_lightning as pl\n", + "from pytorch_lightning.callbacks import ModelCheckpoint\n", + "\n", + "checkpoint_callback = ModelCheckpoint(\n", + " filename=\"{epoch}-{step}-{val_BinaryAUROC:.4f}\",\n", + " monitor=\"val_BinaryAUROC\",\n", + " mode=\"max\",\n", + ")" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "We now create the `Trainer`." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "import torch\n", + "\n", + "trainer = pl.Trainer(\n", + " callbacks=[checkpoint_callback],\n", + " devices=\"auto\",\n", + " accelerator=\"auto\",\n", + " max_epochs=cfg[\"SOLVER\"][\"MAX_EPOCH\"],\n", + " deterministic=True,\n", + ")" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "## Step 3: Model training" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "### Train\n", + "\n", + "After setting up the model and data loaders, we now start training the full DrugBAN model using the PyTorch Lightning Trainer via calling `trainer.fit()`.\n", + "\n", + "#### What Happens Here?\n", + "- The model receives batches of drug-protein pairs from the training data loader.\n", + "\n", + "- During each step, the GCN, CNN, BAN layer, and MLP classifier are updated to improve interaction prediction.\n", + "\n", + "- Validation is automatically run at the end of each epoch to track performance and save the best model based on AUROC.\n", + "\n", + "\n", + "This code block takes approximately 5 minutes to complete." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "trainer.fit(\n", + " drugban_trainer,\n", + " train_dataloaders=training_generator,\n", + " val_dataloaders=valid_generator,\n", + ")" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "## Step 4: Evaluate\n", + "\n", + "Once training is complete, we evaluate the model on the test set using `trainer.test()`.\n", + "\n", + "#### What is included in this step?\n", + "- The best model checkpoint (based on validation AUROC) is automatically loaded.\n", + "\n", + "- The model runs on the test data to generate predictions.\n", + "\n", + "- Final classification metrics, including AUROC, F1 score, accuracy, sensitivity, and specificity, are calculated and logged." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "trainer.test(drugban_trainer, dataloaders=test_generator, ckpt_path=\"best\")" + ], + "cell_type": "code", + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "### Performance comparison\n", + "\n", + "The earlier example was a simple demonstration. To properly evaluate DrugBAN against baseline models, we train it for 100 epochs across multiple random seeds.\n", + "\n", + "The figure below shows the performance of different models on the BioSNAP and BindingDB datasets:\n", + "- Left plot: AUROC (Area Under the ROC Curve)\n", + "- Right plot: AUPRC (Area Under the Precision–Recall Curve)" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "## Step 5: Interpretation\n", + "\n", + "Although we don’t perform this step in the tutorial, for your information, it is possible to explore how DrugBAN internally represents drugs and proteins by extracting intermediate embeddings.\n", + "\n", + "The model first processes drug graphs and protein sequences through separate modules, then fuses them using bilinear attention to create a joint representation. These embeddings—drug embedding, protein embedding, and joint interaction embedding—help reveal what structural and sequence features the model has learned and how it encodes drug-protein interactions.\n", + "\n", + "This is typically done during the evaluation phase to avoid affecting the model’s weights or training behaviour." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "## Extra Tasks\n" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "### Task 1\n", + "\n", + "To use the BindingDB dataset, modify the relevant line in the Configuration section of Step 0 as shown below.\n", + "\n", + "```python\n", + "cfg.DATA.DATASET = \"bindingdb\"\n", + "```\n", + "\n", + "Reload the dataset and re-run training and testing.\n", + "\n", + "> Tip: See if the model struggles more or less with the new dataset. It can reveal how generalisable DrugBAN is.\n" + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "### Task 2\n", + "\n", + "Turn off domain adaptation by updating the config file and re-running training and testing.\n", + "\n", + "Replace `experiments/DA_cross_domain.yaml` with `experiments/non_DA_cross_domain.yaml` in the Configuration section of Step 0 as shown below.\n", + "\n", + "```python\n", + "cfg.merge_from_file(\"experiments/non_DA_cross_domain.yaml\")\n", + "```\n", + ">Tip: Compare the results with and without domain adaptation to see how it affects model performance." + ], + "cell_type": "markdown" + } + ] +}