In [2]:
!pip install "numpy<2.0" "transformers==4.30.2" --force-reinstall

Collecting numpy<2.0
  Downloading numpy-1.26.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting transformers==4.30.2
  Downloading transformers-4.30.2-py3-none-any.whl.metadata (113 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m113.6/113.6 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting filelock (from transformers==4.30.2)
  Downloading filelock-3.18.0-py3-none-any.whl.metadata (2.9 kB)
Collecting huggingface-hub<1.0,>=0.14.1 (from transformers==4.30.2)
  Downloading huggingface_hub-0.33.0-py3-none-any.whl.metadata (14 kB)
Collecting packaging>=20.0 (from transformers==4.30.2)
  Downloading packaging-25.0-py3-none-any.whl.metadata (3.3 kB)
Collecting pyyaml>=5.1 (from transformers==4.30.2)
  Downloading PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.1 

In [1]:
# ✅ Now run this AFTER restarting runtime
import numpy as np
print("NumPy version:", np.__version__)  # should be <2.0

NumPy version: 1.26.4


# **Drug–Target Interaction Prediction** (non DA in domain example)

Welcome to this tutorial on drug–target interaction (DTI) prediction using **PyTorch Geometric**. We demonstrate how to predict whether a given drug and protein pair interact, using graph-based deep learning.

This notebook is inspired by the work of [**Bai et al. (_Nature Machine Intelligence_, 2023)**](https://www.nature.com/articles/s42256-022-00605-1), which introduced a **Deep Bilinear Attention Network (BAN)** with **adversarial domain adaptation**. The model is designed to:

- **Capture fine-grained pairwise interactions** between drug molecules and target proteins
- **Generalise to out-of-distribution data**, improving performance on unseen drug–target pairs

---

## 🔍 What You'll Learn

- How to use `kale.loaddata.molecular_datasets.DTIDataset` encode **drugs** and **protein sequences**.
- How to implement the **BAN** network
- How to evaluate model performance on benchmark datasets

---

Let’s get started!

In the following sections, we'll walk through data preprocessing, model implementation, and evaluation, showing how these components come together for DTI prediction.


## Setup

As a starting point, we will install the required packages and load a set of helper functions to assist throughout this tutorial. To keep the output clean and focused on interpretation, we will also suppress warnings.

Moreover, we provide helper functions that can be inspected directly in the .py files located in the notebook’s current directory. The three additional helper scripts are:

- `config.py`: Defines the base configuration settings, which can be overridden using a custom `.yaml` file.


In [2]:
import os
import warnings

warnings.filterwarnings("ignore")
os.environ["PYTHONWARNINGS"] = "ignore"

[Optional] If you are using Google Colab, please using the following codes to load necessary demo data and code files.

In [3]:
!git clone --branch drug-target-interaction https://github.com/pykale/embc-mmai25.git
%cd /content/embc-mmai25/tutorials/drug-target-interaction

fatal: destination path 'embc-mmai25' already exists and is not an empty directory.
/content/embc-mmai25/tutorials/drug-target-interaction


## 📦 Packages

The main packages required for this tutorial are **PyKale**, **PyTorch Geometric**, and **RDKit**.

- **PyKale** is an open-source interdisciplinary machine learning library developed at the University of Sheffield, designed for applications in biomedical and scientific domains.
- **PyG** (PyTorch Geometric) is a library built on top of PyTorch for building and training Graph Neural Networks (GNNs) on structured data.
- **RDKit** is a cheminformatics toolkit for handling and processing molecular structures, particularly useful for working with SMILES strings and molecular graphs.

📄 Other dependencies are listed in [`embc-mmai25/requirements.txt`](https://github.com/pykale/embc-mmai25/blob/main/requirements.txt).


In [4]:
!pip install --quiet git+https://github.com/pykale/pykale@main\
    && echo "PyKale installed successfully ✅" \
    || echo "Failed to install PyKale ❌"

!pip install --quiet -r /content/embc-mmai25/requirements.txt \
    && echo "Required packages installed successfully ✅" \
    || echo "Failed to install required packages ❌"

import torch
os.environ['TORCH'] = torch.__version__
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html

!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git \
    && echo "PyG installed successfully ✅" \
    || echo "Failed to install PyG ❌"

!pip install rdkit-pypi \
    && echo "PyG installed successfully ✅" \
    || echo "Failed to install PyG ❌"

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m779.2/779.2 MB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m410.6/410.6 MB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.1/14.1 MB[0m [31m118.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m95.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.6/823.6 kB[0m [31m57.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m731.7/731.7 MB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m121.6/121.6 MB[0m [31m7.6 MB/s[0m eta [36m0

## Configuration

To minimize the footprint of the notebook when specifying configurations, we provide a `config.py` file that defines default parameters. These can be customized by supplying a `.yaml` configuration file, such as `experiments/non_da_in_domain.yaml` as an example.

In [5]:
from configs import get_cfg_defaults

cfg = get_cfg_defaults()
cfg.merge_from_file("experiments/non_DA_in_domain.yaml")

# temporary to shorten training time
cfg.SOLVER.MAX_EPOCH = 2
print(cfg)


BCN:
  HEADS: 2
COMET:
  API_KEY: InDQ1UsqJt7QMiANWg55Ulebe
  EXPERIMENT_NAME: Non_DA_in_domain
  PROJECT_NAME: drugban-23-May
  TAG: DrugBAN_Vanilla
  USE: True
DA:
  INIT_EPOCH: 10
  LAMB_DA: 1
  METHOD: CDAN
  ORIGINAL_RANDOM: False
  RANDOM_DIM: None
  RANDOM_LAYER: False
  TASK: False
  USE: False
  USE_ENTROPY: True
DATA:
  DATASET: bindingdb
  SPLIT: random
DECODER:
  BINARY: 1
  HIDDEN_DIM: 512
  IN_DIM: 256
  NAME: MLP
  OUT_DIM: 128
DRUG:
  HIDDEN_LAYERS: [128, 128, 128]
  MAX_NODES: 290
  NODE_IN_EMBEDDING: 128
  NODE_IN_FEATS: 7
  PADDING: True
PROTEIN:
  EMBEDDING_DIM: 128
  KERNEL_SIZE: [3, 6, 9]
  NUM_FILTERS: [128, 128, 128]
  PADDING: True
RESULT:
  SAVE_MODEL: True
SOLVER:
  BATCH_SIZE: 64
  DA_LEARNING_RATE: 0.001
  LEARNING_RATE: 5e-05
  MAX_EPOCH: 2
  NUM_WORKERS: 0
  SEED: 20


## Data Loading

We use the DTI benchmark dataset BindingDB, provided by the authors of the DrugBAN paper in their [repository](https://github.com/peizhenbai/DrugBAN/tree/main).

The `bindingdb` dataset is structured as follows:

```sh
    ├───bindingdb
    │   ├───cluster
    │   │   ├───source_train.csv
    │   │   ├───target_train.csv
    │   │   ├───target_test.csv
    │   ├───random
    │   │   ├───test.csv
    │   │   ├───train.csv
    │   │   ├───val.csv
    │   ├───full.csv

```

Each CSV file contains the following columns:

- **SMILES**: Drug molecule represented in SMILES (Simplified Molecular Input Line Entry System) format  
- **Protein Sequence**: Protein represented as an amino acid sequence  
- **Y**: Binary interaction label (`1` = interaction, `0` = no interaction)


An example structure of the BindingDB dataset is shown below.

**Table 1**: Characteristics of the BindingDB DTI dataset.

| SMILES             | Protein Sequence         | Y |
|--------------------|--------------------------|---|
| Fc1ccc(C2(COC…)    | MDNVLPVDSDLS…            | 1 |
| O=c1oc2c(O)c(…)    | MMYSKLLTLTTL…            | 0 |
| CC(C)Oc1cc(N…)     | MGMACLTMTEME…            | 1 |



### Preprocessing

- **Drugs** are converted from SMILES strings to molecular graphs using **RDKit** and **PyTorch Geometric**.  
- **Proteins** are encoded as integer sequences (via one-hot encoding or embeddings).  
- **Labels** are binary (`0` or `1`).

The `DTIDataset` class handles this preprocessing pipeline.


In [6]:
from kale.loaddata.molecular_datasets import DTIDataset
import pandas as pd

dataFolder = os.path.join(f"./datasets/{cfg.DATA.DATASET}", str(cfg.DATA.SPLIT))

df_train = pd.read_csv(os.path.join(dataFolder, "train.csv"))
df_val = pd.read_csv(os.path.join(dataFolder, "val.csv"))
df_test = pd.read_csv(os.path.join(dataFolder, "test.csv"))

train_dataset = DTIDataset(df_train.index.values, df_train)
valid_dataset = DTIDataset(df_val.index.values, df_val)
test_dataset = DTIDataset(df_test.index.values, df_test)

### Dataset Inspection

After loading the dataset, we can quickly inspect its structure and contents using the following code:


In [7]:
print(f"Train samples: {len(train_dataset)}, Validation samples: {len(valid_dataset)}, Test samples: {len(test_dataset)}")
print("Example sample:\n", train_dataset[0])

Train samples: 34439, Validation samples: 4920, Test samples: 9840
Example sample:
 (Data(x=[290, 7], edge_index=[2, 85], edge_attr=[85, 1], num_nodes=290), array([11.,  5., 14., ...,  0.,  0.,  0.]), 1)


### Batching

We use PyTorch’s `DataLoader` to efficiently load molecular graph data in batches. A custom `graph_collate_func` is used to correctly batch variable-sized graph structures. Separate data loaders are created for training (with shuffling) and for validation/test (without shuffling).


In [8]:
from torch.utils.data import DataLoader
from kale.loaddata.molecular_datasets import graph_collate_func

params = {
        "batch_size": cfg.SOLVER.BATCH_SIZE,
        "shuffle": True,
        "num_workers": cfg.SOLVER.NUM_WORKERS,
        "drop_last": True,
        "collate_fn": graph_collate_func,
    }


training_generator = DataLoader(train_dataset, **params)
params.update({"shuffle": False, "drop_last": False})
valid_generator = DataLoader(valid_dataset, **params)
test_generator = DataLoader(test_dataset, **params)

## Setup Model

The **DrugBAN** model consists of the following components:

- A **GCN** for drug molecular graphs  
- A **CNN** for protein sequences  
- A **Bilinear Attention Network (BAN)** for feature fusion  
- An **MLP** for classification

Model configuration is managed via the `config.py` file.


In [9]:
from kale.embed.ban import DrugBAN

model = DrugBAN(**cfg)
print(model)

DrugBAN(
  (drug_extractor): MolecularGCN(
    (init_transform): Linear(in_features=7, out_features=128, bias=False)
    (gcn_layers): ModuleList(
      (0-2): 3 x GCNConv(128, 128)
    )
  )
  (protein_extractor): ProteinCNN(
    (embedding): Embedding(26, 128, padding_idx=0)
    (conv1): Conv1d(128, 128, kernel_size=(3,), stride=(1,))
    (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv1d(128, 128, kernel_size=(6,), stride=(1,))
    (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv3): Conv1d(128, 128, kernel_size=(9,), stride=(1,))
    (bn3): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (bcn): BANLayer(
    (v_net): FCNet(
      (main): Sequential(
        (0): Dropout(p=0.2, inplace=False)
        (1): Linear(in_features=128, out_features=768, bias=True)
        (2): ReLU()
      )
    )
    (q_net): FCNet(
      (main): Sequential(
     

## Setup Trainer
We use a PyTorch Lightning trainer for structured training and evaluation.

In [10]:
from kale.pipeline.drugban_trainer import DrugbanTrainer

drugban_trainer = DrugbanTrainer(
    model=DrugBAN(**cfg),
    solver_lr=cfg["SOLVER"]["LEARNING_RATE"],
    num_classes=cfg["DECODER"]["BINARY"],
    batch_size=cfg["SOLVER"]["BATCH_SIZE"],
    # --- domain adaptation parameters ---
    is_da=cfg["DA"]["USE"],
    solver_da_lr=cfg["SOLVER"]["DA_LEARNING_RATE"],
    da_init_epoch=cfg["DA"]["INIT_EPOCH"],
    da_method=cfg["DA"]["METHOD"],
    original_random=cfg["DA"]["ORIGINAL_RANDOM"],
    use_da_entropy=cfg["DA"]["USE_ENTROPY"],
    da_random_layer=cfg["DA"]["RANDOM_LAYER"],
    # --- discriminator parameters ---
    da_random_dim=cfg["DA"]["RANDOM_DIM"],
    decoder_in_dim=cfg["DECODER"]["IN_DIM"],
)

In [11]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    filename="{epoch}-{step}-{val_BinaryAUROC:.4f}",
    monitor="val_BinaryAUROC",
    mode="max",
)

trainer = pl.Trainer(
    callbacks=[checkpoint_callback],
    devices="auto",
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    max_epochs=cfg["SOLVER"]["MAX_EPOCH"],
    deterministic=True,  # for reproducibility
)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


## Training

In [12]:
trainer.fit(drugban_trainer, train_dataloaders=training_generator, val_dataloaders=valid_generator)

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name          | Type             | Params | Mode 
-----------------------------------------------------------
0 | model         | DrugBAN          | 1.0 M  | train
1 | valid_metrics | MetricCollection | 0      | train
2 | test_metrics  | MetricCollection | 0      | train
-----------------------------------------------------------
1.0 M     Trainable params
0         Non-trainable params
1.0 M     Total params
4.049     Total estimated model params size (MB)
54        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=2` reached.


## Testing

### Results and Interpretation

After training, you can inspect evaluation metrics such as **AUROC**, **F1 score**, **Recall**, and others.  
You can also visualise **attention maps** or **feature importances** as needed for interpretation.

In [13]:
trainer.test(drugban_trainer, dataloaders=test_generator, ckpt_path="best")

INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at /content/embc-mmai25/tutorials/drug-target-interaction/lightning_logs/version_0/checkpoints/epoch=1-step=1076-val_BinaryAUROC=0.8926.ckpt
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.utilities.rank_zero:Loaded model weights from the checkpoint at /content/embc-mmai25/tutorials/drug-target-interaction/lightning_logs/version_0/checkpoints/epoch=1-step=1076-val_BinaryAUROC=0.8926.ckpt


Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_loss': 0.41549041867256165,
  'test_auroc_sklearn': 0.8900876045227051,
  'test_accuracy_sklearn': 0.7900406718254089,
  'test_f1_sklearn': 0.8168386816978455,
  'test_sensitivity': 0.7275407314300537,
  'test_specificity': 0.8784977793693542,
  'test_optim_threshold': 0.31936752796173096,
  'test_BinaryAUROC': 0.8900876045227051,
  'test_BinaryF1Score': 0.7566662430763245,
  'test_BinaryRecall': 0.7209131121635437,
  'test_BinarySpecificity': 0.8695802688598633,
  'test_BinaryAccuracy': 0.8080284595489502}]

## Summary

- We loaded and preprocessed the **BindingDB** dataset for DTI prediction.  
- We built a **DrugBAN** model using **GCN**, **CNN**, and **BAN** components.  
- We trained and evaluated the model using **PyTorch Lightning**.  
- The configuration file allows for easy reproduction and modification of experiments.  

For more details, see the [original codebase](https://github.com/peizhenbai/DrugBAN) and the accompanying paper in *Nature Machine Intelligence*.

