In [None]:
import json
import random
import scipy
import numpy as np

from pysidt import Datum, MultiEvalSubgraphIsomorphicDecisionTreeRegressor
from pysidt.decomposition import atom_decomposition_noH
from pysidt.plotting import plot_tree

from molecule.molecule import Group, Molecule, ATOMTYPES

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
with open("../data/rmgdbH298CHOstablenoringnoads.json", "r") as f:
    data = json.load(f)

In [None]:
training_data = []
for i, x in enumerate(data):
    d = Datum(Molecule().from_adjacency_list(x[0], check_consistency=True), x[1])
    dup = False
    for td in training_data:
        if td.mol.is_isomorphic(d.mol):
            dup = True
            break
    if not dup:
        training_data.append(d)


In [None]:
train = training_data

In [None]:
root = Group().from_adjacency_list("""1 * R u0 px cx""")

sidt = MultiEvalSubgraphIsomorphicDecisionTreeRegressor(
    atom_decomposition_noH,
    root_group=root,
    r=[ATOMTYPES[x] for x in ["C", "O"]],
    r_bonds=[1, 2, 3, 1.5],
    r_un=[0],
    fract_nodes_expand_per_iter=0.1,
)

In [None]:
sidt.generate_tree(data=train, max_nodes=120)

In [None]:
sidt.regularize(data=train, check_data=True)

In [None]:
sidt.nodes

In [None]:
plot_tree(sidt, images=True)

In [None]:
errors = []
pred = []
actual = []
uncs = []
for d in train:
    #uncomment for leave-one-out errors rather than training errors, commented for tests
    #ds = train[:]
    #ds.remove(d)
    #sidt.fit_tree(data=ds,check_data=True)
    #sidt.estimate_uncertainty()
    pval,unc = sidt.evaluate(d.mol, estimate_uncertainty=True)
    pred.append(pval)
    actual.append(d.value)
    uncs.append(unc)
    error = sidt.evaluate(d.mol) - d.value
    errors.append(error)

In [None]:
np.mean(np.abs(np.array(errors))) / 4184.0

In [None]:
data_var = np.var(np.array([d.value for d in train]))

In [None]:
data_var

In [None]:
error_var = np.var(np.array(errors))

In [None]:
error_var/data_var

In [None]:
plt.hist(np.array(errors)/4184.0,bins=22)

In [None]:
plt.scatter(np.array(pred)/4184.0,np.array(actual)/4184.0,c=np.array(uncs)/4184.0, cmap="viridis")
plt.plot(np.linspace(-150,160),np.linspace(-150,160))
cbar = plt.colorbar()
plt.xlabel("SIDT Hf298 [kcal/mol]")
plt.ylabel("Actual HF298 [kcal/mol]")
cbar.set_label("Predicted Standard Deviation in Hf298 [kcal/mol]")

In [None]:
confidences = np.linspace(0,1,100)
fracts = np.zeros(len(confidences))
for j,conf in enumerate(confidences):
    for i in range(len(pred)):
        h = uncs[i] * scipy.stats.norm.ppf((1+conf)/2.0)
        err = abs(pred[i]-actual[i])
        if h > err:
            fracts[j] += 1
fracts  = fracts/ len(pred)

In [None]:
plt.plot(confidences,confidences)
plt.plot(confidences,fracts)
plt.xlabel("Confidence Level")
plt.ylabel("Proportion Correct")