<a href="https://colab.research.google.com/github/zhaoqichang/ColdstartCPI/blob/main/Demo/ColdstartCPI_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ColdstartCPI Running Demo

| [Open In Colab](https://colab.research.google.com/github/pz-white/DrugBAN/blob/main/drugban_demo.ipynb) (click `Runtime` → `Run all (Ctrl+F9)` |

This is a code demo of ColdstartCPI framework for compound-protein interaction prediction. If you don't train the model, it will take about 10 minutes to run the whole pipeline. If you want to train the model, it will cost you 1 hour to train it.

## Setup

The first few blocks of code are necessary to set up the notebook execution environment. This checks if the notebook is running on Google Colab and installs required packages.

In [1]:
if 'google.colab' in str(get_ipython()):
    print('Running on CoLab')
    !pip uninstall --yes yellowbrick
    !pip install -U -q psutil
    !pip install -U spacy==2.1.0
    !pip install rdkit-pypi
    !pip install Mol2Vec
    !pip install bio_embeddings
    !pip install tqdm
    !pip install prefetch_generator
    !git clone https://github.com/zhaoqichang/ColdstartCPI.git
    %cd ColdstartCPI/Demo
else:
    print('Not running on CoLab')

Running on CoLab
Collecting bio_embeddings
  Using cached bio_embeddings-0.1.6-py3-none-any.whl (73 kB)
Collecting biopython<2.0,>=1.76 (from bio_embeddings)
  Using cached biopython-1.83-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
Collecting gensim<4.0.0,>=3.8.2 (from bio_embeddings)
  Using cached gensim-3.8.3.tar.gz (23.4 MB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting h5py<3.0.0,>=2.10.0 (from bio_embeddings)
  Using cached h5py-2.10.0.tar.gz (301 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting importlib_metadata<2.0.0,>=1.7.0 (from bio_embeddings)
  Using cached importlib_metadata-1.7.0-py2.py3-none-any.whl (31 kB)
INFO: pip is looking at multiple versions of bio-embeddings to determine which version is compatible with other requirements. This could take a while.
Collecting bio_embeddings
  Using cached bio_embeddings-0.1.5-py3-none-any.whl (72 kB)
  Using cached bio_embeddings-0.1.4-py3-none-any.whl (84 kB)
Collectin

## Import required modules.

In [2]:
import warnings
warnings.filterwarnings("ignore")
import random
import os
from model import ColdstartCPI
from dataset import load_dataset
from prefetch_generator import BackgroundGenerator
from tqdm import tqdm
import numpy as np
print(np.__version__)
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, precision_score, f1_score, recall_score,precision_recall_curve, auc
from sklearn import metrics

1.26.4


## Configuration

For saving time to run a whole pipeline in this demo, we use small subsets, which is located at `Demo/Dataset/demo_data.txt`

In [3]:
SEED = 1234
random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Epoch = 100
Batch_size = 32
Learning_rate = 0.0001
Early_stopping_patience = 5
save_path = "./Results/"
if not os.path.exists(save_path):
  os.makedirs(save_path)

## Data Loader

The train/valid/test datasets are specified using the `CustomDataSet()` function and loaded using the `load_dataset()` function.

In [4]:
train_dataset_load, valid_dataset_load, test_dataset_load = load_dataset(batch_size=Batch_size)

Number of samples in the train set:  200
Number of samples in the validation set:  200
Number of samples in the test set:  200


## Setup Model and Optimizer

Here, we use the previously defined configuration to set up the model and optimizer we will subsequently train.


In [5]:
model = ColdstartCPI(unify_num=128,head_num=2).to(device)
Loss = nn.CrossEntropyLoss(weight=None)
optimizer = torch.optim.Adam(model.parameters(), lr=Learning_rate)

## Model Training and validation

Optimize model parameters and check validation performance.
It will only cost you 1 hour to train it.

In [6]:
def roc_auc(y,pred):
    fpr, tpr, thresholds = metrics.roc_curve(y, pred)
    roc_auc = metrics.auc(fpr, tpr)
    return roc_auc

def pr_auc(y, pred):
    precision, recall, thresholds = metrics.precision_recall_curve(y, pred)
    pr_auc = metrics.auc(recall, precision)
    return pr_auc

In [7]:
def test_precess(model,pbar,LOSS):
    model.eval()
    test_losses = []
    Y, P, S = [], [], []
    with torch.no_grad():
        for i, data in pbar:
            '''data preparation '''
            input_batch, labels = data
            labels = labels.to(device)
            input_batch = [d.to(device) for d in input_batch]
            predicted_scores = model(input_batch)
            loss = LOSS(predicted_scores, labels)
            correct_labels = labels.to('cpu').data.numpy()
            predicted_scores = F.softmax(predicted_scores, 1).to('cpu').data.numpy()
            predicted_labels = np.argmax(predicted_scores, axis=1)
            predicted_scores = predicted_scores[:, 1]

            Y.extend(correct_labels)
            P.extend(predicted_labels)
            S.extend(predicted_scores)
            test_losses.append(loss.item())
    Precision = precision_score(Y, P)
    Reacll = recall_score(Y, P)
    F1_score = f1_score(Y, P)
    # AUC = roc_auc_score(Y, S)
    AUC = roc_auc(Y,S)
    tpr, fpr, _ = precision_recall_curve(Y, S)
    # PRC = auc(fpr, tpr)
    PRC = pr_auc(Y,S)
    Accuracy = accuracy_score(Y, P)
    test_loss = np.average(test_losses)
    return Y, P, test_loss, Accuracy, Precision, Reacll, F1_score, AUC, PRC

In [8]:
def test_model(dataset_load, LOSS):
    test_pbar = tqdm(
        enumerate(
            BackgroundGenerator(dataset_load)),
        total=len(dataset_load))
    T, P, loss_test, Accuracy_test, Precision_test, Recall_test, F1_score_test, AUC_test, PRC_test = \
        test_precess(model,test_pbar, LOSS)
    results = 'Loss:{:.5f};Accuracy:{:.5f};Precision:{:.5f};Recall:{:.5f};F1 score:{:.5f};AUC:{:.5f};PRC:{:.5f}.' \
        .format(loss_test, Accuracy_test, Precision_test, Recall_test, F1_score_test, AUC_test, PRC_test)
    print(results)
    return results,loss_test, Accuracy_test, Precision_test, Recall_test, F1_score_test, AUC_test, PRC_test

In [11]:
patience = 0
best_score = 0
best_epoch = 0
"""Start training."""
print('Training...')
epoch_len = len(str(Epoch))
for epoch in range(Epoch):
  trian_pbar = tqdm(
      enumerate(
          BackgroundGenerator(train_dataset_load)),
      total=len(train_dataset_load))
  """train"""
  train_losses_in_epoch = []
  model.train()
  for trian_i, train_data in trian_pbar:
    '''data preparation '''
    input_batch, trian_labels = train_data
    input_batch = [d.to(device) for d in input_batch]
    trian_labels = trian_labels.to(device)
    optimizer.zero_grad()
    predicted_interaction = model(input_batch)
    train_loss = Loss(predicted_interaction, trian_labels)
    train_losses_in_epoch.append(train_loss.item())
    train_loss.backward()
    optimizer.step()
    train_loss_a_epoch = np.average(train_losses_in_epoch)
    """valid"""
  valid_pbar = tqdm(
      enumerate(
          BackgroundGenerator(valid_dataset_load)),
      total=len(valid_dataset_load))
  _,_,valid_loss_a_epoch, _, _, _, _, AUC_dev, PRC_dev = test_precess(model,valid_pbar,Loss)
  valid_score = AUC_dev + PRC_dev
  print_msg = (f'[{epoch + 1:>{epoch_len}}/{Epoch:>{epoch_len}}] ' +
                  f'patience: {patience} ' +
                  f'train_loss: {train_loss_a_epoch:.5f} ' +
                  f'valid_loss: {valid_loss_a_epoch:.5f} ' +
                  f'valid_AUC: {AUC_dev:.5f} ' +
                  f'valid_AUPR: {PRC_dev:.5f} '
                  )

  print("\n")
  print(print_msg)
  if valid_score > best_score:
    best_score = valid_score
    patience = 0
    best_epoch = epoch + 1
    torch.save(model.state_dict(), save_path + 'valid_best_checkpoint.pth')
  else:
    patience += 1
    if patience == Early_stopping_patience:
      break

Training...


100%|██████████| 25/25 [01:51<00:00,  4.46s/it]
100%|██████████| 25/25 [00:26<00:00,  1.06s/it]




[  1/100] patience: 0 train_loss: 0.69358 valid_loss: 0.67396 valid_AUC: 0.69680 valid_AUPR: 0.71302 


100%|██████████| 25/25 [01:37<00:00,  3.90s/it]
100%|██████████| 25/25 [00:24<00:00,  1.03it/s]




[  2/100] patience: 0 train_loss: 0.68502 valid_loss: 0.66262 valid_AUC: 0.70120 valid_AUPR: 0.71435 


100%|██████████| 25/25 [01:19<00:00,  3.16s/it]
100%|██████████| 25/25 [00:30<00:00,  1.22s/it]



[  3/100] patience: 0 train_loss: 0.66849 valid_loss: 0.64993 valid_AUC: 0.74600 valid_AUPR: 0.74160 



100%|██████████| 25/25 [01:17<00:00,  3.10s/it]
100%|██████████| 25/25 [00:25<00:00,  1.02s/it]



[  4/100] patience: 0 train_loss: 0.66480 valid_loss: 0.62706 valid_AUC: 0.79940 valid_AUPR: 0.78644 



100%|██████████| 25/25 [01:26<00:00,  3.44s/it]
100%|██████████| 25/25 [00:22<00:00,  1.12it/s]




[  5/100] patience: 0 train_loss: 0.63424 valid_loss: 0.58244 valid_AUC: 0.84150 valid_AUPR: 0.83059 


100%|██████████| 25/25 [01:28<00:00,  3.54s/it]
100%|██████████| 25/25 [00:22<00:00,  1.11it/s]



[  6/100] patience: 0 train_loss: 0.59485 valid_loss: 0.56669 valid_AUC: 0.84730 valid_AUPR: 0.83869 



100%|██████████| 25/25 [01:17<00:00,  3.10s/it]
100%|██████████| 25/25 [00:21<00:00,  1.19it/s]



[  7/100] patience: 0 train_loss: 0.59688 valid_loss: 0.51810 valid_AUC: 0.87350 valid_AUPR: 0.87035 



100%|██████████| 25/25 [01:16<00:00,  3.04s/it]
100%|██████████| 25/25 [00:21<00:00,  1.19it/s]




[  8/100] patience: 0 train_loss: 0.57454 valid_loss: 0.51766 valid_AUC: 0.90540 valid_AUPR: 0.90515 


100%|██████████| 25/25 [01:15<00:00,  3.03s/it]
100%|██████████| 25/25 [00:21<00:00,  1.19it/s]




[  9/100] patience: 0 train_loss: 0.49304 valid_loss: 0.41695 valid_AUC: 0.93690 valid_AUPR: 0.93641 


100%|██████████| 25/25 [01:12<00:00,  2.89s/it]
100%|██████████| 25/25 [00:23<00:00,  1.08it/s]




[ 10/100] patience: 0 train_loss: 0.44953 valid_loss: 0.36855 valid_AUC: 0.95940 valid_AUPR: 0.95602 


100%|██████████| 25/25 [01:13<00:00,  2.93s/it]
100%|██████████| 25/25 [00:20<00:00,  1.21it/s]




[ 11/100] patience: 0 train_loss: 0.40788 valid_loss: 0.33729 valid_AUC: 0.96250 valid_AUPR: 0.96781 


100%|██████████| 25/25 [01:13<00:00,  2.93s/it]
100%|██████████| 25/25 [00:26<00:00,  1.06s/it]



[ 12/100] patience: 0 train_loss: 0.36511 valid_loss: 0.27949 valid_AUC: 0.97520 valid_AUPR: 0.97799 



100%|██████████| 25/25 [01:19<00:00,  3.19s/it]
100%|██████████| 25/25 [00:24<00:00,  1.03it/s]




[ 13/100] patience: 0 train_loss: 0.35555 valid_loss: 0.22040 valid_AUC: 0.98980 valid_AUPR: 0.98970 


100%|██████████| 25/25 [01:13<00:00,  2.96s/it]
100%|██████████| 25/25 [00:20<00:00,  1.19it/s]




[ 14/100] patience: 0 train_loss: 0.26124 valid_loss: 0.17287 valid_AUC: 0.99740 valid_AUPR: 0.99743 


100%|██████████| 25/25 [01:14<00:00,  2.97s/it]
100%|██████████| 25/25 [00:20<00:00,  1.20it/s]



[ 15/100] patience: 0 train_loss: 0.20725 valid_loss: 0.14185 valid_AUC: 0.99780 valid_AUPR: 0.99771 



100%|██████████| 25/25 [01:15<00:00,  3.02s/it]
100%|██████████| 25/25 [00:21<00:00,  1.14it/s]




[ 16/100] patience: 0 train_loss: 0.22997 valid_loss: 0.12860 valid_AUC: 0.99880 valid_AUPR: 0.99879 


100%|██████████| 25/25 [01:19<00:00,  3.17s/it]
100%|██████████| 25/25 [00:21<00:00,  1.14it/s]



[ 17/100] patience: 0 train_loss: 0.17457 valid_loss: 0.16494 valid_AUC: 0.99920 valid_AUPR: 0.99920 



100%|██████████| 25/25 [01:12<00:00,  2.90s/it]
100%|██████████| 25/25 [00:23<00:00,  1.05it/s]




[ 18/100] patience: 0 train_loss: 0.13845 valid_loss: 0.14355 valid_AUC: 0.98800 valid_AUPR: 0.98501 


100%|██████████| 25/25 [01:23<00:00,  3.35s/it]
100%|██████████| 25/25 [00:22<00:00,  1.12it/s]



[ 19/100] patience: 1 train_loss: 0.14915 valid_loss: 0.10569 valid_AUC: 0.99920 valid_AUPR: 0.99921 



100%|██████████| 25/25 [01:12<00:00,  2.91s/it]
100%|██████████| 25/25 [00:22<00:00,  1.11it/s]




[ 20/100] patience: 0 train_loss: 0.14454 valid_loss: 0.07015 valid_AUC: 0.99940 valid_AUPR: 0.99940 


100%|██████████| 25/25 [01:21<00:00,  3.26s/it]
100%|██████████| 25/25 [00:23<00:00,  1.06it/s]




[ 21/100] patience: 0 train_loss: 0.07480 valid_loss: 0.05497 valid_AUC: 1.00000 valid_AUPR: 1.00000 


100%|██████████| 25/25 [01:18<00:00,  3.14s/it]
100%|██████████| 25/25 [00:28<00:00,  1.12s/it]




[ 22/100] patience: 0 train_loss: 0.05307 valid_loss: 0.04542 valid_AUC: 1.00000 valid_AUPR: 1.00000 


100%|██████████| 25/25 [01:12<00:00,  2.90s/it]
100%|██████████| 25/25 [00:24<00:00,  1.02it/s]




[ 23/100] patience: 1 train_loss: 0.05648 valid_loss: 0.03046 valid_AUC: 1.00000 valid_AUPR: 1.00000 


100%|██████████| 25/25 [01:19<00:00,  3.18s/it]
100%|██████████| 25/25 [00:22<00:00,  1.11it/s]




[ 24/100] patience: 2 train_loss: 0.09718 valid_loss: 0.12893 valid_AUC: 0.99920 valid_AUPR: 0.99923 


100%|██████████| 25/25 [01:23<00:00,  3.33s/it]
100%|██████████| 25/25 [00:24<00:00,  1.00it/s]




[ 25/100] patience: 3 train_loss: 0.06685 valid_loss: 0.03947 valid_AUC: 0.99960 valid_AUPR: 0.99960 


100%|██████████| 25/25 [01:15<00:00,  3.04s/it]
100%|██████████| 25/25 [00:23<00:00,  1.04it/s]



[ 26/100] patience: 4 train_loss: 0.02975 valid_loss: 0.01968 valid_AUC: 1.00000 valid_AUPR: 1.00000 





In [12]:
"""Test the best model"""
print('load trained model...')
model.load_state_dict(torch.load(save_path + 'valid_best_checkpoint.pth'))

trainset_test_results, Loss_train, Accuracy_train, Precision_train, Recall_train, F1_score_train, AUC_train, PRC_train = test_model(train_dataset_load, Loss)
with open(save_path + 'results.txt', 'a') as f:
  f.write("The result of train set:"+ trainset_test_results + '\n')

testset_test_results, Loss_test, Accuracy_test, Precision_test, Recall_test, F1_score_test, AUC_test, PRC_test = test_model(test_dataset_load, Loss)
with open(save_path + 'results.txt', 'a') as f:
  f.write("The result of test set:" + testset_test_results + '\n')

load trained model...


100%|██████████| 25/25 [00:22<00:00,  1.13it/s]


Loss:0.05497;Accuracy:0.99000;Precision:1.00000;Recall:0.98000;F1 score:0.98990;AUC:1.00000;PRC:1.00000.


100%|██████████| 25/25 [00:23<00:00,  1.06it/s]

Loss:0.05497;Accuracy:0.99000;Precision:1.00000;Recall:0.98000;F1 score:0.98990;AUC:1.00000;PRC:1.00000.





## Expected Output

Awesome! You complete all demo steps and should get output like the following. Please note that these numbers might be different due to the update of environment setup on colab.

```
load trained model...
100%|███████████████████████████████████████████| 13/13 [00:02<00:00,  5.80it/s]
Loss:0.05332;Accuracy:0.99500;Precision:0.99010;Recall:1.00000;F1 score:0.99502;AUC:1.00000;PRC:1.00000.
100%|███████████████████████████████████████████| 13/13 [00:01<00:00,  7.23it/s]
Loss:0.05443;Accuracy:0.99500;Precision:0.99010;Recall:1.00000;F1 score:0.99502;AUC:1.00000;PRC:1.00000.
```

Finally, the output result is saved in the colab temporary directory: `/content/ColdstartCPI/Demo/Results`. You can access it by clicking `Files` tab on the left side of colab interface.