In [21]:
import numpy as np
import dgl
from dgllife.utils import (
    smiles_to_bigraph,
    CanonicalAtomFeaturizer,
    CanonicalBondFeaturizer)
from dgllife.model import load_pretrained

import pandas as pd
from celldreamer.paths import PERT_DATA_DIR

In [34]:
pwd

'/home/icb/alessandro.palma/celldreamer/embeddings/MPNN'

Read sciplex3 data

In [22]:
smiles = pd.read_csv(PERT_DATA_DIR / 'sciplex' / 'sciplex.smiles')

Get unique SMILES 

In [23]:
unique_SMILES = np.unique(smiles.SMILES)
unique_SMILES

array(['*.Cl.N[C@@H]1C[C@H]1c1ccccc1',
       'C#Cc1cccc(Nc2ncnc3cc(OC)c(OCCCCCCC(=O)NO)cc23)c1',
       'C/C(=C\\c1csc(C)n1)C1CC2OC2(C)CCCC(C)C(O)C(C)C(=O)C(C)(C)C(O)CC(=O)O1',
       'C/C(=C\\c1csc(C)n1)C1CC2OC2CCCC(C)C(O)C(C)C(=O)C(C)(C)C(O)CC(=O)O1',
       'C=C1C(=O)OC2/C=C(\\C)CC/C=C(\\C)CCC12',
       'C=CCNC1=C2CC(C)CC(OC)C(O)C(C)/C=C(\\C)C(OC(N)=O)C(OC)/C=C\\C=C(/C)C(=O)NC(=CC1=O)C2=O',
       'CC(/C=C/C(=O)NO)=C\\[C@@H](C)C(=O)c1ccc(N(C)C)cc1',
       'CC(=O)Nc1ccc(C(=O)Nc2ccccc2N)cc1',
       'CC(=O)Nc1ccc(OCC(C)(O)C(=O)Nc2ccc([N+](=O)[O-])c(C(F)(F)F)c2)cc1',
       'CC(=O)Nc1cccc(-n2c(=O)n(C3CC3)c(=O)c3c(Nc4ccc(I)cc4F)n(C)c(=O)c(C)c32)c1',
       'CC(C)(CNC(=O)c1cccc(-c2noc(C(F)(F)F)n2)c1)c1coc(-c2ccccc2)n1',
       'CC(C)[C@H](C(=O)Nc1ccc(C(=O)NO)cc1)c1ccccc1',
       'CC(NC(=O)c1ccccc1/N=C/c1c(O)ccc2ccccc12)c1ccccc1',
       'CC(Oc1cc(-c2cnn(C3CCNCC3)c2)cnc1N)c1c(Cl)ccc(F)c1Cl',
       'CC(c1cc2ccccc2s1)N(O)C(N)=O',
       'CC1(C(=O)Nc2ccc(F)nc2)CCCN1c1nc(Nc2cc(C3CC3)[nH]

Construct molecular graph

In [24]:
node_feats = CanonicalAtomFeaturizer(atom_data_field="h")
edge_feats = CanonicalBondFeaturizer(bond_data_field="h", self_loop=True)

In [25]:
mol_graphs = []

for smiles in unique_SMILES: 
    mol_graphs.append(smiles_to_bigraph(
        smiles=smiles,
        add_self_loop=True,
        node_featurizer=node_feats,
        edge_featurizer=edge_feats,
    ))

In [26]:
mol_graphs

[Graph(num_nodes=12, num_edges=34,
       ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
       edata_schemes={'h': Scheme(shape=(13,), dtype=torch.float32)}),
 Graph(num_nodes=32, num_edges=100,
       ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
       edata_schemes={'h': Scheme(shape=(13,), dtype=torch.float32)}),
 Graph(num_nodes=35, num_edges=109,
       ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
       edata_schemes={'h': Scheme(shape=(13,), dtype=torch.float32)}),
 Graph(num_nodes=34, num_edges=106,
       ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
       edata_schemes={'h': Scheme(shape=(13,), dtype=torch.float32)}),
 Graph(num_nodes=17, num_edges=53,
       ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
       edata_schemes={'h': Scheme(shape=(13,), dtype=torch.float32)}),
 Graph(num_nodes=42, num_edges=128,
       ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
       edata_sche

In [27]:
print(f'Number of molecular graphs: {len(mol_graphs)}')

Number of molecular graphs: 188


Batch the molecular graphs for prediction 

In [28]:
mol_batch = dgl.batch(mol_graphs)

In [29]:
model_name = 'GCN_canonical_PCBA'

Load pre-trained model and predict

In [30]:
model = load_pretrained(model_name)

Downloading GCN_canonical_PCBA_pre_trained.pth from https://data.dgl.ai/dgllife/pre_trained/gcn_canonical_pcba.pth...
Pretrained model loaded


In [31]:
model.eval()

GCNPredictor(
  (gnn): GCN(
    (gnn_layers): ModuleList(
      (0): GCNLayer(
        (graph_conv): GraphConv(in=74, out=128, normalization=none, activation=<function relu at 0x7f4c7b7fa950>)
        (dropout): Dropout(p=0.053320999462421345, inplace=False)
        (res_connection): Linear(in_features=74, out_features=128, bias=True)
        (bn_layer): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): GCNLayer(
        (graph_conv): GraphConv(in=128, out=128, normalization=none, activation=<function relu at 0x7f4c7b7fa950>)
        (dropout): Dropout(p=0.053320999462421345, inplace=False)
        (res_connection): Linear(in_features=128, out_features=128, bias=True)
        (bn_layer): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (readout): WeightedSumAndMax(
    (weight_and_sum): WeightAndSum(
      (atom_weighting): Sequential(
        (0): Linear(in_features=128, out_features=

In [32]:
prediction = model(mol_batch, mol_batch.ndata['h'])

Convert to data frame and save

In [35]:
df = pd.DataFrame(prediction.detach().numpy())
df.index = unique_SMILES
df.to_csv('./data/embeddings.csv')