In [None]:
from pysidt.sidt import read_nodes, write_nodes, MultiEvalSubgraphIsomorphicDecisionTreeBinaryClassifier, Datum
from pysidt.plotting import plot_tree
from pysidt.decomposition import atom_decomposition_noH
from molecule.molecule import Molecule, Group
from molecule.molecule.atomtype import ATOMTYPES
import numpy as np

In [None]:
#In general species of the form [R.]OOH are usually not stable and decompose to R=O + OH, an exception to this is CH2OOH
stable_smiles = ["CC","C","O","CO","CCC","C[CH]C","C[CH]CC","C[CH]OC","C[CH]CO","C=C","C=CC","CCCC","CCCO","COC","CCOC",
                 "[OH]","[CH3]","[CH2]OO", "C[CH2]", "COO", "CCOO","CCCOO","[CH2]CCC","C[CH]OC","C[CH]O", "CC[CH]CC", "OC[CH]CC",
                "C=CCC", "O[CH]CC", "CO[CH]CC","CO[CH]OC", "O=CC", "C=CCCC", "O=CCCC", "CCCCCC", "CCCCCCC", "[CH2]OCO[CH]C",
                 "O[CH]CCCO[CH]CC", "CCC[CH]C",]
unstable_smiles = ["C[CH]OO","CC[CH]OO","O=CC[CH]OO","CCC[CH]OO"]

In [None]:
data = []
for sm in stable_smiles:
    data.append(Datum(Molecule().from_smiles(sm),True))
for sm in unstable_smiles:
    data.append(Datum(Molecule().from_smiles(sm),False))

In [None]:
root = Group().from_adjacency_list("""
1 * R ux px cx
""")

In [None]:
tree = MultiEvalSubgraphIsomorphicDecisionTreeBinaryClassifier(atom_decomposition_noH,root_group=root,
                                               r=[ATOMTYPES[x] for x in ["C","O"]],
                                               r_bonds=[1,2,3],
                                                         r_un=[0,1],
                                               r_site=[], 
                                              )

In [None]:
tree.generate_tree(data=data,max_nodes=100)

In [None]:
#initial trees are much larger than it needs to be because a "good split of data" != "change in classification"
plot_tree(tree)

In [None]:
#We then merge nodes when possible and regularize
tree.trim_tree()
tree.regularize()

In [None]:
#After trimming and regularizing we have a much simpler tree that is easy to evaluate and analyze
plot_tree(tree)

In [None]:
tree.analyze_error()

In [None]:
for name,node in tree.nodes.items():
    print(name)
    print(node.rule)
    print(node.group.to_adjacency_list())