## TrimNet: learning molecular representation from triplet messages for biomedicine

<b>Motivation</b>

Computational methods accelerate drug discovery and play an important role in biomedicine, such as molecular property prediction and compound–protein interaction (CPI) identification. A key challenge is to learn useful molecular representation. In the early years, molecular properties are mainly calculated by quantum mechanics or predicted by traditional machine learning methods, which requires expert knowledge and is often labor-intensive. Nowadays, graph neural networks have received significant attention because of the powerful ability to learn representation from graph data. Nevertheless, current graph-based methods have some limitations that need to be addressed, such as large-scale parameters and insufficient bond information extraction.

<b>Results</b>

In this study, we proposed a graph-based approach and employed a novel triplet message mechanism to learn molecular representation efficiently, named triplet message networks (TrimNet). We show that TrimNet can accurately complete multiple molecular representation learning tasks with significant parameter reduction, including the quantum properties, bioactivity, physiology and CPI prediction. In the experiments, TrimNet outperforms the previous state-of-the-art method by a significant margin on various datasets. Besides the few parameters and high prediction accuracy, TrimNet could focus on the atoms essential to the target properties, providing a clear interpretation of the prediction tasks. These advantages have established TrimNet as a powerful and useful computational tool in solving the challenging problem of molecular representation learning.

Link to paper: https://bit.ly/3idpppO

Credit: https://github.com/yvquanli/TrimNet

In [None]:
# Clone the repository and cd into directory
!git clone https://github.com/yvquanli/TrimNet.git
%cd TrimNet

Cloning into 'TrimNet'...
remote: Enumerating objects: 67, done.[K
remote: Counting objects: 100% (67/67), done.[K
remote: Compressing objects: 100% (65/65), done.[K
remote: Total 67 (delta 24), reused 0 (delta 0), pack-reused 0[K
Unpacking objects: 100% (67/67), done.
/content/TrimNet


In [None]:
# Install dependencies / requirements
!pip install torch==1.4.0

!pip install torch-geometric \
  torch-sparse==latest+cu101 \
  torch-scatter==latest+cu101 \
  torch-cluster==latest+cu101 \
  -f https://pytorch-geometric.com/whl/torch-1.4.0.html

# Install RDKit
!pip install rdkit-pypi==2021.3.1.5

### Usage example
For quantum dataset

In [None]:
%cd TrimNet/trimnet_quantum/src

# download dataset from https://s3-us-west-1.amazonaws.com/deepchem.io/datasets/molnet_publish/qm9.zip
!wget https://s3-us-west-1.amazonaws.com/deepchem.io/datasets/molnet_publish/qm9.zip

In [None]:
# unzip the file to trimnet_quantum/dataset/raw
!unzip qm9.zip -d ../dataset/raw

In [17]:
!python run.py --depth 3 --seed 1234 --gpu 0

################################################################################
################################################################################
train set num:1201    valid set num:157    test set num: 155
total parameters:98018
################################################################################
################################################################################
Training start...
100% 10/10 [00:00<00:00, 11.16it/s]
100% 2/2 [00:00<00:00, 13.24it/s]
100% 2/2 [00:00<00:00, 50.73it/s]
Epoch:0 bace trn_loss:0.744 trn_roc:0.500 trn_prc:0.459 lr_cur:0.00100 time elapsed 0.00 hrs (0.0 mins)
Epoch:0 bace val_loss:0.695 val_roc:0.515 val_prc:0.432 lr_cur:0.00100 time elapsed 0.00 hrs (0.0 mins)
Epoch:0 bace test_loss:0.699 test_roc:0.543 test_prc:0.486 lr_cur:0.00100 time elapsed 0.00 hrs (0.0 mins)
Model saved at epoch 0
100% 10/10 [00:00<00:00, 27.88it/s]
100% 2/2 [00:00<00:00, 69.25it/s]
100% 2/2 [00:00<00:00, 72.66it/s]
Epoch:1 bace trn_loss:0.701 

For drug dataset

In [13]:
%cd TrimNet/trimnet_drug/source
!python run.py --dataset bace --gpu 0

/content/TrimNet/trimnet_drug/source
################################################################################
################################################################################
train set num:1216    valid set num:143    test set num: 154
total parameters:98018
################################################################################
################################################################################
Training start...
100% 10/10 [00:01<00:00,  8.82it/s]
100% 2/2 [00:00<00:00,  9.91it/s]
100% 2/2 [00:00<00:00, 52.19it/s]
Epoch:0 bace trn_loss:0.711 trn_roc:0.527 trn_prc:0.476 lr_cur:0.00100 time elapsed 0.00 hrs (0.0 mins)
Epoch:0 bace val_loss:0.709 val_roc:0.357 val_prc:0.377 lr_cur:0.00100 time elapsed 0.00 hrs (0.0 mins)
Epoch:0 bace test_loss:0.696 test_roc:0.492 test_prc:0.395 lr_cur:0.00100 time elapsed 0.00 hrs (0.0 mins)
Model saved at epoch 0
100% 10/10 [00:00<00:00, 27.38it/s]
100% 2/2 [00:00<00:00, 73.88it/s]
100% 2/2 [00:00<00:00, 67