diff --git a/tutorials/drug-target-interaction/notebook-cross-domain.ipynb b/tutorials/drug-target-interaction/notebook-cross-domain.ipynb index e6b988c..56ad4c0 100644 --- a/tutorials/drug-target-interaction/notebook-cross-domain.ipynb +++ b/tutorials/drug-target-interaction/notebook-cross-domain.ipynb @@ -11,358 +11,63 @@ { "metadata": {}, "source": [ - "!pip install \"numpy<2.0\" \"transformers==4.30.2\" --force-reinstall --quiet" + "# Drug–Target Interaction Prediction\n" ], - "cell_type": "code", - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Collecting numpy<2.0\n", - " Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m61.0/61.0 kB\u001b[0m \u001b[31m3.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting transformers==4.30.2\n", - " Downloading transformers-4.30.2-py3-none-any.whl.metadata (113 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m113.6/113.6 kB\u001b[0m \u001b[31m5.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting filelock (from transformers==4.30.2)\n", - " Downloading filelock-3.18.0-py3-none-any.whl.metadata (2.9 kB)\n", - "Collecting huggingface-hub<1.0,>=0.14.1 (from transformers==4.30.2)\n", - " Downloading huggingface_hub-0.33.0-py3-none-any.whl.metadata (14 kB)\n", - "Collecting packaging>=20.0 (from transformers==4.30.2)\n", - " Downloading packaging-25.0-py3-none-any.whl.metadata (3.3 kB)\n", - "Collecting pyyaml>=5.1 (from transformers==4.30.2)\n", - " Downloading PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.1 kB)\n", - "Collecting regex!=2019.12.17 (from transformers==4.30.2)\n", - " Downloading regex-2024.11.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (40 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.5/40.5 kB\u001b[0m \u001b[31m2.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting requests (from transformers==4.30.2)\n", - " Downloading requests-2.32.4-py3-none-any.whl.metadata (4.9 kB)\n", - "Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers==4.30.2)\n", - " Downloading tokenizers-0.13.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)\n", - "Collecting safetensors>=0.3.1 (from transformers==4.30.2)\n", - " Downloading safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.8 kB)\n", - "Collecting tqdm>=4.27 (from transformers==4.30.2)\n", - " Downloading tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m57.7/57.7 kB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting fsspec>=2023.5.0 (from huggingface-hub<1.0,>=0.14.1->transformers==4.30.2)\n", - " Downloading fsspec-2025.5.1-py3-none-any.whl.metadata (11 kB)\n", - "Collecting typing-extensions>=3.7.4.3 (from huggingface-hub<1.0,>=0.14.1->transformers==4.30.2)\n", - " Downloading typing_extensions-4.14.0-py3-none-any.whl.metadata (3.0 kB)\n", - "Collecting hf-xet<2.0.0,>=1.1.2 (from huggingface-hub<1.0,>=0.14.1->transformers==4.30.2)\n", - " Downloading hf_xet-1.1.5-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (879 bytes)\n", - "Collecting charset_normalizer<4,>=2 (from requests->transformers==4.30.2)\n", - " Downloading charset_normalizer-3.4.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (35 kB)\n", - "Collecting idna<4,>=2.5 (from requests->transformers==4.30.2)\n", - " Downloading idna-3.10-py3-none-any.whl.metadata (10 kB)\n", - "Collecting urllib3<3,>=1.21.1 (from requests->transformers==4.30.2)\n", - " Downloading urllib3-2.5.0-py3-none-any.whl.metadata (6.5 kB)\n", - "Collecting certifi>=2017.4.17 (from requests->transformers==4.30.2)\n", - " Downloading certifi-2025.6.15-py3-none-any.whl.metadata (2.4 kB)\n", - "Downloading transformers-4.30.2-py3-none-any.whl (7.2 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.2/7.2 MB\u001b[0m \u001b[31m99.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m18.3/18.3 MB\u001b[0m \u001b[31m110.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading huggingface_hub-0.33.0-py3-none-any.whl (514 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m514.8/514.8 kB\u001b[0m \u001b[31m36.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading packaging-25.0-py3-none-any.whl (66 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m66.5/66.5 kB\u001b[0m \u001b[31m5.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (762 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m763.0/763.0 kB\u001b[0m \u001b[31m41.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading regex-2024.11.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (792 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m792.7/792.7 kB\u001b[0m \u001b[31m47.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (471 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m471.6/471.6 kB\u001b[0m \u001b[31m38.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading tokenizers-0.13.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m75.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading tqdm-4.67.1-py3-none-any.whl (78 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m78.5/78.5 kB\u001b[0m \u001b[31m6.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading filelock-3.18.0-py3-none-any.whl (16 kB)\n", - "Downloading requests-2.32.4-py3-none-any.whl (64 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m64.8/64.8 kB\u001b[0m \u001b[31m5.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading certifi-2025.6.15-py3-none-any.whl (157 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m157.7/157.7 kB\u001b[0m \u001b[31m12.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading charset_normalizer-3.4.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (147 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m147.3/147.3 kB\u001b[0m \u001b[31m13.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading fsspec-2025.5.1-py3-none-any.whl (199 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m199.1/199.1 kB\u001b[0m \u001b[31m17.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading hf_xet-1.1.5-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.1/3.1 MB\u001b[0m \u001b[31m97.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading idna-3.10-py3-none-any.whl (70 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m70.4/70.4 kB\u001b[0m \u001b[31m5.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading typing_extensions-4.14.0-py3-none-any.whl (43 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.8/43.8 kB\u001b[0m \u001b[31m3.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading urllib3-2.5.0-py3-none-any.whl (129 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m129.8/129.8 kB\u001b[0m \u001b[31m10.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hInstalling collected packages: tokenizers, urllib3, typing-extensions, tqdm, safetensors, regex, pyyaml, packaging, numpy, idna, hf-xet, fsspec, filelock, charset_normalizer, certifi, requests, huggingface-hub, transformers\n", - " Attempting uninstall: tokenizers\n", - " Found existing installation: tokenizers 0.21.1\n", - " Uninstalling tokenizers-0.21.1:\n", - " Successfully uninstalled tokenizers-0.21.1\n", - " Attempting uninstall: urllib3\n", - " Found existing installation: urllib3 2.4.0\n", - " Uninstalling urllib3-2.4.0:\n", - " Successfully uninstalled urllib3-2.4.0\n", - " Attempting uninstall: typing-extensions\n", - " Found existing installation: typing_extensions 4.14.0\n", - " Uninstalling typing_extensions-4.14.0:\n", - " Successfully uninstalled typing_extensions-4.14.0\n", - " Attempting uninstall: tqdm\n", - " Found existing installation: tqdm 4.67.1\n", - " Uninstalling tqdm-4.67.1:\n", - " Successfully uninstalled tqdm-4.67.1\n", - " Attempting uninstall: safetensors\n", - " Found existing installation: safetensors 0.5.3\n", - " Uninstalling safetensors-0.5.3:\n", - " Successfully uninstalled safetensors-0.5.3\n", - " Attempting uninstall: regex\n", - " Found existing installation: regex 2024.11.6\n", - " Uninstalling regex-2024.11.6:\n", - " Successfully uninstalled regex-2024.11.6\n", - " Attempting uninstall: pyyaml\n", - " Found existing installation: PyYAML 6.0.2\n", - " Uninstalling PyYAML-6.0.2:\n", - " Successfully uninstalled PyYAML-6.0.2\n", - " Attempting uninstall: packaging\n", - " Found existing installation: packaging 24.2\n", - " Uninstalling packaging-24.2:\n", - " Successfully uninstalled packaging-24.2\n", - " Attempting uninstall: numpy\n", - " Found existing installation: numpy 2.0.2\n", - " Uninstalling numpy-2.0.2:\n", - " Successfully uninstalled numpy-2.0.2\n", - " Attempting uninstall: idna\n", - " Found existing installation: idna 3.10\n", - " Uninstalling idna-3.10:\n", - " Successfully uninstalled idna-3.10\n", - " Attempting uninstall: hf-xet\n", - " Found existing installation: hf-xet 1.1.3\n", - " Uninstalling hf-xet-1.1.3:\n", - " Successfully uninstalled hf-xet-1.1.3\n", - " Attempting uninstall: fsspec\n", - " Found existing installation: fsspec 2025.3.2\n", - " Uninstalling fsspec-2025.3.2:\n", - " Successfully uninstalled fsspec-2025.3.2\n", - " Attempting uninstall: filelock\n", - " Found existing installation: filelock 3.18.0\n", - " Uninstalling filelock-3.18.0:\n", - " Successfully uninstalled filelock-3.18.0\n", - " Attempting uninstall: charset_normalizer\n", - " Found existing installation: charset-normalizer 3.4.2\n", - " Uninstalling charset-normalizer-3.4.2:\n", - " Successfully uninstalled charset-normalizer-3.4.2\n", - " Attempting uninstall: certifi\n", - " Found existing installation: certifi 2025.6.15\n", - " Uninstalling certifi-2025.6.15:\n", - " Successfully uninstalled certifi-2025.6.15\n", - " Attempting uninstall: requests\n", - " Found existing installation: requests 2.32.3\n", - " Uninstalling requests-2.32.3:\n", - " Successfully uninstalled requests-2.32.3\n", - " Attempting uninstall: huggingface-hub\n", - " Found existing installation: huggingface-hub 0.33.0\n", - " Uninstalling huggingface-hub-0.33.0:\n", - " Successfully uninstalled huggingface-hub-0.33.0\n", - " Attempting uninstall: transformers\n", - " Found existing installation: transformers 4.52.4\n", - " Uninstalling transformers-4.52.4:\n", - " Successfully uninstalled transformers-4.52.4\n", - "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", - "google-colab 1.0.0 requires requests==2.32.3, but you have requests 2.32.4 which is incompatible.\n", - "gcsfs 2025.3.2 requires fsspec==2025.3.2, but you have fsspec 2025.5.1 which is incompatible.\n", - "langchain-core 0.3.65 requires packaging<25,>=23.2, but you have packaging 25.0 which is incompatible.\n", - "thinc 8.3.6 requires numpy<3.0.0,>=2.0.0, but you have numpy 1.26.4 which is incompatible.\n", - "sentence-transformers 4.1.0 requires transformers<5.0.0,>=4.41.0, but you have transformers 4.30.2 which is incompatible.\n", - "torch 2.6.0+cu124 requires nvidia-cublas-cu12==12.4.5.8; platform_system == \"Linux\" and platform_machine == \"x86_64\", but you have nvidia-cublas-cu12 12.5.3.2 which is incompatible.\n", - "torch 2.6.0+cu124 requires nvidia-cuda-cupti-cu12==12.4.127; platform_system == \"Linux\" and platform_machine == \"x86_64\", but you have nvidia-cuda-cupti-cu12 12.5.82 which is incompatible.\n", - "torch 2.6.0+cu124 requires nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == \"Linux\" and platform_machine == \"x86_64\", but you have nvidia-cuda-nvrtc-cu12 12.5.82 which is incompatible.\n", - "torch 2.6.0+cu124 requires nvidia-cuda-runtime-cu12==12.4.127; platform_system == \"Linux\" and platform_machine == \"x86_64\", but you have nvidia-cuda-runtime-cu12 12.5.82 which is incompatible.\n", - "torch 2.6.0+cu124 requires nvidia-cudnn-cu12==9.1.0.70; platform_system == \"Linux\" and platform_machine == \"x86_64\", but you have nvidia-cudnn-cu12 9.3.0.75 which is incompatible.\n", - "torch 2.6.0+cu124 requires nvidia-cufft-cu12==11.2.1.3; platform_system == \"Linux\" and platform_machine == \"x86_64\", but you have nvidia-cufft-cu12 11.2.3.61 which is incompatible.\n", - "torch 2.6.0+cu124 requires nvidia-curand-cu12==10.3.5.147; platform_system == \"Linux\" and platform_machine == \"x86_64\", but you have nvidia-curand-cu12 10.3.6.82 which is incompatible.\n", - "torch 2.6.0+cu124 requires nvidia-cusolver-cu12==11.6.1.9; platform_system == \"Linux\" and platform_machine == \"x86_64\", but you have nvidia-cusolver-cu12 11.6.3.83 which is incompatible.\n", - "torch 2.6.0+cu124 requires nvidia-cusparse-cu12==12.3.1.170; platform_system == \"Linux\" and platform_machine == \"x86_64\", but you have nvidia-cusparse-cu12 12.5.1.3 which is incompatible.\n", - "torch 2.6.0+cu124 requires nvidia-nvjitlink-cu12==12.4.127; platform_system == \"Linux\" and platform_machine == \"x86_64\", but you have nvidia-nvjitlink-cu12 12.5.82 which is incompatible.\u001b[0m\u001b[31m\n", - "\u001b[0mSuccessfully installed certifi-2025.6.15 charset_normalizer-3.4.2 filelock-3.18.0 fsspec-2025.5.1 hf-xet-1.1.5 huggingface-hub-0.33.0 idna-3.10 numpy-1.26.4 packaging-25.0 pyyaml-6.0.2 regex-2024.11.6 requests-2.32.4 safetensors-0.5.3 tokenizers-0.13.3 tqdm-4.67.1 transformers-4.30.2 typing-extensions-4.14.0 urllib3-2.5.0\n" - ] - }, - { - "output_type": "display_data", - "data": { - "application/vnd.colab-display-data+json": { - "pip_warning": { - "packages": [ - "certifi", - "numpy", - "packaging" - ] - }, - "id": "45157e696c714aad87be51c98f340caf" - } - }, - "metadata": {} - } - ], - "execution_count": null - }, - { - "metadata": {}, - "source": [ - "# ✅ Now run this AFTER restarting runtime\n", - "import numpy as np\n", - "\n", - "print(\"NumPy version:\", np.__version__) # should be <2.0" - ], - "cell_type": "code", - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "NumPy version: 1.26.4\n" - ] - } - ], - "execution_count": null + "cell_type": "markdown" }, { "metadata": {}, "source": [ - "# **Drug–Target Interaction Prediction**\n", + "## Introduction\n", + "\n", + "In this tutorial, we demonstrate the standard pipeline in `PyKale` and show how to integrate multimodal data from **drugs** and **proteins** to perform **drug-target interaction (DTI) prediction**.\n", "\n", - "Welcome to this tutorial on drug–target interaction (DTI) prediction using the **PyKale library**.\n", "\n", - "**PyKale** is a Python toolkit that helps make machine learning more approachable, especially for researchers working in interdisciplinary fields. It is particularly useful when dealing with **multimodal data**, which simply means combining different types of data — for example, information about drugs and proteins — to learn patterns from them together.\n", - "\n", - "Even if you’re new to Python or machine learning, don’t worry — we’ll explain key concepts as we go.\n", - "\n", - " \n", - "\n", - "---\n", - "\n", - " \n", "\n", "This tutorial builds on the work of [**Bai et al. (_Nature Machine Intelligence_, 2023)**](https://www.nature.com/articles/s42256-022-00605-1), which introduced the **DrugBAN** framework. The DrugBAN includes two key ideas:\n", "\n", "- A **bilinear attention network (BAN)**. This is a model that learns the features of both the drug and the protein, and how these features interact locally.\n", "\n", "\n", - "- **Adversarial domain adaptation**. This is a method that helps the model generalise to data that is different from what it was trained on (also known as out-of-distribution data), improving its performance on unseen drug–target pairs.\n", - "\n", - " \n", - "\n", - "---\n", - "\n", - "\n", - " \n", - "\n", - "## 🔍 What You'll Learn\n", - "\n", - "In the sections that follow, we’ll guide you through the PyKale development pipeline. Specifically, you will learn how to use PyKale to:\n", - "\n", - "- Load and preprocess the data\n", - "\n", - "- Set up the model and the training process\n", - "\n", - "- Train and test the model\n", - "\n", - "Finally, we will compare the results from DrugBAN with those from other established models.\n", - "\n", - "Let’s get started!\n" + "- **Adversarial domain adaptation**. This is a method that helps the model generalise to data that is different from what it was trained on (also known as out-of-distribution data), improving its performance on unseen drug–target pairs.\n" ], "cell_type": "markdown" }, { "metadata": {}, "source": [ - "## Setup\n", + "## Problem Formulation\n", "\n", - "To begin, we will install the necessary packages required for this tutorial. To maintain clarity and focus on interpretation, we will also suppress any warnings." - ], - "cell_type": "markdown" - }, - { - "metadata": {}, - "source": [ - "import os\n", - "import warnings\n", + "This tutorial focuses on the drug–target interaction (DTI) prediction problem, which is framed as a binary classification task. The inputs are drug SMILES strings and protein amino acid sequences, and the output is a binary label (1 or 0) indicating whether an interaction occurs.\n", "\n", - "warnings.filterwarnings(\"ignore\")\n", - "os.environ[\"PYTHONWARNINGS\"] = \"ignore\"" - ], - "cell_type": "code", - "outputs": [], - "execution_count": null - }, - { - "metadata": {}, - "source": [ - "[Optional] If you are using Google Colab, please using the following codes to load necessary demo data and code files." + "We will work with two datasets: **BioSNAP** and **BindingDB**. The main tutorial will use the BioSNAP dataset, while BindingDB is provided as an additional dataset for you to explore and reproduce results in your own time after completing the tutorial." ], "cell_type": "markdown" }, { "metadata": {}, "source": [ - "!git clone --branch drug-target-interaction https://github.com/pykale/embc-mmai25.git\n", - "%cd /content/embc-mmai25/tutorials/drug-target-interaction" + "## Objective\n", + "- Understand the standard pipeline of `PyKale` library.\n" ], - "cell_type": "code", - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Cloning into 'embc-mmai25'...\n", - "remote: Enumerating objects: 1165, done.\u001b[K\n", - "remote: Counting objects: 100% (186/186), done.\u001b[K\n", - "remote: Compressing objects: 100% (128/128), done.\u001b[K\n", - "remote: Total 1165 (delta 87), reused 119 (delta 54), pack-reused 979 (from 1)\u001b[K\n", - "Receiving objects: 100% (1165/1165), 128.89 MiB | 34.22 MiB/s, done.\n", - "Resolving deltas: 100% (543/543), done.\n", - "/content/embc-mmai25/tutorials/drug-target-interaction\n" - ] - } - ], - "execution_count": null + "cell_type": "markdown" }, { "metadata": {}, "source": [ - "from google.colab import drive\n", - "\n", - "drive.mount(\"/content/drive\")\n", - "\n", - "shared_drives_path = (\n", - " \"/content/drive/Shared drives/EMBC-MMAI 25 Workshop/data/drug-target-interaction\"\n", - ")\n", + "## Environment Preparation\n", "\n", - "import os\n", - "import shutil\n", + "As a starting point, we will install the required packages and load a set of helper functions to assist throughout this tutorial. To keep the output clean and focused on interpretation, we will also suppress warnings.\n", "\n", - "print(\"Contents of the folder:\")\n", - "for item in os.listdir(shared_drives_path):\n", - " print(item)" - ], - "cell_type": "code", - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Mounted at /content/drive\n", - "Contents of the folder:\n", - "bindingdb\n", - "biosnap\n" - ] - } + "Moreover, we provide helper functions that can be inspected directly in the `.py` files located in the notebook's current directory. The additional helper script is:\n", + "- [`config.py`](https://github.com/pykale/embc-mmai25/blob/main/tutorials/drug-target-interaction/configs.py): Defines the base configuration settings, which can be overridden using a custom `.yaml` file." ], - "execution_count": null + "cell_type": "markdown" }, { "metadata": {}, "source": [ - "### 📦 Packages\n", + "### Package Installation\n", "\n", "The main packages required for this tutorial are **PyKale**, **PyTorch Geometric**, and **RDKit**.\n", "\n", @@ -370,21 +75,35 @@ "- **PyG** (PyTorch Geometric) is a library built on top of PyTorch for building and training Graph Neural Networks (GNNs) on structured data.\n", "- **RDKit** is a cheminformatics toolkit for handling and processing molecular structures, particularly useful for working with SMILES strings and molecular graphs.\n", "\n", - "Other dependencies are listed in [`embc-mmai25/requirements.txt`](https://github.com/pykale/embc-mmai25/blob/main/requirements.txt).\n" + "Other required packages can be found in [`embc-mmai25/requirements.txt`](https://github.com/pykale/embc-mmai25/blob/main/requirements.txt).\n", + "\n", + "--------------------------------------------------------------------------------\n", + "\n", + "#### **WARNINGS**\n", + "Please don't re-run this session after installation completed. Runing this installation multiple times will trigger issues related to `PyG`. If you want to re-run this installation, please click the `Runtime` on the top menu and choose `Disconnect and delete runtime` before installing.\n" ], "cell_type": "markdown" }, { "metadata": {}, "source": [ - "!pip install --quiet git+https://github.com/pykale/pykale@main\\\n", - " && echo \"PyKale installed successfully ✅\" \\\n", - " || echo \"Failed to install PyKale ❌\"\n", + "import os\n", + "import warnings\n", + "\n", + "warnings.filterwarnings(\"ignore\")\n", + "os.environ[\"PYTHONWARNINGS\"] = \"ignore\"\n", + "\n", + "!git clone https://github.com/pykale/embc-mmai25.git\n", + "%cd /content/embc-mmai25/tutorials/drug-target-interaction\n", "\n", "!pip install --quiet -r /content/embc-mmai25/requirements.txt \\\n", " && echo \"Required packages installed successfully ✅\" \\\n", " || echo \"Failed to install required packages ❌\"\n", "\n", + "!pip install --quiet git+https://github.com/pykale/pykale@main\\\n", + " && echo \"PyKale installed successfully ✅\" \\\n", + " || echo \"Failed to install PyKale ❌\"\n", + "\n", "import torch\n", "os.environ['TORCH'] = torch.__version__\n", "!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html\n", @@ -396,7 +115,12 @@ "\n", "!pip install rdkit-pypi \\\n", " && echo \"PyG installed successfully ✅\" \\\n", - " || echo \"Failed to install PyG ❌\"" + " || echo \"Failed to install PyG ❌\"\n", + "\n", + "\n", + "# !pip install \"numpy<2.0\" \"transformers==4.30.2\" --force-reinstall --quiet\n", + "!pip install --upgrade --force-reinstall numpy==2.0.0\n", + "os.kill(os.getpid(), 9)" ], "cell_type": "code", "outputs": [ @@ -404,64 +128,127 @@ "output_type": "stream", "name": "stdout", "text": [ + "Cloning into 'embc-mmai25'...\n", + "remote: Enumerating objects: 1444, done.\u001b[K\n", + "remote: Counting objects: 100% (193/193), done.\u001b[K\n", + "remote: Compressing objects: 100% (108/108), done.\u001b[K\n", + "remote: Total 1444 (delta 111), reused 136 (delta 85), pack-reused 1251 (from 1)\u001b[K\n", + "Receiving objects: 100% (1444/1444), 15.80 MiB | 16.94 MiB/s, done.\n", + "Resolving deltas: 100% (752/752), done.\n", + "/content/embc-mmai25/tutorials/drug-target-interaction\n", " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m812.3/812.3 kB\u001b[0m \u001b[31m25.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m779.2/779.2 MB\u001b[0m \u001b[31m1.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m410.6/410.6 MB\u001b[0m \u001b[31m3.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.1/14.1 MB\u001b[0m \u001b[31m85.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m23.7/23.7 MB\u001b[0m \u001b[31m64.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.6/823.6 kB\u001b[0m \u001b[31m42.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m731.7/731.7 MB\u001b[0m \u001b[31m1.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m121.6/121.6 MB\u001b[0m \u001b[31m7.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.5/56.5 MB\u001b[0m \u001b[31m13.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m124.2/124.2 MB\u001b[0m \u001b[31m7.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m196.0/196.0 MB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m176.2/176.2 MB\u001b[0m \u001b[31m6.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m99.1/99.1 kB\u001b[0m \u001b[31m9.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m168.1/168.1 MB\u001b[0m \u001b[31m6.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.4/7.4 MB\u001b[0m \u001b[31m71.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.0/7.0 MB\u001b[0m \u001b[31m67.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m962.6/962.6 kB\u001b[0m \u001b[31m41.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h Building wheel for pykale (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", - "torchaudio 2.6.0+cu124 requires torch==2.6.0, but you have torch 2.3.0 which is incompatible.\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m45.0/45.0 kB\u001b[0m \u001b[31m3.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.6/8.6 MB\u001b[0m \u001b[31m88.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.4/10.4 MB\u001b[0m \u001b[31m69.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m80.3/80.3 kB\u001b[0m \u001b[31m7.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m83.2/83.2 kB\u001b[0m \u001b[31m7.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.4/3.4 MB\u001b[0m \u001b[31m104.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m434.0/434.0 kB\u001b[0m \u001b[31m35.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.6/4.6 MB\u001b[0m \u001b[31m105.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.2/2.2 MB\u001b[0m \u001b[31m86.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.4/40.4 kB\u001b[0m \u001b[31m3.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.4/127.4 kB\u001b[0m \u001b[31m11.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m78.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.4/1.4 MB\u001b[0m \u001b[31m63.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h\u001b[33m WARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33m WARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33m WARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "thinc 8.3.6 requires numpy<3.0.0,>=2.0.0, but you have numpy 1.26.4 which is incompatible.\n", "sentence-transformers 4.1.0 requires transformers<5.0.0,>=4.41.0, but you have transformers 4.30.2 which is incompatible.\u001b[0m\u001b[31m\n", - "\u001b[0mPyKale installed successfully ✅\n", - " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[0mRequired packages installed successfully ✅\n", + "\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m45.0/45.0 kB\u001b[0m \u001b[31m3.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.6/8.6 MB\u001b[0m \u001b[31m103.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.5/10.5 MB\u001b[0m \u001b[31m13.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m80.3/80.3 kB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m83.2/83.2 kB\u001b[0m \u001b[31m6.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.4/3.4 MB\u001b[0m \u001b[31m93.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m434.0/434.0 kB\u001b[0m \u001b[31m28.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.6/4.6 MB\u001b[0m \u001b[31m103.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.2/2.2 MB\u001b[0m \u001b[31m89.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.4/40.4 kB\u001b[0m \u001b[31m3.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.4/127.4 kB\u001b[0m \u001b[31m10.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m60.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.4/1.4 MB\u001b[0m \u001b[31m68.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequired packages installed successfully ✅\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.9/10.9 MB\u001b[0m \u001b[31m74.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.1/5.1 MB\u001b[0m \u001b[31m41.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "cupy-cuda12x 13.3.0 requires numpy<2.3,>=1.22, but you have numpy 2.3.1 which is incompatible.\n", + "tensorflow 2.18.0 requires numpy<2.1.0,>=1.26.0, but you have numpy 2.3.1 which is incompatible.\n", + "numba 0.60.0 requires numpy<2.1,>=1.22, but you have numpy 2.3.1 which is incompatible.\n", + "sentence-transformers 4.1.0 requires transformers<5.0.0,>=4.41.0, but you have transformers 4.30.2 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0mPyKale installed successfully ✅\n", + "\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "cupy-cuda12x 13.3.0 requires numpy<2.3,>=1.22, but you have numpy 2.3.1 which is incompatible.\n", + "tensorflow 2.18.0 requires numpy<2.1.0,>=1.26.0, but you have numpy 2.3.1 which is incompatible.\n", + "numba 0.60.0 requires numpy<2.1,>=1.22, but you have numpy 2.3.1 which is incompatible.\n", + "sentence-transformers 4.1.0 requires transformers<5.0.0,>=4.41.0, but you have transformers 4.30.2 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", " Building wheel for torch-geometric (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "PyG installed successfully ✅\n", - "Collecting rdkit-pypi\n", - " Downloading rdkit_pypi-2022.9.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.9 kB)\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from rdkit-pypi) (1.26.4)\n", + "\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33m WARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "pykale 0.1.2 requires torch-geometric==2.3.0, but you have torch-geometric 2.7.0 which is incompatible.\n", + "cupy-cuda12x 13.3.0 requires numpy<2.3,>=1.22, but you have numpy 2.3.1 which is incompatible.\n", + "tensorflow 2.18.0 requires numpy<2.1.0,>=1.26.0, but you have numpy 2.3.1 which is incompatible.\n", + "numba 0.60.0 requires numpy<2.1,>=1.22, but you have numpy 2.3.1 which is incompatible.\n", + "sentence-transformers 4.1.0 requires transformers<5.0.0,>=4.41.0, but you have transformers 4.30.2 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0mPyG installed successfully ✅\n", + "\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: rdkit-pypi in /usr/local/lib/python3.11/dist-packages (2022.9.5)\n", + "Collecting numpy (from rdkit-pypi)\n", + " Using cached numpy-2.3.1-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (62 kB)\n", "Requirement already satisfied: Pillow in /usr/local/lib/python3.11/dist-packages (from rdkit-pypi) (11.2.1)\n", - "Downloading rdkit_pypi-2022.9.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (29.4 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m29.4/29.4 MB\u001b[0m \u001b[31m71.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hInstalling collected packages: rdkit-pypi\n", - "Successfully installed rdkit-pypi-2022.9.5\n", - "PyG installed successfully ✅\n" + "Using cached numpy-2.3.1-cp311-cp311-manylinux_2_28_x86_64.whl (16.9 MB)\n", + "\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0mInstalling collected packages: numpy\n", + "\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "pykale 0.1.2 requires torch-geometric==2.3.0, but you have torch-geometric 2.7.0 which is incompatible.\n", + "cupy-cuda12x 13.3.0 requires numpy<2.3,>=1.22, but you have numpy 2.3.1 which is incompatible.\n", + "tensorflow 2.18.0 requires numpy<2.1.0,>=1.26.0, but you have numpy 2.3.1 which is incompatible.\n", + "numba 0.60.0 requires numpy<2.1,>=1.22, but you have numpy 2.3.1 which is incompatible.\n", + "sentence-transformers 4.1.0 requires transformers<5.0.0,>=4.41.0, but you have transformers 4.30.2 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0mSuccessfully installed numpy\n", + "PyG installed successfully ✅\n", + "\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0m\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0mCollecting numpy==2.0.0\n", + " Using cached numpy-2.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)\n", + "Using cached numpy-2.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (19.3 MB)\n", + "\u001b[33mWARNING: Ignoring invalid distribution ~umpy (/usr/local/lib/python3.11/dist-packages)\u001b[0m\u001b[33m\n", + "\u001b[0mInstalling collected packages: numpy\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "pykale 0.1.2 requires torch-geometric==2.3.0, but you have torch-geometric 2.7.0 which is incompatible.\n", + "sentence-transformers 4.1.0 requires transformers<5.0.0,>=4.41.0, but you have transformers 4.30.2 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0mSuccessfully installed numpy-2.0.0\n" ] } ], @@ -470,30 +257,85 @@ { "metadata": {}, "source": [ - "### ⚙️ Configuration\n", + "# (Optional: Numpy version check) ✅ Now run this AFTER restarting runtime\n", + "import numpy as np\n", "\n", - "Before running any model or data processing, we need to tell the code **what settings to use**. To make this easier, we provide a file called `config.py`. Think of this file as a **menu of default settings** that the rest of the code can refer to — for example, where to find the data, which model to use, and how many times to train it.\n", + "print(\"NumPy version:\", np.__version__) # should be <2.0" + ], + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "NumPy version: 2.0.0\n" + ] + } + ], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "### Mount Data (Optional)\n", "\n", - "You can find `config.py` in the same folder as this notebook. You don’t need to change it directly. Instead, we use a **YAML file** to customise the settings.\n", + "If you are using Google Colab, please using the following codes to load necessary datasets." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "from google.colab import drive\n", "\n", - "> **What is a YAML file?** \n", - "> YAML is a simple text file format often used for configuration. It lets you list out settings in a way that’s easier to read than raw Python code.\n", + "drive.mount(\"/content/drive\")\n", "\n", - "For example, we have a YAML file called `experiments/non_da_in_domain.yaml`. You can change this file to adjust things like:\n", + "shared_drives_path = (\n", + " \"/content/drive/Shared drives/EMBC-MMAI 25 Workshop/data/drug-target-interaction\"\n", + ")\n", "\n", - "- Which dataset to use \n", - "- How long to train the model \n", - "- Which model settings to apply \n", + "import os\n", + "import shutil\n", "\n", - "This helps keep your work organised and flexible. You don’t need to modify the original Python files — just change the YAML file instead.\n", + "print(\"Contents of the folder:\")\n", + "for item in os.listdir(shared_drives_path):\n", + " print(item)" + ], + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Mounted at /content/drive\n", + "Contents of the folder:\n", + "bindingdb\n", + "biosnap\n" + ] + } + ], + "execution_count": null + }, + { + "metadata": {}, + "source": [ + "### Configuration\n", "\n", - "Now let’s see how we actually load and apply the settings from a YAML file in Python." + "To minimize the footprint of the notebook when specifying configurations, we provide a [`config.py`](https://github.com/pykale/embc-mmai25/blob/main/tutorials/drug-target-interaction/configs.py) file that defines default parameters. These can be customized by supplying a `.yaml` configuration file, such as [`experiments/DA_cross_domain.yaml`](https://github.com/pykale/embc-mmai25/blob/main/tutorials/drug-target-interaction/experiments/DA_cross_domain.yaml) as an example.\n", + "\n", + "In this tutorial, we list the hyperparameters we would like users to play with outside the `.yaml` file:\n", + "- `cfg.SOLVER.MAX_EPOCH`: Number of epochs in training stage.\n", + "- `cfg.DATA.DATASET`: The dataset used in the study. This can be `bindingdb` or `biosnap`.\n", + "\n", + "As a quick exercise, please take a moment to review and understand the parameters in [`config.py`](https://github.com/pykale/embc-mmai25/blob/main/tutorials/drug-target-interaction/configs.py)." ], "cell_type": "markdown" }, { "metadata": {}, "source": [ + "%cd /content/embc-mmai25/tutorials/drug-target-interaction\n", + "\n", "from configs import get_cfg_defaults\n", "\n", "# Load the default settings from config.py\n", @@ -502,13 +344,14 @@ "# Update (or override) some of those settings using a custom YAML file\n", "cfg.merge_from_file(\"experiments/DA_cross_domain.yaml\")\n", "\n", - "# Example: temporarily shorten the training time by setting fewer training rounds\n", + "# ------ Hyperparameters to play with -----\n", + "# User can reduce the training epochs to decrease training time if necessary\n", "cfg.SOLVER.MAX_EPOCH = 2\n", "\n", - "# Example: switch the dataset to Biosnap\n", + "# User can change to a different dataset\n", "cfg.DATA.DATASET = \"biosnap\"\n", "\n", - "# Print the current settings to check what’s being used\n", + "# -----------------------------------------\n", "print(cfg)" ], "cell_type": "code", @@ -517,10 +360,11 @@ "output_type": "stream", "name": "stdout", "text": [ + "/content/embc-mmai25/tutorials/drug-target-interaction\n", "BCN:\n", " HEADS: 2\n", "COMET:\n", - " API_KEY: InDQ1UsqJt7QMiANWg55Ulebe\n", + " API_KEY: \n", " EXPERIMENT_NAME: DA_cross_domain\n", " PROJECT_NAME: drugban-23-May\n", " TAG: DrugBAN_CDAN\n", @@ -572,13 +416,16 @@ { "metadata": {}, "source": [ - "## Data Overview\n", + "## Data Loading and Pre-processing\n", "\n", - "In this tutorial, we use a benchmark dataset called **Biosnap**, which contains information about how well different drugs interact with specific proteins. This dataset has been preprocessed and provided by the authors of the **DrugBAN** paper. You can also find it in their [GitHub repository](https://github.com/peizhenbai/DrugBAN/tree/main).\n", - "\n", - "### 📁 Folder Structure\n", - "\n", - "The dataset is stored in a folder called `biosnap`, which contains a few subfolders. Each subfolder corresponds to a different experimental setting for training and testing machine learning models.\n", + "In this tutorial, we use a benchmark dataset called **Biosnap**, which contains information about how well different drugs interact with specific proteins. This dataset has been preprocessed and provided by the authors of the **DrugBAN** paper. You can also find it in their [GitHub repository](https://github.com/peizhenbai/DrugBAN/tree/main)." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ + "The Biosnap dataset is stored in a folder called `biosnap`, which contains a few subfolders. Each subfolder corresponds to a different experimental setting for training and testing machine learning models.\n", "\n", "Here is a simplified view of the folder structure:\n", "\n", @@ -601,24 +448,8 @@ "\n", "Each file listed here is in **CSV format**, which you can open using spreadsheet software (like Excel) or load into Python using tools like `pandas`. These files contain rows of data, with each row representing one drug–protein pair.\n", "\n", - "### 🧬 What’s Inside Each File?\n", - "\n", - "Each row of the dataset contains three key pieces of information:\n", - "\n", - "**SMILES** \n", - "This is a way to describe the structure of a drug molecule using a short string of letters and symbols. It’s a compact format called *Simplified Molecular Input Line Entry System*. You don’t need to understand chemistry to use this, but just know that this string uniquely represents a drug.\n", - "\n", - "**Protein Sequence** \n", - "This is a string of letters where each letter stands for an amino acid, the building blocks of proteins. For example, `MGYTSLLT...` is a short protein sequence.\n", - "\n", - "**Y** \n", - "This is the label or answer. It tells us whether the drug and the protein interact. \n", - "`1` means yes, they interact. \n", - "`0` means no, they do not interact.\n", - "\n", - "### 📊 Sample of the Data\n", "\n", - "Here’s what the data looks like in a table format:\n", + "Here’s what each csv file looks like in a table format:\n", "\n", "| SMILES | Protein Sequence | Y |\n", "|--------------------|--------------------------|---|\n", @@ -626,63 +457,38 @@ "| O=c1oc2c(O)c(…) | MMYSKLLTLTTL… | 0 |\n", "| CC(C)Oc1cc(N…) | MGMACLTMTEME… | 1 |\n", "\n", - "Each row shows one drug–protein pair. The goal of our machine learning model is to predict the last column (**Y**) — whether or not the drug and protein interact." - ], - "cell_type": "markdown" - }, - { - "metadata": {}, - "source": [ - "### 🧪 Preprocessing\n", - "\n", - "Before we train a model, we need to prepare the data in a format that the model can understand. This process is called **preprocessing**. In this task, we work with two types of biological data: **drugs** and **proteins**.\n", - "\n", - "#### How is the data represented?\n", + "Each row of the dataset contains three key pieces of information:\n", "\n", "**Drugs**: \n", "Drugs are often written as SMILES strings, which are like chemical formulas in text format (for example, `\"CC(=O)OC1=CC=CC=C1C(=O)O\"` is aspirin). \n", "\n", - "To make this information useful for machine learning, we convert each SMILES string into a **molecular graph**. In a molecular graph:\n", - "- Each **atom** is a node\n", - "- Each **bond** is an edge between nodes \n", - "\n", - " \n", - "\n", - "---\n", - "\n", - " \n", - "\n", - "**Proteins**: \n", - "Proteins are sequences of amino acids. We convert each sequence into numbers using:\n", - "\n", - "- **One-hot encoding**, which assigns each amino acid a unique numerical representation. A full sequence is then turned into an embedding-like vector, similar to how sentences are represented in natural language processing.\n", "\n", + "**Protein Sequence** \n", + "This is a string of letters where each letter stands for an amino acid, the building blocks of proteins. For example, `MGYTSLLT...` is a short protein sequence.\n", "\n", - " \n", "\n", - "---\n", + "**Y (Labels)**: \n", + "Each drug–protein pair is given a label:\n", + "- `1` if they interact\n", + "- `0` if they do not\n", "\n", - " \n", "\n", + "Each row shows one drug–protein pair. The goal of our machine learning model is to predict the last column (**Y**) — whether or not the drug and protein interact." + ], + "cell_type": "markdown" + }, + { + "metadata": {}, + "source": [ "\n", - "**Labels**: \n", - "Each drug–protein pair is given a label:\n", - "- `1` if they interact (i.e. the drug affects the protein)\n", - "- `0` if they do not\n", "\n", - " \n", + "To generate the drug graphs, we use the `kale.loaddata.molecular_datasets.smiles_to_graph` function, which converts SMILES strings into graph structures. For each molecule, atom-level features such as atomic number, degree, valence, and aromaticity are encoded as node features. Bond information is represented through edge indices and edge attributes. The function automatically adds self-loops to all nodes to ensure that each node has at least one connection. For molecules with fewer atoms than the maximum allowed, the function applies node padding by adding virtual nodes with zero features.\n", "\n", - "---\n", "\n", - " \n", + "We use the kale.prepdata.`chem_transform.integer_label_protein` function to convert protein sequences into fixed-length integer arrays. Each amino acid is mapped to a unique integer based on a predefined dictionary (CHARPROTSET). Sequences longer than the maximum length (default: 1200) are truncated, while shorter sequences are zero-padded. Unknown characters are treated as padding, ensuring all protein inputs have a consistent numerical format.\n", "\n", - "#### How is preprocessing handled in the code?\n", "\n", - "We use a class called `DTIDataset`, provided by **PyKale**, to handle all of this preprocessing for us. It takes care of:\n", - "- Reading the data\n", - "- Converting drugs to molecular graphs\n", - "- Encoding protein sequences\n", - "- Assigning labels to each pair\n" + "We then use the `kale.loaddata.molecular_datasets.DTIDataset` class to integrate these steps by organising the drug-protein-label triplets into a dataset format compatible with PyTorch. During training and evaluation, the DataLoader calls graph_collate_func to batch the molecular graphs, protein sequences, and labels into a single batch. The output is a batched drug graph, a stacked protein sequence tensor, and a label tensor, ready for input into the DrugBAN model." ], "cell_type": "markdown" }, @@ -709,13 +515,24 @@ "test_target_dataset = DTIDataset(df_test_target.index.values, df_test_target)" ], "cell_type": "code", - "outputs": [], + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.11/dist-packages/torch_geometric/__init__.py:4: UserWarning: An issue occurred while importing 'torch-scatter'. Disabling its usage. Stacktrace: /usr/local/lib/python3.11/dist-packages/torch_scatter/_version_cuda.so: undefined symbol: _ZN5torch3jit17parseSchemaOrNameERKSsb\n", + " import torch_geometric.typing\n", + "/usr/local/lib/python3.11/dist-packages/torch_geometric/__init__.py:4: UserWarning: An issue occurred while importing 'torch-sparse'. Disabling its usage. Stacktrace: /usr/local/lib/python3.11/dist-packages/torch_sparse/_version_cuda.so: undefined symbol: _ZN5torch3jit17parseSchemaOrNameERKSsb\n", + " import torch_geometric.typing\n" + ] + } + ], "execution_count": null }, { "metadata": {}, "source": [ - "### 🗂️ Dataset Inspection\n", + "### Dataset Inspection\n", "\n", "Once we’ve loaded the dataset, it's useful to take a quick look at what it contains. This helps us understand the data format and what kind of information we’ll be working with in the rest of the tutorial.\n", "\n", @@ -753,7 +570,7 @@ "Train samples from source domain: 9766, Train samples from target domain: 3628, Test samples from target domain: 907\n", "\n", "An example sample from source domain:\n", - "(Data(x=[290, 7], edge_index=[2, 58], edge_attr=[58, 1], num_nodes=290), array([11., 1., 18., ..., 0., 0., 0.]), 0.0)\n" + "(Data(x=[290, 7], edge_index=[2, 58], edge_attr=[58, 1], num_nodes=290), array([11., 1., 18., ..., 0., 0., 0.]), np.float64(0.0))\n" ] } ], @@ -762,8 +579,6 @@ { "metadata": {}, "source": [ - "### 🧾 Example Sample Explained\n", - "\n", "Let’s break down what this example from the **source domain** means:\n", "\n", "```\n", @@ -773,9 +588,9 @@ "\n", "This sample is a tuple with **three parts**:\n", "\n", - "---\n", "\n", - "#### 1. **Drug Graph (Data object)**\n", + "\n", + "1. **Drug Graph (Data object)**\n", "\n", "This part is a graph-based representation of the **drug**, built using the PyTorch Geometric `Data` object:\n", "\n", @@ -794,18 +609,17 @@ "- `num_nodes=290` \n", " This confirms that the graph has **290 atoms (nodes)**.\n", "\n", - "---\n", "\n", - "#### 2. **Protein Features (array)**\n", + "\n", + "2. **Protein Features (array)**\n", "\n", "- This is a **1D array** (or vector) representing the **protein**. \n", "- It contains numerical features extracted from the protein sequence or structure. \n", "- Example values: `[11., 1., 18., ..., 0., 0., 0.]` \n", " These could represent biochemical or structural properties, with padding at the end (zeros) to ensure a consistent input size.\n", "\n", - "---\n", "\n", - "#### 3. **Label (float)**\n", + "3. **Label (float)**\n", "\n", "- `0.0` \n", " This is the **label**, which tells us the ground truth: \n", @@ -813,7 +627,7 @@ "\n", " If the label were `1.0`, it would mean they **do interact**.\n", "\n", - "---\n", + "\n", "\n", "This format allows the model to learn from both structured graph data (the drug) and feature-based data (the protein), and predict whether they interact based on the label.\n" ], @@ -822,20 +636,13 @@ { "metadata": {}, "source": [ - "### 🧱 Batching\n", + "### Batching\n", "\n", "When training machine learning models, especially on large datasets like molecular graphs, it’s inefficient and memory-intensive to load everything at once. Instead, we split the data into **mini-batches** and feed them into the model one at a time. This process is called **batching**, i.e, loading data in manageable pieces.\n", "\n", - "In this tutorial, we use PyTorch’s `DataLoader` to help us do this. A `DataLoader` handles the process of batching, shuffling, and loading data efficiently during training and evaluation.\n", - "\n", - "However, because molecular data involves **graphs of different sizes and shapes**, we can't just stack them like regular tables or images. That’s where a custom helper function called `graph_collate_func` comes in. This function tells the `DataLoader` how to correctly combine graphs of different structures into a batch.\n", - "\n", - "#### 🔄 Training vs Testing\n", - "\n", - "- During **training**, we shuffle the data randomly. This helps the model generalise better and prevents it from learning the order of the data.\n", - "- During **validation and testing**, we **don’t** shuffle the data. This ensures consistent and reproducible evaluation.\n", + "We use PyTorch’s `DataLoader` to efficiently batch and load samples during training and evaluation. For training, we create two separate data loaders: one for the source domain and one for the target domain. To enable domain adaptation, we combine them using `kale.loaddata.sampler.MultiDataLoader`, which yields one batch from each domain at every training step and ensures a consistent number of batches per epoch by automatically restarting smaller datasets when needed.\n", "\n", - "Now let’s see how this looks in code." + "However, because molecular data involves graphs of varying sizes and structures, we cannot stack them like regular tensors or images. To handle this, we use a custom collate function called `kale.loaddata.molecular_datasets.graph_collate_func`, which tells the `DataLoader` how to correctly combine multiple graphs into a single batch that the model can process." ], "cell_type": "markdown" }, @@ -885,34 +692,20 @@ { "metadata": {}, "source": [ - "## Model and Trainer Overview\n", + "## Model Definition\n", "\n", - "In this section, we'll look at the model and trainer we're using and how to set it up in your code. Don’t worry if the names sound technical — we’ll break them down for you." - ], - "cell_type": "markdown" - }, - { - "metadata": {}, - "source": [ - "### 🏗️ Setting Up the DrugBAN Model\n", + "DrugBAN consists of three main components: a Graph Convolutional Network (GCN) for extracting structural features from drug molecular graphs, a Convolutional Neural Network (CNN) for encoding protein sequences, and a Bilinear Attention Network (BAN) for fusing drug and protein features. The fused representation is then passed through a Multi-Layer Perceptron (MLP) classifier to predict interaction scores.\n", "\n", - "The **DrugBAN** model is designed to predict whether a drug and protein interact. It brings together different parts of the data using specialised tools from deep learning.\n", + "We define the DrugBAN class in `kale.embed.ban`, which wraps all key modules of the DrugBAN pipeline based on the configuration.\n", + "This wrapper handles:\n", "\n", - "Here’s what DrugBAN is made of:\n", + "- Initialising the GCN-based drug feature extractor (MolecularGCN).\n", "\n", - "**1. GCN (Graph Convolutional Network)** \n", - "This handles the structure of drug molecules. It treats each molecule as a graph — where atoms are nodes and bonds are edges — and learns useful patterns from it.\n", + "- Building the CNN-based protein sequence encoder (ProteinCNN).\n", "\n", - "**2. CNN (Convolutional Neural Network)** \n", - "This works with the protein sequences. Think of it like scanning the sequence for patterns, just like image recognition scans for edges or shapes.\n", + "- Integrating the BAN layer for drug-protein feature fusion (BANLayer).\n", "\n", - "**3. BAN (Bilinear Attention Network)** \n", - "This connects the drug and protein features and helps the model learn **how parts of the drug interact with parts of the protein**.\n", - "\n", - "**4. MLP (Multilayer Perceptron)** \n", - "This is the final decision-maker. It takes all the features the model has learned and makes the final prediction: will this drug bind to this protein?\n", - "\n", - "Here’s how you can create the model in your code:" + "- Creating the MLP classifier for final prediction (MLPDecoder)." ], "cell_type": "markdown" }, @@ -928,84 +721,14 @@ "print(model)" ], "cell_type": "code", - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "DrugBAN(\n", - " (drug_extractor): MolecularGCN(\n", - " (init_transform): Linear(in_features=7, out_features=128, bias=False)\n", - " (gcn_layers): ModuleList(\n", - " (0-2): 3 x GCNConv(128, 128)\n", - " )\n", - " )\n", - " (protein_extractor): ProteinCNN(\n", - " (embedding): Embedding(26, 128, padding_idx=0)\n", - " (conv1): Conv1d(128, 128, kernel_size=(3,), stride=(1,))\n", - " (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (conv2): Conv1d(128, 128, kernel_size=(6,), stride=(1,))\n", - " (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (conv3): Conv1d(128, 128, kernel_size=(9,), stride=(1,))\n", - " (bn3): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " (bcn): BANLayer(\n", - " (v_net): FCNet(\n", - " (main): Sequential(\n", - " (0): Dropout(p=0.2, inplace=False)\n", - " (1): Linear(in_features=128, out_features=768, bias=True)\n", - " (2): ReLU()\n", - " )\n", - " )\n", - " (q_net): FCNet(\n", - " (main): Sequential(\n", - " (0): Dropout(p=0.2, inplace=False)\n", - " (1): Linear(in_features=128, out_features=768, bias=True)\n", - " (2): ReLU()\n", - " )\n", - " )\n", - " (p_net): AvgPool1d(kernel_size=(3,), stride=(3,), padding=(0,))\n", - " (bn): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " (mlp_classifier): MLPDecoder(\n", - " (fc1): Linear(in_features=256, out_features=512, bias=True)\n", - " (bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (fc2): Linear(in_features=512, out_features=512, bias=True)\n", - " (bn2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (fc3): Linear(in_features=512, out_features=128, bias=True)\n", - " (bn3): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (fc4): Linear(in_features=128, out_features=2, bias=True)\n", - " )\n", - ")\n" - ] - } - ], + "outputs": [], "execution_count": null }, { "metadata": {}, "source": [ - "### 🏋️♀️ Setup Trainer\n", - "\n", - "In this section, we will set up the **training process** using **PyTorch Lightning**, a high-level library that simplifies training loops and experiment tracking in deep learning. Think of it as a way to organise all the messy training code into something tidy and reusable.\n", - "\n", - "We will use a training class called `DrugbanTrainer`, which is part of **PyKale**, and handles model training, domain adaptation, and evaluation.\n", - "\n", - "The values for the trainer's setup come from a configuration file written in YAML. If you are curious about what each setting means, check the YAML file. We've added comments there to explain each parameter.\n", - "\n", - "---\n", - "\n", - "Step 1: Initialise the Trainer\n", - "```python\n", - "from kale.pipeline.drugban_trainer import DrugbanTrainer\n", - "\n", - "drugban_trainer = DrugbanTrainer(\n", - " model=DrugBAN(**cfg),\n", - " solver_lr=cfg[\"SOLVER\"][\"LEARNING_RATE\"], # learning rate for training the model\n", - " num_classes=cfg[\"DECODER\"][\"BINARY\"], # number of output classes (1 for binary classification)\n", - " batch_size=cfg[\"SOLVER\"][\"BATCH_SIZE\"], # how many samples the model sees at once\n", - "\n", - " # Domain adaptation settings (you can think of this as helping the model\n" + "## Model Training\n", + "We use the training class `kale.pipeline.drugban_trainer`, which handles model training, domain adaptation, and evaluation for DrugBAN." ], "cell_type": "markdown" }, @@ -1034,15 +757,22 @@ ")" ], "cell_type": "code", - "outputs": [], + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.11/dist-packages/torch/nn/utils/weight_norm.py:28: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.\n", + " warnings.warn(\"torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.\")\n" + ] + } + ], "execution_count": null }, { "metadata": {}, "source": [ - "Step 2: Setup Checkpointing\n", - "\n", - "We want to save the best model during training. This helps you avoid rerunning training if you want to reuse the model later." + "We want to save the best model during training so we can reuse it later without needing to retrain. PyTorch Lightning’s `ModelCheckpoint` does this by automatically saving the model whenever it achieves a new best validation AUROC score." ], "cell_type": "markdown" }, @@ -1066,15 +796,15 @@ { "metadata": {}, "source": [ - "Step 3: Launch the Trainer\n", - "\n", - "We now create the actual PyTorch Lightning trainer, which handles the training loop." + "We now create the `Trainer`." ], "cell_type": "markdown" }, { "metadata": {}, "source": [ + "import torch\n", + "\n", "trainer = pl.Trainer(\n", " callbacks=[checkpoint_callback], # automatically save best model\n", " devices=\"auto\", # use all available GPUs\n", @@ -1102,30 +832,15 @@ { "metadata": {}, "source": [ - "## Training and Testing Overview\n", - "\n", - "Before we can make any predictions, we need to **train** the model using known examples of drug–protein interactions. This step helps the model learn patterns, so that it can later predict whether a new drug and protein might interact.\n", - "\n", - "### What is training?\n", + "### Train the DrugBAN Model\n", + "After setting up the model and data loaders, we now start training the full DrugBAN model using the PyTorch Lightning Trainer via calling `trainer.fit()`.\n", "\n", - "Training is the process where the model adjusts itself to improve its guesses. Imagine giving it many examples of drug–protein pairs along with the correct answers (whether they interact or not). The model learns from these examples by updating its internal settings to reduce mistakes.\n", + "#### What Happens Here?\n", + "- The model receives batches of drug-protein pairs from the training data loader.\n", "\n", - "### What is validation?\n", + "- During each step, the GCN, CNN, BAN layer, and MLP classifier are updated to improve interaction prediction.\n", "\n", - "Validation happens *during* training. We use a separate set of data (different from the training data) to check how well the model is doing as it learns. This helps us tune the model without accidentally letting it memorise all the training examples. It’s like checking your understanding by doing practice questions while revising.\n", - "\n", - "### What is testing?\n", - "\n", - "Testing is the final step, done *after* training is complete. We give the model new examples it has never seen before. This tells us how well it might perform in the real world when predicting new drug–protein interactions.\n" - ], - "cell_type": "markdown" - }, - { - "metadata": {}, - "source": [ - "### 🏋️♀️ Training\n", - "\n", - "The following code starts the training process. It uses a function called `fit` which is part of PyTorch Lightning's training system. You do not need to change anything here unless you're experimenting." + "- Validation is automatically run at the end of each epoch to track performance and save the best model based on AUROC." ], "cell_type": "markdown" }, @@ -1170,11 +885,20 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "5d1253045dbd49e981af08fe5e0a3866" + "model_id": "de1e5fc3654547e4b7078fd9511d4764" } }, "metadata": {} }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.11/dist-packages/pytorch_lightning/utilities/data.py:78: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 9280. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", + "/usr/local/lib/python3.11/dist-packages/torchmetrics/utilities/prints.py:43: TorchMetricsUserWarning: You are trying to use a metric in deterministic mode on GPU that uses `torch.cumsum`, which is currently not supported. The tensor will be copied to the CPU memory to compute it and then copied back to GPU. Expect some slowdowns.\n", + " warnings.warn(*args, **kwargs)\n" + ] + }, { "output_type": "display_data", "data": { @@ -1184,7 +908,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "fa7446341eee40599cb1afd95f2e2b32" + "model_id": "8ef34bbc1d944d8398a243f6c45ab415" } }, "metadata": {} @@ -1193,8 +917,8 @@ "output_type": "stream", "name": "stderr", "text": [ - "[18:02:25] Unusual charge on atom 0 number of radical electrons set to zero\n", - "[18:02:32] Unusual charge on atom 0 number of radical electrons set to zero\n" + "[15:53:49] Unusual charge on atom 0 number of radical electrons set to zero\n", + "[15:53:55] Unusual charge on atom 0 number of radical electrons set to zero\n" ] }, { @@ -1206,7 +930,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "1184677423754ec48baa956fc28f7d37" + "model_id": "fc8cb437cdc34a5eb335d42c56f2975c" } }, "metadata": {} @@ -1215,8 +939,9 @@ "output_type": "stream", "name": "stderr", "text": [ - "[18:03:35] Unusual charge on atom 0 number of radical electrons set to zero\n", - "[18:04:00] Unusual charge on atom 0 number of radical electrons set to zero\n" + "/usr/local/lib/python3.11/dist-packages/pytorch_lightning/utilities/data.py:78: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 3190. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n", + "[15:54:38] Unusual charge on atom 0 number of radical electrons set to zero\n", + "[15:55:25] Unusual charge on atom 0 number of radical electrons set to zero\n" ] }, { @@ -1228,7 +953,7 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "f7d4f8508ec847939a7475b44b7b7df4" + "model_id": "7d6c70a42d194a1f956778d4543fe9ff" } }, "metadata": {} @@ -1246,11 +971,16 @@ { "metadata": {}, "source": [ - "### 📊 Testing\n", + "### Evaluate\n", + "\n", + "Once training is complete, we evaluate the model on the test set using `trainer.test()`.\n", "\n", - "Once training is complete, the next step is to test how well the model performs on new, unseen data. \n", + "#### What is included in this step?\n", + "- The best model checkpoint (based on validation AUROC) is automatically loaded.\n", "\n", - "You can run the following code cell to do this:" + "- The model runs on the test data to generate predictions.\n", + "\n", + "- Final classification metrics, including AUROC, F1 score, accuracy, sensitivity, and specificity, are calculated and logged." ], "cell_type": "markdown" }, @@ -1265,9 +995,9 @@ "output_type": "stream", "name": "stderr", "text": [ - "INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at /content/embc-mmai25/tutorials/drug-target-interaction/lightning_logs/version_0/checkpoints/epoch=1-step=1220-val_BinaryAUROC=0.5570.ckpt\n", + "INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at /content/embc-mmai25/tutorials/drug-target-interaction/lightning_logs/version_0/checkpoints/epoch=1-step=1220-val_BinaryAUROC=0.5640.ckpt\n", "INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", - "INFO:pytorch_lightning.utilities.rank_zero:Loaded model weights from the checkpoint at /content/embc-mmai25/tutorials/drug-target-interaction/lightning_logs/version_0/checkpoints/epoch=1-step=1220-val_BinaryAUROC=0.5570.ckpt\n" + "INFO:pytorch_lightning.utilities.rank_zero:Loaded model weights from the checkpoint at /content/embc-mmai25/tutorials/drug-target-interaction/lightning_logs/version_0/checkpoints/epoch=1-step=1220-val_BinaryAUROC=0.5640.ckpt\n" ] }, { @@ -1279,11 +1009,21 @@ "application/vnd.jupyter.widget-view+json": { "version_major": 2, "version_minor": 0, - "model_id": "caaad7ec5cc4412fb6b05d0be40b9619" + "model_id": "4fb3eaa50982466696a329fb30165511" } }, "metadata": {} }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.11/dist-packages/kale/pipeline/drugban_trainer.py:377: RuntimeWarning: invalid value encountered in divide\n", + " precision = tpr / (tpr + fpr)\n", + "/usr/local/lib/python3.11/dist-packages/torchmetrics/utilities/prints.py:43: TorchMetricsUserWarning: You are trying to use a metric in deterministic mode on GPU that uses `torch.cumsum`, which is currently not supported. The tensor will be copied to the CPU memory to compute it and then copied back to GPU. Expect some slowdowns.\n", + " warnings.warn(*args, **kwargs)\n" + ] + }, { "output_type": "display_data", "data": { @@ -1291,36 +1031,36 @@ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│\u001b[36m \u001b[0m\u001b[36m test_BinaryAUROC \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.5569921731948853 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_BinaryAccuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.5126791596412659 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_BinaryF1Score \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.21071428060531616 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_BinaryRecall \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.12967033684253693 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_BinarySpecificity \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.8982300758361816 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_accuracy_sklearn \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.532524824142456 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_auroc_sklearn \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.5569921135902405 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_f1_sklearn \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.6724442839622498 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.7800763249397278 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_optim_threshold \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.15476277470588684 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_sensitivity \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.09955751895904541 \u001b[0m\u001b[35m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36m test_specificity \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9626373648643494 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_BinaryAUROC \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.5640328526496887 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_BinaryAccuracy \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.5082690119743347 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_BinaryF1Score \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.15209124982357025 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_BinaryRecall \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.08791209012269974 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_BinarySpecificity \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9314159154891968 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_accuracy_sklearn \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.5226019620895386 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_auroc_sklearn \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.5640328526496887 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_f1_sklearn \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.6693024635314941 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.8550940155982971 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_optim_threshold \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.08681917190551758 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_sensitivity \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.07300885021686554 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_specificity \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9692307710647583 \u001b[0m\u001b[35m \u001b[0m│\n", "└───────────────────────────┴───────────────────────────┘\n" ], "text/html": [ "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃ Test metric ┃ DataLoader 0 ┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
- "│ test_BinaryAUROC │ 0.5569921731948853 │\n",
- "│ test_BinaryAccuracy │ 0.5126791596412659 │\n",
- "│ test_BinaryF1Score │ 0.21071428060531616 │\n",
- "│ test_BinaryRecall │ 0.12967033684253693 │\n",
- "│ test_BinarySpecificity │ 0.8982300758361816 │\n",
- "│ test_accuracy_sklearn │ 0.532524824142456 │\n",
- "│ test_auroc_sklearn │ 0.5569921135902405 │\n",
- "│ test_f1_sklearn │ 0.6724442839622498 │\n",
- "│ test_loss │ 0.7800763249397278 │\n",
- "│ test_optim_threshold │ 0.15476277470588684 │\n",
- "│ test_sensitivity │ 0.09955751895904541 │\n",
- "│ test_specificity │ 0.9626373648643494 │\n",
+ "│ test_BinaryAUROC │ 0.5640328526496887 │\n",
+ "│ test_BinaryAccuracy │ 0.5082690119743347 │\n",
+ "│ test_BinaryF1Score │ 0.15209124982357025 │\n",
+ "│ test_BinaryRecall │ 0.08791209012269974 │\n",
+ "│ test_BinarySpecificity │ 0.9314159154891968 │\n",
+ "│ test_accuracy_sklearn │ 0.5226019620895386 │\n",
+ "│ test_auroc_sklearn │ 0.5640328526496887 │\n",
+ "│ test_f1_sklearn │ 0.6693024635314941 │\n",
+ "│ test_loss │ 0.8550940155982971 │\n",
+ "│ test_optim_threshold │ 0.08681917190551758 │\n",
+ "│ test_sensitivity │ 0.07300885021686554 │\n",
+ "│ test_specificity │ 0.9692307710647583 │\n",
"└───────────────────────────┴───────────────────────────┘\n",
"\n"
]
@@ -1331,22 +1071,22 @@
"output_type": "execute_result",
"data": {
"text/plain": [
- "[{'test_loss': 0.7800763249397278,\n",
- " 'test_auroc_sklearn': 0.5569921135902405,\n",
- " 'test_accuracy_sklearn': 0.532524824142456,\n",
- " 'test_f1_sklearn': 0.6724442839622498,\n",
- " 'test_sensitivity': 0.09955751895904541,\n",
- " 'test_specificity': 0.9626373648643494,\n",
- " 'test_optim_threshold': 0.15476277470588684,\n",
- " 'test_BinaryAUROC': 0.5569921731948853,\n",
- " 'test_BinaryF1Score': 0.21071428060531616,\n",
- " 'test_BinaryRecall': 0.12967033684253693,\n",
- " 'test_BinarySpecificity': 0.8982300758361816,\n",
- " 'test_BinaryAccuracy': 0.5126791596412659}]"
+ "[{'test_loss': 0.8550940155982971,\n",
+ " 'test_auroc_sklearn': 0.5640328526496887,\n",
+ " 'test_accuracy_sklearn': 0.5226019620895386,\n",
+ " 'test_f1_sklearn': 0.6693024635314941,\n",
+ " 'test_sensitivity': 0.07300885021686554,\n",
+ " 'test_specificity': 0.9692307710647583,\n",
+ " 'test_optim_threshold': 0.08681917190551758,\n",
+ " 'test_BinaryAUROC': 0.5640328526496887,\n",
+ " 'test_BinaryF1Score': 0.15209124982357025,\n",
+ " 'test_BinaryRecall': 0.08791209012269974,\n",
+ " 'test_BinarySpecificity': 0.9314159154891968,\n",
+ " 'test_BinaryAccuracy': 0.5082690119743347}]"
]
},
"metadata": {},
- "execution_count": 15
+ "execution_count": 17
}
],
"execution_count": null
@@ -1354,102 +1094,15 @@
{
"metadata": {},
"source": [
- "### 📊 Understanding the Evaluation Metrics\n",
- "\n",
- "After testing the model, several performance metrics are displayed. \n",
- "These help you understand how well the model is making predictions. \n",
- "Below is a brief explanation of each metric:\n",
- "\n",
- "---\n",
- "\n",
- "#### **AUROC (Area Under the Receiver Operating Characteristic Curve)** \n",
- "Measures the model’s ability to distinguish between positive (interacting) and negative (non-interacting) pairs. \n",
- "A value close to 1.0 means excellent distinction. \n",
- "A value around 0.5 means the model is guessing randomly.\n",
- "\n",
- "---\n",
- "\n",
- "#### **Accuracy** \n",
- "Represents the percentage of correct predictions (both positive and negative) out of all predictions. \n",
- "While easy to understand, accuracy can be misleading if the classes are imbalanced.\n",
- "\n",
- "---\n",
- "\n",
- "#### **F1 Score** \n",
- "A balanced measure that combines **precision** and **recall**. \n",
- "It is especially useful when you care equally about false positives and false negatives.\n",
- "\n",
- "---\n",
- "\n",
- "#### **Recall (also called Sensitivity or True Positive Rate)** \n",
- "Shows the proportion of actual positives (interacting pairs) that the model correctly identified. \n",
- "High recall means the model is good at finding interactions.\n",
- "\n",
- "---\n",
- "\n",
- "#### **Specificity (also called True Negative Rate)** \n",
- "Shows the proportion of actual negatives (non-interacting pairs) that the model correctly identified. \n",
- "High specificity means the model is good at ruling out non-interactions.\n",
- "\n",
- "---\n",
- "\n",
- "#### **Optimised Threshold** \n",
- "During evaluation, the model can choose a threshold value for classification that maximises certain metrics like F1 score. \n",
- "This threshold is what the model uses to decide between \"interaction\" and \"no interaction\".\n",
- "\n",
- "---\n",
- "\n",
- "#### **Loss** \n",
- "This is a number the model tries to minimise during training. \n",
- "Lower loss generally means better performance, but it should always be considered alongside the other metrics.\n",
- "\n",
- "---\n",
- "\n",
- "> These metrics provide different perspectives on the model’s behaviour. Together, they help you judge how well the model performs on your task.\n"
- ],
- "cell_type": "markdown"
- },
- {
- "metadata": {},
- "source": [
- "### 📊 Compare with Baselines\n",
- "\n",
- "To evaluate the robustness and generalisability of different models, the DrugBAN model was run for **100 epochs** across multiple random seeds and dataset splits. \n",
- "\n",
- "The figure below compares the performance of these models on the **BioSNAP** and **BindingDB** datasets.\n",
- "\n",
- "- The **left plot** shows results based on **AUROC** (Area Under the Receiver Operating Characteristic Curve).\n",
- "- The **right plot** shows results based on **AUPRC** (Area Under the Precision–Recall Curve).\n",
+ "### Compare with Baselines\n",
"\n",
- "---\n",
+ "To assess the robustness and generalisability of DrugBAN, we compare its performance against baseline models. In this example, DrugBAN was trained for 100 epochs across multiple random seeds。\n",
"\n",
- "### 🧪 Experimental Setup\n",
+ "The figure below presents the comparison on the BioSNAP and BindingDB datasets.\n",
"\n",
- "- Each model was trained and evaluated multiple times with different random seeds to capture performance variability.\n",
- "- Each box plot summarises results from these runs.\n",
- "- **DrugBAN** and **DrugBANDA** were trained for 100 epochs per run.\n",
- "- Performance on the **BioSNAP** dataset is shown in **blue**, and **BindingDB** results are shown in **orange**.\n",
+ "- The left plot shows model performance based on AUROC (Area Under the Receiver Operating Characteristic Curve).\n",
"\n",
- "---\n",
- "\n",
- "### 📈 How to Read the Box Plots\n",
- "\n",
- "- The **centre line** of each box represents the **median** performance.\n",
- "- The **green triangle** shows the **mean** performance.\n",
- "- The **lower and upper edges** of the box indicate the **first and third quartiles**.\n",
- "- The **whiskers** show the full range (excluding outliers).\n",
- "\n",
- "---\n",
- "\n",
- "### 🔍 Key Insights\n",
- "\n",
- "- **DrugBANDA** consistently achieves top performance across both metrics and datasets.\n",
- "- On the **BioSNAP** dataset (blue), performance varies more across models, highlighting its challenging nature.\n",
- "- Simpler models such as **SVM** and **Random Forest (RF)** show limited ability to generalise.\n",
- "- Deep learning models such as **GraphDTA** and **MolTrans** show competitive AUROC but less stability in AUPRC.\n",
- "- **Domain adaptation** improves the model's ability to generalise from BindingDB to BioSNAP, as seen in DrugBANDA's superior scores.\n",
- "\n",
- "---"
+ "- The right plot shows performance based on AUPRC (Area Under the Precision–Recall Curve)."
],
"cell_type": "markdown"
},
@@ -1463,50 +1116,35 @@
{
"metadata": {},
"source": [
- "## Summary\n",
- "\n",
- "In this tutorial, you learned how to use the **PyKale** library to build and evaluate a deep learning model for drug–target interaction (DTI) prediction.\n",
+ "## Interpretation Study - Extracting Embeddings from DrugBAN\n",
+ "After training and evaluating the model, we can study how DrugBAN represents drug and protein information at different stages by extracting key embeddings: **drug embedding**, **protein embedding**, and **joint interaction embedding**. This helps us understand how structural and sequential features are captured and how drug-protein interactions are encoded.\n",
"\n",
- "We walked through the pipeline in three key steps:\n",
+ "The DrugBAN model provides these embeddings through its `forward()` function, returning intermediate outputs before and after the Bilinear Attention Network (BAN) layer.\n",
"\n",
- "### 1. Data Overview \n",
- "You explored how to load and prepare drug and protein data using PyKale’s data handling tools.\n",
+ "### How Are the Embeddings Computed?\n",
+ "The `forward()` function sequentially applies three main modules:\n",
"\n",
- "### 2. Model and Trainer Overview \n",
- "You saw how to configure and use the **DrugBAN** model with PyKale. You also learned how PyKale’s trainer simplifies the training process, including logging and model saving.\n",
+ "- The **GCN-based drug extractor** processes the molecular graph, learning structural features and generating the **drug embedding**.\n",
"\n",
- "### 3. Training and Testing Overview \n",
- "You trained the model and evaluated its performance using commonly used metrics such as **AUROC**, **F1 score**, **recall**, and **specificity**.\n",
+ "- The **CNN-based protein extractor** processes the protein sequence, capturing local and global sequence patterns as the **protein embedding**.\n",
"\n",
- "---\n",
+ "- The **BAN layer** fuses the drug and protein embeddings using bilinear attention, creating a **joint embedding** that highlights interaction-specific features.\n",
"\n",
- "This notebook is designed as an accessible entry point for researchers who are new to Python or machine learning, and who want to explore how graph-based deep learning can be applied to biomedical problems.\n",
- "\n",
- "> 🧠 Tip: Try experimenting with the dataset, changing model settings, or applying PyKale to a new task. That’s the best way to learn!\n",
- "\n",
- "For more information, check out the [original DrugBAN codebase](https://github.com/peizhenbai/DrugBAN) and the full paper in *Nature Machine Intelligence*.\n"
+ "You should save embeddings during the evaluation phase (validation or test), not during training. This ensures you are extracting embeddings from a model that is not updating its weights, and you avoid interfering with training performance."
],
"cell_type": "markdown"
},
{
"metadata": {},
"source": [
- "### Explore More: 3 Tasks to Try (~1 Hour Total)\n",
- "\n",
- "These tasks are designed to help you go beyond the tutorial and gain deeper, hands-on experience with model development, interpretation, and dataset handling. You don’t need prior experience in machine learning — just curiosity!"
- ],
- "cell_type": "markdown"
- },
- {
- "metadata": {},
- "source": [
- "### 🔁 Task 1: Use the BindingDB Dataset (20 minutes)\n",
+ "## Extra Tasks\n",
"\n",
+ "### Task 1: Try the BindingDB Dataset\n",
"Swap the current dataset with **BindingDB**, a real-world dataset containing experimentally measured drug–target interactions.\n",
"\n",
"Steps:\n",
"1. Download or prepare the BindingDB dataset (if needed).\n",
- "2. Update the `data` fields in the YAML config file.\n",
+ "2. Update the relevant field in the YAML config file.\n",
"3. Reload the dataset and re-run training and testing.\n",
"\n",
"**What to explore:**\n",
@@ -1518,70 +1156,6 @@
"\n"
],
"cell_type": "markdown"
- },
- {
- "metadata": {},
- "source": [
- "### 🧪 Task 2: Inspect Misclassified Samples (20 minutes)\n",
- "\n",
- "Dive into the test results and check where the model made incorrect predictions. \n",
- "This helps you understand where the model struggles and whether those mistakes make sense.\n",
- "\n",
- "---\n",
- "\n",
- "#### ✅ Steps\n",
- "\n",
- "1. After testing, collect predicted probabilities and the true labels.\n",
- "2. Print out the predictions that the model got wrong.\n",
- "3. (Optional) Visualise the drug or protein graph for those samples.\n",
- "\n",
- "---\n",
- "\n",
- "#### 🔍 What to Explore\n",
- "\n",
- "- **Are the wrong predictions close to 0.5?** \n",
- " If so, the model was unsure. This can help you identify borderline cases.\n",
- "\n",
- "- **Are there more false positives or false negatives?** \n",
- " This tells you whether the model is more likely to over-predict interactions or miss them.\n",
- "\n",
- "- **Do certain types of drug–protein pairs seem harder to classify?** \n",
- " For example, some proteins or drugs may always appear in misclassified samples. This could point to noisy or hard-to-learn examples.\n",
- "\n",
- "---\n",
- "\n",
- "#### 🧬 (Optional) Visualise the Graph of a Misclassified Sample\n",
- "\n",
- "You can plot the structure of a drug or protein graph that the model got wrong. \n",
- "This helps you interpret what the model was seeing when it made the incorrect prediction.\n"
- ],
- "cell_type": "markdown"
- },
- {
- "metadata": {},
- "source": [
- "### 🧠 Task 3: Change the Protein Encoder (20 minutes)\n",
- "\n",
- "Explore how the model behaves when you change the way proteins are represented. \n",
- "This task helps you understand whether the model relies more on protein structure or other features when making predictions.\n",
- "\n",
- "---\n",
- "\n",
- "💡 Ideas to Try\n",
- "Use one hot embedding instead of integer encoding.\n",
- "\n",
- "Replace the protein graph with a flat vector input (e.g. sequence length, molecular weight, or hydrophobicity).\n",
- "\n",
- "\n",
- "🔍 What to Observe\n",
- "- Does model performance improve, stay the same, or get worse?\n",
- "\n",
- "- Which metric changes the most — AUROC, F1 score, or recall?\n",
- "\n",
- "- Does training take more time or less time with the new encoder?\n",
- "\n"
- ],
- "cell_type": "markdown"
}
]
}