# **Drug–Target Interaction Prediction** (non DA in domain example)

Welcome to this tutorial on drug–target interaction (DTI) prediction using **PyTorch Geometric**. We demonstrate how to predict whether a given drug and protein pair interact, using graph-based deep learning.

This notebook is inspired by the work of [**Bai et al. (_Nature Machine Intelligence_, 2023)**](https://www.nature.com/articles/s42256-022-00605-1), which introduced a **Deep Bilinear Attention Network (BAN)** with **adversarial domain adaptation**. The model is designed to:

- **Capture fine-grained pairwise interactions** between drug molecules and target proteins
- **Generalise to out-of-distribution data**, improving performance on unseen drug–target pairs

---

## 🔍 What You'll Learn

- How to use `kale.loaddata.molecular_datasets.DTIDataset` encode **drugs** and **protein sequences**.
- How to implement the **BAN** network
- How to evaluate model performance on benchmark datasets

---

Let’s get started!

In the following sections, we'll walk through data preprocessing, model implementation, and evaluation, showing how these components come together for DTI prediction.


## Setup

As a starting point, we will install the required packages and load a set of helper functions to assist throughout this tutorial. To keep the output clean and focused on interpretation, we will also suppress warnings.

Moreover, we provide helper functions that can be inspected directly in the .py files located in the notebook’s current directory. The three additional helper scripts are:

- `config.py`: Defines the base configuration settings, which can be overridden using a custom `.yaml` file.


In [1]:
import os
import warnings

warnings.filterwarnings("ignore")
os.environ["PYTHONWARNINGS"] = "ignore"

[Optional] If you are using Google Colab, please using the following codes to load necessary demo data and code files.

In [2]:
!git clone --branch drug-target-interaction https://github.com/pykale/embc-mmai25.git
%cd /content/embc-mmai25/tutorials/drug-target-interaction

Cloning into 'embc-mmai25'...
remote: Enumerating objects: 697, done.[K
remote: Counting objects: 100% (99/99), done.[K
remote: Compressing objects: 100% (74/74), done.[K
remote: Total 697 (delta 44), reused 62 (delta 22), pack-reused 598 (from 1)[K
Receiving objects: 100% (697/697), 125.59 MiB | 10.98 MiB/s, done.
Resolving deltas: 100% (289/289), done.
/content/embc-mmai25/tutorials/drug-target-interaction


## 📦 Packages

The main packages required for this tutorial are **PyKale**, **PyTorch Geometric**, and **RDKit**.

- **PyKale** is an open-source interdisciplinary machine learning library developed at the University of Sheffield, designed for applications in biomedical and scientific domains.
- **PyG** (PyTorch Geometric) is a library built on top of PyTorch for building and training Graph Neural Networks (GNNs) on structured data.
- **RDKit** is a cheminformatics toolkit for handling and processing molecular structures, particularly useful for working with SMILES strings and molecular graphs.

📄 Other dependencies are listed in [`embc-mmai25/requirements.txt`](https://github.com/pykale/embc-mmai25/blob/main/requirements.txt).


In [None]:
!pip install --quiet git+https://github.com/pykale/pykale@main\
    && echo "PyKale installed successfully ✅" \
    || echo "Failed to install PyKale ❌"

!pip install --quiet -r /content/embc-mmai25/requirements.txt \
    && echo "Required packages installed successfully ✅" \
    || echo "Failed to install required packages ❌"

!pip install --upgrade --force-reinstall numpy

import torch
os.environ['TORCH'] = torch.__version__
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html

!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git \
    && echo "PyG installed successfully ✅" \
    || echo "Failed to install PyG ❌"

!pip install rdkit-pypi \
    && echo "PyG installed successfully ✅" \
    || echo "Failed to install PyG ❌"

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m779.2/779.2 MB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m410.6/410.6 MB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.1/14.1 MB[0m [31m66.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m38.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.6/823.6 kB[0m [31m60.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m731.7/731.7 MB[0m [31m2.7 MB/s[0m eta [36m0:

In [None]:
import pandas as pd

In [None]:
# Standard imports
import os
import torch

import numpy as np
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from yacs.config import CfgNode

# PyKale and custom modules (make sure your PYTHONPATH is set correctly)
from kale.loaddata.molecular_datasets import DTIDataset, graph_collate_func
from kale.embed.ban import DrugBAN
from kale.pipeline.drugban_trainer import DrugbanTrainer

## Configuration

To minimize the footprint of the notebook when specifying configurations, we provide a `config.py` file that defines default parameters. These can be customized by supplying a `.yaml` configuration file, such as `experiments/non_da_in_domain.yaml` as an example.

In [None]:
from configs import get_cfg_defaults

cfg = get_cfg_defaults()
cfg.merge_from_file("experiments/non_DA_in_domain.yaml")

# temporary to shorten training time
cfg.SOLVER.MAX_EPOCH = 2
print(cfg)


## Data Loading

We use the DTI benchmark dataset BindingDB, provided by the authors of the DrugBAN paper in their [repository](https://github.com/peizhenbai/DrugBAN/tree/main).

The `bindingdb` dataset is structured as follows:

```sh
    ├───bindingdb
    │   ├───cluster
    │   │   ├───source_train.csv
    │   │   ├───target_train.csv
    │   │   ├───target_test.csv
    │   ├───random
    │   │   ├───test.csv
    │   │   ├───train.csv
    │   │   ├───val.csv
    │   ├───full.csv

```

Each CSV file contains the following columns:

- **SMILES**: Drug molecule represented in SMILES (Simplified Molecular Input Line Entry System) format  
- **Protein Sequence**: Protein represented as an amino acid sequence  
- **Y**: Binary interaction label (`1` = interaction, `0` = no interaction)


An example structure of the BindingDB dataset is shown below.

**Table 1**: Characteristics of the BindingDB DTI dataset.

| SMILES             | Protein Sequence         | Y |
|--------------------|--------------------------|---|
| Fc1ccc(C2(COC…)    | MDNVLPVDSDLS…            | 1 |
| O=c1oc2c(O)c(…)    | MMYSKLLTLTTL…            | 0 |
| CC(C)Oc1cc(N…)     | MGMACLTMTEME…            | 1 |



### Preprocessing

- **Drugs** are converted from SMILES strings to molecular graphs using **RDKit** and **PyTorch Geometric**.  
- **Proteins** are encoded as integer sequences (via one-hot encoding or embeddings).  
- **Labels** are binary (`0` or `1`).

The `DTIDataset` class handles this preprocessing pipeline.


In [None]:
from kale.loaddata.molecular_datasets import DTIDataset

dataFolder = os.path.join(f"./datasets/{cfg.DATA.DATASET}", str(cfg.DATA.SPLIT))

df_train = pd.read_csv(os.path.join(dataFolder, "train.csv"))
df_val = pd.read_csv(os.path.join(dataFolder, "val.csv"))
df_test = pd.read_csv(os.path.join(dataFolder, "test.csv"))

train_dataset = DTIDataset(df_train.index.values, df_train)
valid_dataset = DTIDataset(df_val.index.values, df_val)
test_dataset = DTIDataset(df_test.index.values, df_test)

### Dataset Inspection

After loading the dataset, we can quickly inspect its structure and contents using the following code:


In [None]:
print(f"Train samples: {len(train_dataset)}, Validation samples: {len(valid_dataset)}, Test samples: {len(test_dataset)}")
print("Example sample:\n", train_dataset[0])

### Batching

We use PyTorch’s `DataLoader` to efficiently load molecular graph data in batches. A custom `graph_collate_func` is used to correctly batch variable-sized graph structures. Separate data loaders are created for training (with shuffling) and for validation/test (without shuffling).


In [None]:
from torch.utils.data import DataLoader
from kale.loaddata.molecular_datasets import graph_collate_func

params = {
        "batch_size": cfg.SOLVER.BATCH_SIZE,
        "shuffle": True,
        "num_workers": cfg.SOLVER.NUM_WORKERS,
        "drop_last": True,
        "collate_fn": graph_collate_func,
    }


training_generator = DataLoader(train_dataset, **params)
params.update({"shuffle": False, "drop_last": False})
valid_generator = DataLoader(valid_dataset, **params)
test_generator = DataLoader(test_dataset, **params)

## Setup Model

The **DrugBAN** model consists of the following components:

- A **GCN** for drug molecular graphs  
- A **CNN** for protein sequences  
- A **Bilinear Attention Network (BAN)** for feature fusion  
- An **MLP** for classification

Model configuration is managed via the `config.py` file.


In [None]:
from kale.embed.ban import DrugBAN

model = DrugBAN(**cfg)
print(model)

## Setup Trainer
We use a PyTorch Lightning trainer for structured training and evaluation.

In [None]:
from kale.pipeline.drugban_trainer import DrugbanTrainer

drugban_trainer = DrugbanTrainer(
    model=DrugBAN(**cfg),
    solver_lr=cfg["SOLVER"]["LEARNING_RATE"],
    num_classes=cfg["DECODER"]["BINARY"],
    batch_size=cfg["SOLVER"]["BATCH_SIZE"],
    # --- domain adaptation parameters ---
    is_da=cfg["DA"]["USE"],
    solver_da_lr=cfg["SOLVER"]["DA_LEARNING_RATE"],
    da_init_epoch=cfg["DA"]["INIT_EPOCH"],
    da_method=cfg["DA"]["METHOD"],
    original_random=cfg["DA"]["ORIGINAL_RANDOM"],
    use_da_entropy=cfg["DA"]["USE_ENTROPY"],
    da_random_layer=cfg["DA"]["RANDOM_LAYER"],
    # --- discriminator parameters ---
    da_random_dim=cfg["DA"]["RANDOM_DIM"],
    decoder_in_dim=cfg["DECODER"]["IN_DIM"],
)

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    filename="{epoch}-{step}-{val_BinaryAUROC:.4f}",
    monitor="val_BinaryAUROC",
    mode="max",
)

trainer = pl.Trainer(
    callbacks=[checkpoint_callback],
    devices="auto",
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    max_epochs=cfg["SOLVER"]["MAX_EPOCH"],
    deterministic=True,  # for reproducibility
)

## Training

In [None]:
trainer.fit(drugban_trainer, train_dataloaders=training_generator, val_dataloaders=valid_generator)

## Testing

### Results and Interpretation

After training, you can inspect evaluation metrics such as **AUROC**, **F1 score**, **Recall**, and others.  
You can also visualise **attention maps** or **feature importances** as needed for interpretation.

In [None]:
trainer.test(drugban_trainer, dataloaders=test_generator, ckpt_path="best")

## Summary

- We loaded and preprocessed the **BindingDB** dataset for DTI prediction.  
- We built a **DrugBAN** model using **GCN**, **CNN**, and **BAN** components.  
- We trained and evaluated the model using **PyTorch Lightning**.  
- The configuration file allows for easy reproduction and modification of experiments.  

For more details, see the [original codebase](https://github.com/peizhenbai/DrugBAN) and the accompanying paper in *Nature Machine Intelligence*.

