<a href="https://colab.research.google.com/github/xy2119/BioKR2_Graph_Transformer/blob/main/GraphGPS/GraphGPS_TDC_explore.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GraphGPS Implementation on [Therapeutics Data Commons](https://tdcommons.ai/)
<img src="https://tdcommons.ai/img/tdc_triangle.png" width="600">

TDC collects ML tasks and associated datasets across therapeutic modalities and stages of discovery. 

TDC Datasets mainly divide into three: Single-instance Prediction, Multi-instance Prediction and Generation.

<img src="https://tdcommons.ai/img/tdc_overview2.png" width="600">

## Results on TDC ADME Regression

|     GraphGPS   Mono-encoder Performance    |     loss       |     mae        |     r2         |     spearmanr    |     mse        |     rmse       |
|--------------------------------------------|----------------|----------------|----------------|------------------|----------------|----------------|
|     Caco-2                                 |     0.33303    |     0.33303    |     0.74028    |     0.89824      |     0.18529    |     0.43045    |
|     Lipophilicity                          |     0.44033    |     0.44033    |     0.71245    |     0.84066      |     0.38297    |     0.61885    |
|     Solubility                             |     0.61756    |     0.61756    |     0.84549    |     0.91817      |     0.88549    |     0.941      |
|     Hydration   Free Energy                |     0.45752    |     0.45752    |     0.96486    |     0.98262      |     0.49524    |     0.70374    |

##**Graph Set Up**

`Graph Input`: SMILES strings

`Node`: atoms (feature embedding by SMILES2graph)

`Edge`: chemical bonds (feature embedding by SMILES2graph)

`Output`: regression task on property prediction


In [None]:
# Load the Drive helper and mount
from google.colab import drive

# This will prompt for authorization.
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%cd /content/drive/MyDrive

/content/drive/MyDrive


In [None]:
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install torch-geometric==2.0.4
!pip install torchmetrics
!pip install performer-pytorch
!pip install ogb
!pip install tensorboardX
!pip install wandb

!pip install h5py
!pip install typing-extensions
!pip install wheel

!pip install git+https://github.com/PyTorchLightning/pytorch-lightning
!pip install yacs

In [None]:
!git clone https://github.com/rampasek/GraphGPS.git

Cloning into 'GraphGPS'...
remote: Enumerating objects: 373, done.[K
remote: Counting objects: 100% (118/118), done.[K
remote: Compressing objects: 100% (28/28), done.[K
remote: Total 373 (delta 103), reused 90 (delta 90), pack-reused 255[K
Receiving objects: 100% (373/373), 12.88 MiB | 23.51 MiB/s, done.
Resolving deltas: 100% (224/224), done.


In [None]:
import os
import torch
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

1.12.1+cu113


In [None]:
import argparse
import logging

import torch
import torch_geometric
import pytorch_lightning 

import torchmetrics
import performer_pytorch
import ogb
import tensorboardX
import wandb

import pandas as pd
import numpy as np 
import random
random.seed(2022)
np.random.seed(2022)

In [None]:
!unzip /content/GraphGPS.zip 

## TDC Data Loading 

In [None]:
!pip install PyTDC

In [None]:
# allowable multiple choice node and edge features 
allowable_features = {
    'possible_atomic_num_list' : list(range(1, 119)) + ['misc'],
    'possible_chirality_list' : [
        'CHI_UNSPECIFIED',
        'CHI_TETRAHEDRAL_CW',
        'CHI_TETRAHEDRAL_CCW',
        'CHI_OTHER'
    ],
    'possible_degree_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'],
    'possible_formal_charge_list' : [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'],
    'possible_numH_list' : [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
    'possible_number_radical_e_list': [0, 1, 2, 3, 4, 'misc'],
    'possible_hybridization_list' : [
        'SP', 'SP2', 'SP3', 'SP3D', 'SP3D2', 'misc'
        ],
    'possible_is_aromatic_list': [False, True],
    'possible_is_in_ring_list': [False, True],
    'possible_bond_type_list' : [
        'SINGLE',
        'DOUBLE',
        'TRIPLE',
        'AROMATIC',
        'misc'
    ],
    'possible_bond_stereo_list': [
        'STEREONONE',
        'STEREOZ',
        'STEREOE',
        'STEREOCIS',
        'STEREOTRANS',
        'STEREOANY',
    ], 
    'possible_is_conjugated_list': [False, True],
}

def safe_index(l, e):
    """
    Return index of element e in list l. If e is not present, return the last index
    """
    try:
        return l.index(e)
    except:
        return len(l) - 1
# # miscellaneous case
# i = safe_index(allowable_features['possible_atomic_num_list'], 'asdf')
# assert allowable_features['possible_atomic_num_list'][i] == 'misc'
# # normal case
# i = safe_index(allowable_features['possible_atomic_num_list'], 2)
# assert allowable_features['possible_atomic_num_list'][i] == 2

def atom_to_feature_vector(atom):
    """
    Converts rdkit atom object to feature list of indices
    :param mol: rdkit atom object
    :return: list
    """
    atom_feature = [
            safe_index(allowable_features['possible_atomic_num_list'], atom.GetAtomicNum()),
            allowable_features['possible_chirality_list'].index(str(atom.GetChiralTag())),
            safe_index(allowable_features['possible_degree_list'], atom.GetTotalDegree()),
            safe_index(allowable_features['possible_formal_charge_list'], atom.GetFormalCharge()),
            safe_index(allowable_features['possible_numH_list'], atom.GetTotalNumHs()),
            safe_index(allowable_features['possible_number_radical_e_list'], atom.GetNumRadicalElectrons()),
            safe_index(allowable_features['possible_hybridization_list'], str(atom.GetHybridization())),
            allowable_features['possible_is_aromatic_list'].index(atom.GetIsAromatic()),
            allowable_features['possible_is_in_ring_list'].index(atom.IsInRing()),
            ]
    return atom_feature
# from rdkit import Chem
# mol = Chem.MolFromSmiles('Cl[C@H](/C=C/C)Br')
# atom = mol.GetAtomWithIdx(1)  # chiral carbon
# atom_feature = atom_to_feature_vector(atom)
# assert atom_feature == [5, 2, 4, 5, 1, 0, 2, 0, 0]


def get_atom_feature_dims():
    return list(map(len, [
        allowable_features['possible_atomic_num_list'],
        allowable_features['possible_chirality_list'],
        allowable_features['possible_degree_list'],
        allowable_features['possible_formal_charge_list'],
        allowable_features['possible_numH_list'],
        allowable_features['possible_number_radical_e_list'],
        allowable_features['possible_hybridization_list'],
        allowable_features['possible_is_aromatic_list'],
        allowable_features['possible_is_in_ring_list']
        ]))

def bond_to_feature_vector(bond):
    """
    Converts rdkit bond object to feature list of indices
    :param mol: rdkit bond object
    :return: list
    """
    bond_feature = [
                safe_index(allowable_features['possible_bond_type_list'], str(bond.GetBondType())),
                allowable_features['possible_bond_stereo_list'].index(str(bond.GetStereo())),
                allowable_features['possible_is_conjugated_list'].index(bond.GetIsConjugated()),
            ]
    return bond_feature
# uses same molecule as atom_to_feature_vector test
# bond = mol.GetBondWithIdx(2)  # double bond with stereochem
# bond_feature = bond_to_feature_vector(bond)
# assert bond_feature == [1, 2, 0]

def get_bond_feature_dims():
    return list(map(len, [
        allowable_features['possible_bond_type_list'],
        allowable_features['possible_bond_stereo_list'],
        allowable_features['possible_is_conjugated_list']
        ]))

def atom_feature_vector_to_dict(atom_feature):
    [atomic_num_idx, 
    chirality_idx,
    degree_idx,
    formal_charge_idx,
    num_h_idx,
    number_radical_e_idx,
    hybridization_idx,
    is_aromatic_idx,
    is_in_ring_idx] = atom_feature

    feature_dict = {
        'atomic_num': allowable_features['possible_atomic_num_list'][atomic_num_idx],
        'chirality': allowable_features['possible_chirality_list'][chirality_idx],
        'degree': allowable_features['possible_degree_list'][degree_idx],
        'formal_charge': allowable_features['possible_formal_charge_list'][formal_charge_idx],
        'num_h': allowable_features['possible_numH_list'][num_h_idx],
        'num_rad_e': allowable_features['possible_number_radical_e_list'][number_radical_e_idx],
        'hybridization': allowable_features['possible_hybridization_list'][hybridization_idx],
        'is_aromatic': allowable_features['possible_is_aromatic_list'][is_aromatic_idx],
        'is_in_ring': allowable_features['possible_is_in_ring_list'][is_in_ring_idx]
    }

    return feature_dict
# # uses same atom_feature as atom_to_feature_vector test
# atom_feature_dict = atom_feature_vector_to_dict(atom_feature)
# assert atom_feature_dict['atomic_num'] == 6
# assert atom_feature_dict['chirality'] == 'CHI_TETRAHEDRAL_CCW'
# assert atom_feature_dict['degree'] == 4
# assert atom_feature_dict['formal_charge'] == 0
# assert atom_feature_dict['num_h'] == 1
# assert atom_feature_dict['num_rad_e'] == 0
# assert atom_feature_dict['hybridization'] == 'SP3'
# assert atom_feature_dict['is_aromatic'] == False
# assert atom_feature_dict['is_in_ring'] == False

def bond_feature_vector_to_dict(bond_feature):
    [bond_type_idx, 
    bond_stereo_idx,
    is_conjugated_idx] = bond_feature

    feature_dict = {
        'bond_type': allowable_features['possible_bond_type_list'][bond_type_idx],
        'bond_stereo': allowable_features['possible_bond_stereo_list'][bond_stereo_idx],
        'is_conjugated': allowable_features['possible_is_conjugated_list'][is_conjugated_idx]
    }

    return feature_dict
# # uses same bond as bond_to_feature_vector test
# bond_feature_dict = bond_feature_vector_to_dict(bond_feature)
# assert bond_feature_dict['bond_type'] == 'DOUBLE'
# assert bond_feature_dict['bond_stereo'] == 'STEREOE'
# assert bond_feature_dict['is_conjugated'] == False

In [None]:
from ogb.utils.features import (allowable_features, atom_to_feature_vector,
 bond_to_feature_vector, atom_feature_vector_to_dict, bond_feature_vector_to_dict) 
from rdkit import Chem
import numpy as np

def smiles2graph(smiles_string):
    """
    Converts SMILES string to graph Data object
    :input: SMILES string (str)
    :return: graph object
    """

    mol = Chem.MolFromSmiles(smiles_string)

    # atoms
    atom_features_list = []
    for atom in mol.GetAtoms():
        atom_features_list.append(atom_to_feature_vector(atom))
    x = np.array(atom_features_list, dtype = np.int64)

    # bonds
    num_bond_features = 3  # bond type, bond stereo, is_conjugated
    if len(mol.GetBonds()) > 0: # mol has bonds
        edges_list = []
        edge_features_list = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()

            edge_feature = bond_to_feature_vector(bond)

            # add edges in both directions
            edges_list.append((i, j))
            edge_features_list.append(edge_feature)
            edges_list.append((j, i))
            edge_features_list.append(edge_feature)

        # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
        edge_index = np.array(edges_list, dtype = np.int64).T

        # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
        edge_attr = np.array(edge_features_list, dtype = np.int64)

    else:   # mol has no bonds
        edge_index = np.empty((2, 0), dtype = np.int64)
        edge_attr = np.empty((0, num_bond_features), dtype = np.int64)

    graph = dict()
    graph['edge_index'] = edge_index
    graph['edge_feat'] = edge_attr
    graph['node_feat'] = x
    graph['num_nodes'] = len(x)

    return graph 

In [None]:
from tdc.single_pred import ADME
data = ADME(name = 'Caco2_Wang')

df=data.get_data()[['Drug','Y']]
df['Drug']=df['Drug'].apply(lambda x:smiles2graph(x))

Downloading...
100%|██████████| 82.5k/82.5k [00:00<00:00, 1.03MiB/s]
Loading...
Done!


In [None]:
import hashlib
import os.path as osp
import pickle
import shutil

import pandas as pd
import torch
from ogb.utils import smiles2graph
from ogb.utils.torch_util import replace_numpy_with_torchtensor
from ogb.utils.url import decide_download
from torch_geometric.data import Data, InMemoryDataset, download_url
from tqdm import tqdm

class TDC_DTI_data(InMemoryDataset):
    SEED = 42
    VAL_RATIO = 0.05
    TEST_RATIO = 0.05

    def __init__(self, root='datasets', subset='530k',smiles2graph=smiles2graph,transform=None, pre_transform=None):
        self.original_root = root
        self.subset = subset
        self.smiles2graph = smiles2graph
        self.folder = osp.join(root, 'TDC')
        self.generate_splits = True


        super().__init__(self.folder, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])
        
    @property
    def raw_file_names(self):
        return 'DTC_DTI_dataset.csv.gz'

    @property
    def processed_file_names(self):
        return 'DTC_DTI_data_processed.pt'

    def process(self):
        data = ADME(name = 'Caco2_Wang')   
        data_df=data.get_data()[['Drug','Y']]
        data_df['Drug']=data_df['Drug'].apply(lambda x: x.encode('utf-8'))


        smiles_list = data_df['Drug']

        print('Converting SMILES strings into graphs...')
        data_list = []
        for i in tqdm(range(len(smiles_list))):
            data = Data()
            smiles = smiles_list[i]
            graph = self.smiles2graph(smiles)

            assert (len(graph['edge_feat']) == graph['edge_index'].shape[1])
            assert (len(graph['node_feat']) == graph['num_nodes'])

            data.__num_nodes__ = int(graph['num_nodes'])
            data.edge_index = torch.from_numpy(graph['edge_index']).to(
                torch.int64)
            data.edge_attr = torch.from_numpy(graph['edge_feat']).to(
                torch.int64)
            data.x = torch.from_numpy(graph['node_feat']).to(torch.int64)
            data.y = torch.Tensor([data_df['Y'].iloc[i]])

            data_list.append(data)

        if self.generate_splits:
            # Random shuffle split of the molecules by 90/5/5 ratio.
            self.create_shuffle_split(len(data_list),
                                      self.VAL_RATIO, self.TEST_RATIO)

            # Create 90/5/5 split by the size of molecules.
            num_atoms_list = [d.num_nodes for d in data_list]
            self.create_numatoms_split(num_atoms_list,
                                       self.VAL_RATIO, self.TEST_RATIO)

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)

        print('Saving...')
        torch.save((data, slices), self.processed_paths[0])

    def create_shuffle_split(self, N, val_ratio, test_ratio):
        """ Create a random shuffle split and saves it to disk.
        Args:
            N: Total size of the dataset to split.
        """
        rng = np.random.default_rng(seed=self.SEED)
        all_ind = rng.permutation(N)
        train_ratio = 1 - val_ratio - test_ratio
        val_ratio_rem = val_ratio / (val_ratio + test_ratio)

        # Random shuffle split into 90/5/5.
        train_ind = all_ind[:int(train_ratio * N)]
        tmp_ind = all_ind[int(train_ratio * N):]
        val_ind = tmp_ind[:int(val_ratio_rem * len(tmp_ind))]
        test_ind = tmp_ind[int((1 - val_ratio_rem) * len(tmp_ind)):]
        assert self._check_splits(N, [train_ind, val_ind, test_ind],
                                  [train_ratio, val_ratio, test_ratio])

        shuffle_split = {'train': train_ind, 'val': val_ind, 'test': test_ind}
        torch.save(shuffle_split,
                   osp.join(self.folder, f'{self.subset}_shuffle_split_dict.pt'))

    def create_numatoms_split(self, num_atoms_list, val_ratio, test_ratio):
        """ Create split by the size of molecules, testing on the largest ones.
        Args:
            num_atoms_list: List with molecule size per each graph.
        """
        rng = np.random.default_rng(seed=self.SEED)
        all_ind = np.argsort(np.array(num_atoms_list))
        train_ratio = 1 - val_ratio - test_ratio
        val_ratio_rem = val_ratio / (val_ratio + test_ratio)

        # Split based on mol size into 90/5/5, but shuffle the top 10% randomly
        # before splitting to validation and test set.
        N = len(num_atoms_list)
        train_ind = all_ind[:int(train_ratio * N)]
        tmp_ind = all_ind[int(train_ratio * N):]
        rng.shuffle(tmp_ind)
        val_ind = tmp_ind[:int(val_ratio_rem * len(tmp_ind))]
        test_ind = tmp_ind[int((1 - val_ratio_rem) * len(tmp_ind)):]
        assert len(train_ind) + len(val_ind) + len(test_ind) == N
        assert self._check_splits(N, [train_ind, val_ind, test_ind],
                                  [train_ratio, val_ratio, test_ratio])

        size_split = {'train': train_ind, 'val': val_ind, 'test': test_ind}
        torch.save(size_split,
                   osp.join(self.folder, f'{self.subset}_num-atoms_split_dict.pt'))

    def _check_splits(self, N, splits, ratios):
        """ Check whether splits intersect and raise error if so.
        """
        assert sum([len(split) for split in splits]) == N
        for ii, split in enumerate(splits):
            true_ratio = len(split) / N
            assert abs(true_ratio - ratios[ii]) < 3 / N
        for i in range(len(splits) - 1):
            for j in range(i + 1, len(splits)):
                n_intersect = len(set(splits[i]) & set(splits[j]))
                if n_intersect != 0:
                    raise ValueError(
                        f"Splits must not have intersecting indices: "
                        f"split #{i} (n = {len(splits[i])}) and "
                        f"split #{j} (n = {len(splits[j])}) have "
                        f"{n_intersect} intersecting indices"
                    )
        return True

    def get_idx_split(self, split_name):
        """ Get dataset splits.
        Args:
            split_name: Split type: 'shuffle', 'num-atoms'
        Returns:
            Dict with 'train', 'val', 'test', splits indices.
        """
        split_file = osp.join(
            self.folder,f"{self.subset}_{split_name.replace('-', '_')}_split_dict.pt"

            
        )
        split_dict = replace_numpy_with_torchtensor(torch.load(split_file))
        return split_dict



In [None]:
def preformat_TDC_DTI(dataset_dir, name):
    """Load TDC Drug Target Interaction dataset.
    
    Args:
       dataset_dir: path where to store the cached dataset
       name: the type of dataset split: 'shuffle', 'num-atoms'
    Returns:
       PyG dataset object
    """
 
    split_name = name.split('-', 1)[1]
    dataset = TDC_DTI_data(dataset_dir, subset='530k')
    # Inductive graph-level split (there is no train/test edge split).
    s_dict = dataset.get_idx_split(split_name)
    dataset.split_idxs = [s_dict[s] for s in ['train', 'val', 'test']]
    if cfg.dataset.resample_negative:
        dataset.transform = structured_neg_sampling_transform
    return dataset



## ADME Regression


### **Caco-2 (Cell Effective Permeability)**
`Dataset Description`: The human colon epithelial cancer cell line, Caco-2, is used as an in vitro model to simulate the human intestinal tissue. The experimental result on the rate of drug passing through the Caco-2 cells can approximate the rate at which the drug permeates through the human intestinal tissue.

`Task Description`: Regression. Given a drug SMILES string, predict the Caco-2 cell effective permeability.

`Dataset Statistics`: 906 drugs.



In [None]:
%cd ../
%cd /content/content/GraphGPS
#!conda activate graphgps

# Running GPS with RWSE and tuned hyperparameters for ZINC.
!python main.py --cfg configs/GPS/TDC-DTI-Caco-GPS+RWSE.yaml  wandb.use False



/content/content
/content/content/GraphGPS
GPU Mem: [2]
GPU Prob: [1.]
Random select GPU, select GPU 0 with mem: 2
[*] Run ID 0: seed=0, split_index=0
    Starting now: 2022-09-21 13:12:57.138133
Processing...
TDC data Caco2_Wang from ADME is being processed...
Found local copy...
Loading...
Done!
Converting SMILES strings into graphs...
100% 910/910 [00:01<00:00, 793.47it/s]
Saving...
Done!
[*] Loaded dataset 'PyG-TDC-Caco2_Wang' from 'PyG-TDC':
  Data(edge_index=[2, 57400], edge_attr=[57400, 3], x=[26722, 9], y=[910])
  undirected: True
  num graphs: 910
  avg num_nodes/graph: 29
  num node features: 9
  num edge features: 3
  num classes: (appears to be a regression task)
Parsed RWSE PE kernel times / steps: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
Precomputing Positional Encoding statistics: ['RWSE'] for all graphs...
  ...estimated to be undirected: True
100% 910/910 [00:00<00:00, 1028.37it/s]
Done! Took 00:00:00.90
GPSModel(
  (encoder): FeatureEncoder(
    (node_e

### **Lipophilicity, AstraZeneca**
`Dataset Description`: Lipophilicity measures the ability of a drug to dissolve in a lipid (e.g. fats, oils) environment. High lipophilicity often leads to high rate of metabolism, poor solubility, high turn-over, and low absorption. From MoleculeNet.

`Task Description`: Regression. Given a drug SMILES string, predict the activity of lipophilicity.

`Dataset Statistics`: 4,200 drugs.



In [None]:
%cd ../
%cd /content/content/GraphGPS
#!conda activate graphgps

# Running GPS with RWSE and tuned hyperparameters for ZINC.
!python main.py --cfg configs/GPS/TDC-DTI-LAZ-GPS+RWSE.yaml  wandb.use False

/content/content
/content/content/GraphGPS
GPU Mem: [2]
GPU Prob: [1.]
Random select GPU, select GPU 0 with mem: 2
[*] Run ID 0: seed=0, split_index=0
    Starting now: 2022-09-21 13:28:30.658216
Processing...
TDC data Lipophilicity_AstraZeneca from ADME is being processed...
Downloading...
100% 298k/298k [00:00<00:00, 1.16MiB/s]
Loading...
Done!
Converting SMILES strings into graphs...
100% 4200/4200 [00:05<00:00, 817.60it/s]
Saving...
Done!
[*] Loaded dataset 'PyG-TDC-Lipophilicity_AstraZeneca' from 'PyG-TDC':
  Data(edge_index=[2, 247798], edge_attr=[247798, 3], x=[113568, 9], y=[4200])
  undirected: True
  num graphs: 4200
  avg num_nodes/graph: 27
  num node features: 9
  num edge features: 3
  num classes: (appears to be a regression task)
Parsed RWSE PE kernel times / steps: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
Precomputing Positional Encoding statistics: ['RWSE'] for all graphs...
  ...estimated to be undirected: True
100% 4200/4200 [00:04<00:00, 1000.59it/s]

### **Solubility, AqSolDB**
`Dataset Description`: Aqeuous solubility measures a drug's ability to dissolve in water. Poor water solubility could lead to slow drug absorptions, inadequate bioavailablity and even induce toxicity. More than 40% of new chemical entities are not soluble.

`Task Description`: Regression. Given a drug SMILES string, predict the activity of solubility.

`Dataset Statistics`: 9,982 drugs.

In [None]:
%cd ../
%cd /content/content/GraphGPS
#!conda activate graphgps

# Running GPS with RWSE and tuned hyperparameters for ZINC.
!python main.py --cfg configs/GPS/TDC-DTI-Solubility-GPS+RWSE.yaml  wandb.use False

/content/content
/content/content/GraphGPS
GPU Mem: [2]
GPU Prob: [1.]
Random select GPU, select GPU 0 with mem: 2
[*] Run ID 0: seed=0, split_index=0
    Starting now: 2022-09-21 13:38:04.212406
Processing...
TDC data Solubility_AqSolDB from ADME is being processed...
Downloading...
100% 853k/853k [00:00<00:00, 1.99MiB/s]
Loading...
Done!
Converting SMILES strings into graphs...
100% 9982/9982 [00:09<00:00, 1005.01it/s]
Saving...
Done!
[*] Loaded dataset 'PyG-TDC-Solubility_AqSolDB' from 'PyG-TDC':
  Data(edge_index=[2, 352000], edge_attr=[352000, 3], x=[173500, 9], y=[9982])
  undirected: True
  num graphs: 9982
  avg num_nodes/graph: 17
  num node features: 9
  num edge features: 3
  num classes: (appears to be a regression task)
Parsed RWSE PE kernel times / steps: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
Precomputing Positional Encoding statistics: ['RWSE'] for all graphs...
  ...estimated to be undirected: True
100% 9982/9982 [00:10<00:00, 924.53it/s]
Done! Took 00

### **Hydration Free Energy, FreeSolv**
`Dataset Description`: The Free Solvation Database, FreeSolv(SAMPL), provides experimental and calculated hydration free energy of small molecules in water. The calculated values are derived from alchemical free energy calculations using molecular dynamics simulations. From MoleculeNet.

`Task Description`: Regression. Given a drug SMILES string, predict the activity of hydration free energy.

`Dataset Statistics`: 642 drugs.

In [None]:
%cd ../
%cd /content/content/GraphGPS
#!conda activate graphgps

# Running GPS with RWSE and tuned hyperparameters for ZINC.
!python main.py --cfg configs/GPS/TDC-DTI-Hydra-GPS+RWSE.yaml  wandb.use False

/content/content
/content/content/GraphGPS
GPU Mem: [2]
GPU Prob: [1.]
Random select GPU, select GPU 0 with mem: 2
[*] Run ID 0: seed=0, split_index=0
    Starting now: 2022-09-21 13:49:45.614614
Processing...
TDC data HydrationFreeEnergy_FreeSolv from ADME is being processed...
Downloading...
100% 29.0k/29.0k [00:00<00:00, 18.2MiB/s]
Loading...
Done!
Converting SMILES strings into graphs...
100% 642/642 [00:00<00:00, 1188.40it/s]
Saving...
Done!
[*] Loaded dataset 'PyG-TDC-HydrationFreeEnergy_FreeSolv' from 'PyG-TDC':
  Data(edge_index=[2, 10770], edge_attr=[10770, 3], x=[5600, 9], y=[642])
  undirected: True
  num graphs: 642
  avg num_nodes/graph: 8
  num node features: 9
  num edge features: 3
  num classes: (appears to be a regression task)
Parsed RWSE PE kernel times / steps: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
Precomputing Positional Encoding statistics: ['RWSE'] for all graphs...
  ...estimated to be undirected: True
100% 642/642 [00:00<00:00, 1159.73it/s]
D

## Protein Sequence Encoding

In [None]:
import os, sys, math

codes = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
         'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']

def one_hot_encode(seq):
    o = list(set(codes) - set(seq))
    s = pd.DataFrame(list(seq))    
    x = pd.DataFrame(np.zeros((len(seq),len(o)),dtype=int),columns=o)    
    a = s[0].str.get_dummies(sep=',')
    a = a.join(x)
    a = a.sort_index(axis=1)
    e = a.values.flatten()
    return e

pep='ALDFEQEMT'
e=one_hot_encode(pep)
e

array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       1, 0, 0, 0])

In [None]:
from tdc.multi_pred import PPI
data = PPI(name = 'HuRI')
data=data.get_data()
data

Found local copy...
Loading...
Done!


Unnamed: 0,Protein1_ID,Protein1,Protein2_ID,Protein2,Y
0,ENSG00000000005,MAKNPPENCEDCHILNAEAFKSKKICKSLKICGLVFGILALTLIVL...,ENSG00000061656,MRRSSRPGSASSSRKHTPNFFSENSSMSITSEDSKGLRSAEPGPGE...,1
1,ENSG00000000005,MAKNPPENCEDCHILNAEAFKSKKICKSLKICGLVFGILALTLIVL...,ENSG00000099968,MASSSTVPLGFHYETKYVVLSYLGLLSQEKLQEQHLSSPQGVQLDI...,1
2,ENSG00000000005,MAKNPPENCEDCHILNAEAFKSKKICKSLKICGLVFGILALTLIVL...,ENSG00000104765,MSSHLVEPPPPLHNNNNNCEENEQSLPPPAGLNSSWVELPMNSSNG...,1
3,ENSG00000000005,MAKNPPENCEDCHILNAEAFKSKKICKSLKICGLVFGILALTLIVL...,ENSG00000105383,MPLLLLLPLLWAGALAMDPNFWLQVQESVTVQEGLCVLVPCTFFHP...,1
4,ENSG00000000005,MAKNPPENCEDCHILNAEAFKSKKICKSLKICGLVFGILALTLIVL...,ENSG00000114455,MKAQTALSFFLILITSLSGSQGIFPLAFFIYVPMNEQIVIGRLDED...,1
...,...,...,...,...,...
52364,ENSG00000273899,MGRNKKKKRDGDDRRPRLVLSFDEEKRREYLTGFHKRKVERKKAAI...,ENSG00000273899,MGRNKKKKRDGDDRRPRLVLSFDEEKRREYLTGFHKRKVERKKAAI...,1
52365,ENSG00000275302,MKLCVTVLSLLMLVAAFCSPALSAPMGSDPPTACCFSYTARKLPRN...,ENSG00000278619,MEVMDVFSTDDLTGFLQTKAQQGWLVAGTVGCPSTEDPQSSEIPIM...,1
52366,ENSG00000275774,MASNVTNKMDPHSVNSRVFIGNLNTLVVKKSDVEAIFSKYGKIAGC...,ENSG00000275774,MASNVTNKMDPHSVNSRVFIGNLNTLVVKKSDVEAIFSKYGKIAGC...,1
52367,ENSG00000276070,MKLCVTVLSLLVLVAAFCSLALSAPMGSDPPTACCFSYTARKLPRN...,ENSG00000278619,MEVMDVFSTDDLTGFLQTKAQQGWLVAGTVGCPSTEDPQSSEIPIM...,1


In [None]:
codes = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
         'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']

def create_dict(codes):
  char_dict = {}
  for index, val in enumerate(codes):
    char_dict[val] = index+1

  return char_dict

char_dict = create_dict(codes)

def integer_encoding(data):
  """
  - Encodes code sequence to integer values.
  - 20 common amino acids are taken into consideration
    and rest 4 are categorized as 0.
  """
  
  encode_list = []
  for row in data:
    row_encode = []
    for code in row:
      row_encode.append(char_dict.get(code, 0))
    encode_list.append(np.array(row_encode))
  
  return encode_list

e=integer_encoding(data['Protein1'].values)

In [None]:
#np.array(np.concatenate(e[0],None), dtype = np.int64)

array([11,  1,  9, 12, 13, 13,  4, 12,  2,  4,  3,  2,  7,  8, 10, 12,  1,
        4,  1,  5,  9, 16,  9,  9,  8,  2,  9, 16, 10,  9,  8,  2,  6, 10,
       18,  5,  6,  8, 10,  1, 10, 17, 10,  8, 18, 10,  5, 19,  6, 16,  9,
        7,  5, 19, 13,  4, 18, 13,  9,  9,  1, 20,  3, 11,  4,  7, 17,  5,
       20, 16, 12,  6,  4,  9,  9,  9,  8, 20, 11,  4,  8,  3, 13, 18, 17,
       15, 17,  4,  8,  5, 15, 16,  6, 12,  6, 17,  3,  4, 17, 10,  4, 18,
        7,  3,  5,  9, 12,  6, 20, 17,  6,  8, 20,  5, 18,  6, 10, 14,  9,
        2,  5,  8,  9, 17, 14,  8,  9, 18,  8, 13,  4,  5, 16,  4, 13,  4,
        4,  4,  8,  3,  4, 12,  4,  4,  8, 17, 17, 17,  5,  5,  4, 14, 16,
       18,  8, 19, 18, 13,  1,  4,  9, 13,  8,  4, 12, 15,  3,  5, 10,  9,
       12, 16,  9,  8, 10,  4,  8,  2,  3, 12, 18, 17, 11, 20, 19,  8, 12,
       13, 17, 10,  8, 16, 18, 16,  4, 10, 14,  3,  5,  4,  4,  4,  6,  4,
        3, 10,  7,  5, 13,  1, 12,  4,  9,  9,  6,  8,  4, 14, 12,  4, 14,
       19, 18, 18, 13, 14

In [None]:
#np.concatenate([np.array([[5,4,3,2,1],[9,8,7,6,5]],np.int64),[0,1]])
a=np.array([[5,4,3,2,1],[9,8,7,6,5]],np.int64)
seq=[0,1]

feats=np.zeros(shape=(int(a.shape[0])+len(seq),int(a.shape[1])+len(seq)))
for i in range(a.shape[0]):
  concate=np.concatenate([a[i],seq])
  feats[i]=concate
  print('\n')

feats

[5 4 3 2 1 0 1]


[9 8 7 6 5 0 1]




array([[5., 4., 3., 2., 1., 0., 1.],
       [9., 8., 7., 6., 5., 0., 1.],
       [0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0.]])

In [None]:
from ogb.utils.features import (allowable_features, atom_to_feature_vector,
 bond_to_feature_vector, atom_feature_vector_to_dict, bond_feature_vector_to_dict) 
from rdkit import Chem
import numpy as np

def smiles2graph(smiles_string):
    """
    Converts SMILES string to graph Data object
    :input: SMILES string (str)
    :return: graph object
    """

    mol = Chem.MolFromSmiles(smiles_string)

    # atoms
    atom_features_list = []
    for atom in mol.GetAtoms():
        atom_features_list.append(atom_to_feature_vector(atom))
    x = np.array(atom_features_list, dtype = np.int64)

    # bonds
    num_bond_features = 3  # bond type, bond stereo, is_conjugated
    if len(mol.GetBonds()) > 0: # mol has bonds
        edges_list = []
        edge_features_list = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()

            edge_feature = bond_to_feature_vector(bond)

            # add edges in both directions
            edges_list.append((i, j))
            edge_features_list.append(edge_feature)
            edges_list.append((j, i))
            edge_features_list.append(edge_feature)

        # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
        edge_index = np.array(edges_list, dtype = np.int64).T

        # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
        edge_attr = np.array(edge_features_list, dtype = np.int64)

    else:   # mol has no bonds
        edge_index = np.empty((2, 0), dtype = np.int64)
        edge_attr = np.empty((0, num_bond_features), dtype = np.int64)

    graph = dict()
    graph['edge_index'] = edge_index
    graph['edge_feat'] = edge_attr
    graph['node_feat'] = x
    graph['num_nodes'] = len(x)

    return graph 

In [None]:
import hashlib
import os.path as osp
import pickle
import shutil

import pandas as pd
import torch
from ogb.utils import smiles2graph
from ogb.utils.torch_util import replace_numpy_with_torchtensor
from ogb.utils.url import decide_download
from torch_geometric.data import Data, InMemoryDataset, download_url
from tqdm import tqdm

def process():
    data_df = pd.read_csv("/content/peptide_multi_class_dataset.csv")
    smiles_list = data_df['smiles']
    seq_list = data_df['peptide_seq']

    print('Converting SMILES and sequence strings into graphs...')
    data_list = []
    for i in tqdm(range(len(smiles_list))):
        data = Data()

        smiles = smiles_list[i]
        graph = smiles2graph(smiles)
        print(graph['node_feat'])

        assert (len(graph['edge_feat']) == graph['edge_index'].shape[1])
        assert (len(graph['node_feat']) == graph['num_nodes'])
        
        seq = seq_list[i]
        seq = integer_encoding(seq)
        print('sequence: \n',seq)

        data.__num_nodes__ = int(graph['num_nodes'])
        data.edge_index = torch.from_numpy(graph['edge_index']).to(
            torch.int64)
        data.edge_attr = torch.from_numpy(graph['edge_feat']).to(
            torch.int64)
        data.x = torch.from_numpy(graph['node_feat']).to(torch.int64)
        data.y = torch.Tensor([eval(data_df['labels'].iloc[i])])

        data_list.append(data)
process()

In [None]:
#read the matrix a csv file on github
nlf = pd.read_csv('https://raw.githubusercontent.com/dmnfarrell/epitopepredict/master/epitopepredict/mhcdata/NLF.csv',index_col=0)

def nlf_encode(seq):    
    x = pd.DataFrame([nlf[i] for i in seq]).reset_index(drop=True)  
    e = x.values.flatten()
    return e

e=nlf_encode(pep)
e

array([0.42, 2.07, 0.67, 0.01, 1.1 , 0.32, 0.2 , 0.09, 0.2 , 0.09, 0.11,
       0.15, 0.01, 0.06, 0.02, 0.16, 0.07, 0.03, 1.29, 1.21, 0.25, 0.96,
       0.18, 0.06, 0.04, 0.  , 0.09, 0.26, 0.18, 0.05, 0.  , 0.11, 0.01,
       0.15, 0.14, 0.02, 0.81, 0.13, 1.36, 0.63, 0.15, 0.1 , 0.45, 0.31,
       0.1 , 0.03, 0.15, 0.02, 0.16, 0.12, 0.07, 0.11, 0.01, 0.05, 2.37,
       0.23, 0.09, 0.37, 0.19, 0.04, 0.03, 0.06, 0.14, 0.14, 0.1 , 0.03,
       0.21, 0.04, 0.09, 0.03, 0.06, 0.18, 1.56, 0.48, 0.87, 0.02, 0.07,
       0.13, 0.22, 0.15, 0.09, 0.1 , 0.04, 0.05, 0.12, 0.28, 0.03, 0.09,
       0.06, 0.02, 1.71, 1.11, 0.08, 0.15, 0.11, 0.45, 0.11, 0.08, 0.02,
       0.25, 0.12, 0.25, 0.2 , 0.16, 0.01, 0.07, 0.02, 0.03, 1.56, 0.48,
       0.87, 0.02, 0.07, 0.13, 0.22, 0.15, 0.09, 0.1 , 0.04, 0.05, 0.12,
       0.28, 0.03, 0.09, 0.06, 0.02, 1.72, 0.85, 0.34, 0.44, 0.01, 0.8 ,
       0.16, 0.05, 0.05, 0.3 , 0.29, 0.06, 0.02, 0.03, 0.03, 0.09, 0.14,
       0.  , 0.3 , 0.68, 0.88, 0.23, 0.1 , 0.23, 0.

In [None]:
!pip install epitopepredict
import epitopepredict as ep

blosum = ep.blosum62

def blosum_encode(seq):
    #encode a peptide into blosum features
    s=list(seq)
    x = pd.DataFrame([blosum[i] for i in seq]).reset_index(drop=True)
    e = x.values.flatten()    
    return e

def random_encode(p):
    return [np.random.randint(20) for i in pep]

e=blosum_encode(pep)
e

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


array([ 4, -1, -2, -2,  0, -1, -1,  0, -2, -1, -1, -1, -1, -2, -1,  1,  0,
       -3, -2,  0, -2, -1,  0, -4, -1, -2, -3, -4, -1, -2, -3, -4, -3,  2,
        4, -2,  2,  0, -3, -2, -1, -2, -1,  1, -4, -3, -1, -4, -2, -2,  1,
        6, -3,  0,  2, -1, -1, -3, -4, -1, -3, -3, -1,  0, -1, -4, -3, -3,
        4,  1, -1, -4, -2, -3, -3, -3, -2, -3, -3, -3, -1,  0,  0, -3,  0,
        6, -4, -2, -2,  1,  3, -1, -3, -3, -1, -4, -1,  0,  0,  2, -4,  2,
        5, -2,  0, -3, -3,  1, -2, -3, -1,  0, -1, -3, -2, -2,  1,  4, -1,
       -4, -1,  1,  0,  0, -3,  5,  2, -2,  0, -3, -2,  1,  0, -3, -1,  0,
       -1, -2, -1, -2,  0,  3, -1, -4, -1,  0,  0,  2, -4,  2,  5, -2,  0,
       -3, -3,  1, -2, -3, -1,  0, -1, -3, -2, -2,  1,  4, -1, -4, -1, -1,
       -2, -3, -1,  0, -2, -3, -2,  1,  2, -1,  5,  0, -2, -1, -1, -1, -1,
        1, -3, -1, -1, -4,  0, -1,  0, -1, -1, -1, -1, -2, -2, -1, -1, -1,
       -1, -2, -1,  1,  5, -2, -2,  0, -1, -1,  0, -4])

In [None]:
%cd ../
%cd /content/content/GraphGPS

# Running GPS with RWSE and tuned hyperparameters for ZINC.
!python main.py --cfg configs/GPS/peptides-func-GPS.yaml  wandb.use False

/content/content
/content/content/GraphGPS
GPU Mem: [2]
GPU Prob: [1.]
Random select GPU, select GPU 0 with mem: 2
[*] Run ID 0: seed=0, split_index=0
    Starting now: 2022-09-21 14:48:58.638062
[*] Loaded dataset 'peptides-functional' from 'OGB':
  Data(edge_index=[2, 4773974], edge_attr=[4773974, 3], x=[2344859, 9], y=[15535, 10])
  undirected: True
  num graphs: 15535
  avg num_nodes/graph: 150
  num node features: 9
  num edge features: 3
  num classes: 10
Precomputing Positional Encoding statistics: ['LapPE'] for all graphs...
  ...estimated to be undirected: True
100% 15535/15535 [01:53<00:00, 136.51it/s]
Done! Took 00:01:54.49
GPSModel(
  (encoder): FeatureEncoder(
    (node_encoder): Concat2NodeEncoder(
      (encoder1): AtomEncoder(
        (atom_embedding_list): ModuleList(
          (0): Embedding(119, 80)
          (1): Embedding(4, 80)
          (2): Embedding(12, 80)
          (3): Embedding(12, 80)
          (4): Embedding(10, 80)
          (5): Embedding(6, 80)
       

In [None]:
%cd ../
%cd GraphGPS

# Running GPS with RWSE and tuned hyperparameters for ZINC.
!python main.py --cfg configs/GPS/TDC-DTI-GPS+RWSE.yaml  wandb.use False

/content
/content/GraphGPS
GPU Mem: [2]
GPU Prob: [1.]
Random select GPU, select GPU 0 with mem: 2
[*] Run ID 0: seed=0, split_index=0
    Starting now: 2022-09-19 18:05:30.848411
Processing...
Found local copy...
Loading...
Done!
Converting SMILES strings into graphs...
100% 910/910 [00:02<00:00, 443.69it/s]
Saving...
Done!
[*] Loaded dataset 'PyG-TDC' from 'PyG-TDC':
  Data(edge_index=[2, 57400], edge_attr=[57400, 3], x=[26722, 9], y=[910])
  undirected: True
  num graphs: 910
  avg num_nodes/graph: 29
  num node features: 9
  num edge features: 3
  num classes: (appears to be a regression task)
Parsed RWSE PE kernel times / steps: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
Precomputing Positional Encoding statistics: ['RWSE'] for all graphs...
  ...estimated to be undirected: True
100% 910/910 [00:01<00:00, 730.52it/s]
Done! Took 00:00:01.27
GPSModel(
  (encoder): FeatureEncoder(
    (node_encoder): Concat2NodeEncoder(
      (encoder1): AtomEncoder(
        (atom_embedd

In [None]:
%cd ../
%cd GraphGPS

# Running GPS with RWSE and tuned hyperparameters for ZINC.
!python main.py --cfg configs/GPS/peptides-func-GPS.yaml  wandb.use False

/content
/content/GraphGPS
GPU Mem: [2]
GPU Prob: [1.]
Random select GPU, select GPU 0 with mem: 2
[*] Run ID 0: seed=0, split_index=0
    Starting now: 2022-09-19 22:38:41.612841
Downloading https://www.dropbox.com/s/ol2v01usvaxbsr8/peptide_multi_class_dataset.csv.gz?dl=1
Downloading https://www.dropbox.com/s/j4zcnx2eipuo0xz/splits_random_stratified_peptide.pickle?dl=1
Processing...
Converting SMILES strings into graphs...
100% 15535/15535 [01:07<00:00, 228.86it/s]
Saving...
Done!
[*] Loaded dataset 'peptides-functional' from 'OGB':
  Data(edge_index=[2, 4773974], edge_attr=[4773974, 3], x=[2344859, 9], y=[15535, 10])
  undirected: True
  num graphs: 15535
  avg num_nodes/graph: 150
  num node features: 9
  num edge features: 3
  num classes: 10
Precomputing Positional Encoding statistics: ['LapPE'] for all graphs...
  ...estimated to be undirected: True
100% 15535/15535 [01:55<00:00, 134.72it/s]
Done! Took 00:01:56.02
GPSModel(
  (encoder): FeatureEncoder(
    (node_encoder): Concat2

In [None]:
%cd ../
%cd GraphGPS

# Running GPS with RWSE and tuned hyperparameters for ZINC.
!python main.py --cfg configs/GPS/peptides-struct-GPS.yaml  wandb.use False

/content
/content/GraphGPS
GPU Mem: [2]
GPU Prob: [1.]
Random select GPU, select GPU 0 with mem: 2
[*] Run ID 0: seed=0, split_index=0
    Starting now: 2022-09-20 00:03:47.307937
Downloading https://www.dropbox.com/s/0d4aalmq4b4e2nh/peptide_structure_normalized_dataset.csv.gz?dl=1
Downloading https://www.dropbox.com/s/9dfifzft1hqgow6/splits_random_stratified_peptide_structure.pickle?dl=1
Processing...
Converting SMILES strings into graphs...
100% 15535/15535 [01:17<00:00, 199.72it/s]
Saving...
Done!
[*] Loaded dataset 'peptides-structural' from 'OGB':
  Data(edge_index=[2, 4773974], edge_attr=[4773974, 3], x=[2344859, 9], y=[15535, 11])
  undirected: True
  num graphs: 15535
  avg num_nodes/graph: 150
  num node features: 9
  num edge features: 3
  num classes: 11
Precomputing Positional Encoding statistics: ['LapPE'] for all graphs...
  ...estimated to be undirected: True
100% 15535/15535 [01:55<00:00, 134.01it/s]
Done! Took 00:01:56.62
GPSModel(
  (encoder): FeatureEncoder(
    (nod