Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-evaluation Binary Classifier #15

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ jobs:
run: |
pytest --nbmake IPython/multi_eval_SIDT_example.ipynb
pytest --nbmake IPython/Surface_Diffusion_single_eval_SIDT_example.ipynb
pytest --nbmake IPython/Unstable_QOOH_Binary_Classification.ipynb
158 changes: 158 additions & 0 deletions IPython/Unstable_QOOH_Binary_Classification.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "2e602b7f-c603-4de2-a05d-9f0c73d41b0b",
"metadata": {},
"outputs": [],
"source": [
"from pysidt.sidt import read_nodes, write_nodes, MultiEvalSubgraphIsomorphicDecisionTreeBinaryClassifier, Datum\n",
"from pysidt.plotting import plot_tree\n",
"from pysidt.decomposition import atom_decomposition_noH\n",
"from molecule.molecule import Molecule, Group\n",
"from molecule.molecule.atomtype import ATOMTYPES\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "87a03282-d729-499a-9479-5001a1b1d49e",
"metadata": {},
"outputs": [],
"source": [
"#In general species of the form [R.]OOH are usually not stable and decompose to R=O + OH, an exception to this is CH2OOH\n",
"stable_smiles = [\"CC\",\"C\",\"O\",\"CO\",\"CCC\",\"C[CH]C\",\"C[CH]CC\",\"C[CH]OC\",\"C[CH]CO\",\"C=C\",\"C=CC\",\"CCCC\",\"CCCO\",\"COC\",\"CCOC\",\n",
" \"[OH]\",\"[CH3]\",\"[CH2]OO\", \"C[CH2]\", \"COO\", \"CCOO\",\"CCCOO\",\"[CH2]CCC\",\"C[CH]OC\",\"C[CH]O\", \"CC[CH]CC\", \"OC[CH]CC\",\n",
" \"C=CCC\", \"O[CH]CC\", \"CO[CH]CC\",\"CO[CH]OC\", \"O=CC\", \"C=CCCC\", \"O=CCCC\", \"CCCCCC\", \"CCCCCCC\", \"[CH2]OCO[CH]C\",\n",
" \"O[CH]CCCO[CH]CC\", \"CCC[CH]C\",]\n",
"unstable_smiles = [\"C[CH]OO\",\"CC[CH]OO\",\"O=CC[CH]OO\",\"CCC[CH]OO\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bd624b7e-f9aa-4fba-bf09-2898d747f441",
"metadata": {},
"outputs": [],
"source": [
"data = []\n",
"for sm in stable_smiles:\n",
" data.append(Datum(Molecule().from_smiles(sm),True))\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
" data.append(Datum(Molecule().from_smiles(sm),True))\n",
" data.append(Datum(Molecule().from_smiles(sm), True))\n",

Can you run black formatter through the notebook too

"for sm in unstable_smiles:\n",
" data.append(Datum(Molecule().from_smiles(sm),False))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e49e1fc1-6ea3-43a3-b76f-166e3b951079",
"metadata": {},
"outputs": [],
"source": [
"root = Group().from_adjacency_list(\"\"\"\n",
"1 * R ux px cx\n",
"\"\"\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1284bbee-a3ca-4f43-9068-2456e7d15a57",
"metadata": {},
"outputs": [],
"source": [
"tree = MultiEvalSubgraphIsomorphicDecisionTreeBinaryClassifier(atom_decomposition_noH,root_group=root,\n",
" r=[ATOMTYPES[x] for x in [\"C\",\"O\"]],\n",
" r_bonds=[1,2,3],\n",
" r_un=[0,1],\n",
" r_site=[], \n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "84c4e20f-c6e6-4208-b876-d016a372f387",
"metadata": {},
"outputs": [],
"source": [
"tree.generate_tree(data=data,max_nodes=100)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2dc98a30-4cd7-4f5d-9d9b-84720886155e",
"metadata": {},
"outputs": [],
"source": [
"#initial trees are much larger than it needs to be because a \"good split of data\" != \"change in classification\"\n",
"plot_tree(tree)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7145d6f0-8e07-4804-a565-8bfadf09955c",
"metadata": {},
"outputs": [],
"source": [
"#We then merge nodes when possible and regularize\n",
"tree.trim_tree()\n",
"tree.regularize()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8f726167-e1ac-40e1-af71-edf53105cf66",
"metadata": {},
"outputs": [],
"source": [
"#After trimming and regularizing we have a much simpler tree that is easy to evaluate and analyze\n",
"plot_tree(tree)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "24a88f4a-04a0-427d-804a-860d840aec6d",
"metadata": {},
"outputs": [],
"source": [
"tree.analyze_error()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "437dbc01-e1af-4c2c-9b34-bebdc1588838",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
40 changes: 40 additions & 0 deletions pysidt/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,3 +1125,43 @@ def specify_bond_extensions(grp, i, j, basename, r_bonds):
)
)
return grps

def generate_extensions_reverse(grp,structs):
"""
This function is designed to generate extensions by reverse engineering the structures being split rather than extending the original group
This should be a reliable fallback when traditional extension generation becomes to expensive
"""
exts = []
for st in structs:
new_struct = st.to_group()
atoms = new_struct.atoms
aexts = []
smallest_not_matching_group = None
for ind in range(len(st.atoms))[::-1]:
old_struct = new_struct
new_struct = old_struct.copy(deep=True)
at = new_struct.atoms[ind]
if at.label: #don't remove any labeled atoms
continue
new_struct.remove_atom(at)
if not new_struct.is_subgraph_isomorphic(grp,generate_initial_map=True,save_order=True): #removing that atom broke isomorphism with original group so don't delete that atom
new_struct = old_struct
continue
else:
boos = np.array([item.is_subgraph_isomorphic(new_struct,generate_initial_map=True,save_order=True) for item in structs if item != st])
if boos.all(): #suddenly matches all groups...don't remove that atom
new_struct = old_struct
continue
elif boos.any(): #splits groups
aexts.append(new_struct)
else:
if smallest_not_matching_group is None or len(smallest_not_matching_group.atoms) > len(new_struct.atoms):
smallest_not_matching_group = new_struct
continue
if len(aexts):
minlen = min([len(g.atoms) for g in aexts])
exts.extend([g for g in aexts if len(g.atoms) == minlen])
else:
exts.append(smallest_not_matching_group)

return exts
32 changes: 19 additions & 13 deletions pysidt/plotting.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,33 @@
from IPython.display import Image, display
import pydot
import os
import numpy as np
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import numpy as np
import numpy as np
from pathlib import Path



def plot_tree(sidt, images=True):
def plot_tree(sidt, images=True, depth=np.inf):
graph = pydot.Dot("treestruct", graph_type="digraph", overlap="false")
graph.set_fontname("sans")
graph.set_fontsize("10")
if not os.path.exists("./tree"):
os.makedirs("./tree")
Comment on lines 9 to 11
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we change all the use of os to pathlib? pathlib is the more modern way since Python 3.4

Suggested change
graph.set_fontsize("10")
if not os.path.exists("./tree"):
os.makedirs("./tree")
graph.set_fontsize("10")
save_dir = Path("./tree")
save_dir.mkdir(exist_ok=True)

out_nodes = dict()
index = -1
for name, node in sidt.nodes.items():
n = pydot.Node(name=name, label=name, fontname="Helvetica", fontsize="16")
if images:
img = node.group.draw("png")
with open("./tree/" + node.name + ".png", "wb") as f:
f.write(img)
n.set_image(os.path.abspath("./tree/" + node.name + ".png"))
n.set_label(" ")
graph.add_node(n)
for name, node in sidt.nodes.items():
index += 1
if node.depth <= depth:
n = pydot.Node(name=name, label=name, fontname="Helvetica", fontsize="16")
if images and node.group is not None:
img = node.group.draw("png")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
img = node.group.draw("png")
img = node.group.draw("pdf")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better quality of figures when writing paper and making slides. The figures don't have fuzzy edges when you enlarge them. Additionally, overleaf compiles faster if you include all your figures as pdf.

with open("./tree/" + str(index) + ".png", "wb") as f:
f.write(img)
n.set_image(os.path.abspath("./tree/" + str(index) + ".png"))
n.set_label(" ")
graph.add_node(n)
out_nodes[name] = node
for name, node in out_nodes.items():
for nod in node.children:
edge = pydot.Edge(node.name, nod.name)
graph.add_edge(edge)
if nod.name in out_nodes.keys():
edge = pydot.Edge(node.name, nod.name)
graph.add_edge(edge)
graph.write_dot("./tree/tree.dot")
graph.write_png("./tree/tree.png")
Comment on lines 31 to 32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
graph.write_dot("./tree/tree.dot")
graph.write_png("./tree/tree.png")
graph.write_dot(save_dir / "tree.dot")
graph.write_png(save_dir / "tree.png")

plt = Image("./tree/tree.png")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
plt = Image("./tree/tree.png")
plt = Image(save_dir / "tree.pdf")

Expand Down
3 changes: 3 additions & 0 deletions pysidt/regularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def simple_regularization(node, Rx, Rbonds, Run, Rsite, Rmorph, test=True):

grp = node.group
data = node.items

if grp is None:
return

R = Rx[:]
if ATOMTYPES["X"] in R:
Expand Down
Loading