In [11]:
import pandas as pd
import numpy as np
# import matplotlib.pyplot as plt
from sklearn.preprocessing import normalize
from deepchem.utils.vina_utils import prepare_inputs
import deepchem as dc
from rdkit import Chem
import scipy

from deepchem.utils.evaluate import Evaluator
from sklearn.ensemble import RandomForestRegressor

In [3]:
structure_df = pd.read_csv("../data/structures.csv")
structures = np.array(structure_df['smiles'])

In [4]:
affinity_df = pd.read_csv("../data/test.csv")
affinities = np.array(affinity_df[' G'])
affinities_normalized = affinities / np.linalg.norm(affinities)

In [5]:
fp_featurizer = dc.feat.CircularFingerprint(size=2048)
mols = [Chem.MolFromSmiles(l) for l in structures]
features = fp_featurizer.featurize(mols)

In [6]:
pdbid = '1O7S'
dataset = dc.data.NumpyDataset(X=features, y=affinities_normalized, ids=[pdbid for i in range(len(structures))])
train_dataset, test_dataset = dc.splits.RandomSplitter().train_test_split(dataset, seed=42)

In [7]:
seed = 42
sklearn_model = RandomForestRegressor(n_estimators=100, max_features='sqrt')
sklearn_model.random_state = seed
model = dc.models.SklearnModel(sklearn_model)
model.fit(train_dataset)

In [8]:
# use Pearson correlation so metrics are > 0
metric = dc.metrics.Metric(dc.metrics.pearson_r2_score)

evaluator = Evaluator(model, train_dataset, [])
train_r2score = evaluator.compute_model_performance([metric])
print("RF Train set R^2 %f" % (train_r2score["pearson_r2_score"]))

evaluator = Evaluator(model, test_dataset, [])
test_r2score = evaluator.compute_model_performance([metric])
print("RF Test set R^2 %f" % (test_r2score["pearson_r2_score"]))

RF Train set R^2 0.938019
RF Test set R^2 0.230368


In [12]:
x = model.predict(test_dataset)
y = test_dataset.y
slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(x, y)
print(r_value**2)
print(p_value)

0.23036782003583386
0.03755134148687876


In [22]:
x = model.predict(train_dataset)
y = train_dataset.y
slope, intercept, r_value, p_value, std_err = scipy.stats.linregress(x, y)
print(r_value**2)
print(p_value)

0.9380193900606827
7.936737616213138e-46
