## ML and Molecules

In [50]:
## !pip install --pre deepchem


In [51]:

import torch
import deepchem as dc
import numpy as np

from rdkit import Chem

## Featurize Molecules

In [52]:

smiles = ['C1CCCCC1', 'O1CCOCC1']   # cyclohexane and dioxane
mols   = [Chem.MolFromSmiles(smile) for smile in smiles]
feat = dc.feat.CircularFingerprint(size=1024)
arr = feat.featurize(mols)



In [53]:
arr.shape

(2, 1024)

## Physiochemical featurization

In [54]:

feat = dc.feat.RDKitDescriptors()
arr  = feat.featurize(mols)


In [55]:
arr.shape

(2, 208)

## Training a model to predict solubility

In [56]:

## tasks, datasets, transformers = dc.molnet.load_delaney(featurizer='GraphConv')
tasks, datasets, transformers = dc.molnet.load_delaney(featurizer="ECFP", splitter="scaffold")
train_dataset, valid_dataset, test_dataset = datasets




In [57]:
print(train_dataset.X.shape)
print(train_dataset.y.shape)

(902, 1024)
(902, 1)


In [58]:
print(valid_dataset.X.shape)
print(valid_dataset.y.shape)

(113, 1024)
(113, 1)


In [59]:
print(test_dataset.X.shape)
print(test_dataset.y.shape)

(113, 1024)
(113, 1)


In [60]:

## the graph conv model (GraphConvModel) seems to be better
## but currently only supported for tensorflow

model = dc.models.MultitaskRegressor(n_tasks=1, n_features=1024, layer_sizes =[4000], dropout=0.5)



In [65]:
model.fit(train_dataset, nb_epoch=1000)

0.02957241773605347

In [66]:
metric = dc.metrics.Metric(dc.metrics.pearson_r2_score)

In [67]:

print(model.evaluate(train_dataset, [metric], transformers))
print(model.evaluate(valid_dataset, [metric], transformers))
print(model.evaluate(test_dataset, [metric], transformers))


{'pearson_r2_score': 0.9774312806434012}
{'pearson_r2_score': 0.3031890710094726}
{'pearson_r2_score': 0.5043251356440803}


## Predict

In [71]:

mols = [Chem.MolFromSmiles(s) for s in smiles]
featurizer = dc.feat.CircularFingerprint(size=1024)
x = featurizer.featurize(mols)



In [72]:

predict_solubility = model.predict_on_batch(x)


In [73]:
print(predict_solubility )

[[[-0.27176958]]

 [[ 0.7965225 ]]]
