Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-evaluation Binary Classifier #15

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open

Conversation

mjohnson541
Copy link
Collaborator

@mjohnson541 mjohnson541 commented May 6, 2024

This adds an SIDT algorithm for Multi-Evaluation binary classification.

It also adds some smaller improvements:
Allows plotting only to specified depth
Saves rules as well as nodes in postpruning
allows specification of an initial set of splits from the root node

An example notebook for unstable Q.OOH classification is provided.

Copy link
Collaborator

@hwpang hwpang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

if node.depth <= depth:
n = pydot.Node(name=name, label=name, fontname="Helvetica", fontsize="16")
if images:
img = node.group.draw("png")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
img = node.group.draw("png")
img = node.group.draw("pdf")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better quality of figures when writing paper and making slides. The figures don't have fuzzy edges when you enlarge them. Additionally, overleaf compiles faster if you include all your figures as pdf.

@@ -1,27 +1,31 @@
from IPython.display import Image, display
import pydot
import os
import numpy as np
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import numpy as np
import numpy as np
from pathlib import Path

Comment on lines 9 to 11
graph.set_fontsize("10")
if not os.path.exists("./tree"):
os.makedirs("./tree")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we change all the use of os to pathlib? pathlib is the more modern way since Python 3.4

Suggested change
graph.set_fontsize("10")
if not os.path.exists("./tree"):
os.makedirs("./tree")
graph.set_fontsize("10")
save_dir = Path("./tree")
save_dir.mkdir(exist_ok=True)

Comment on lines +18 to +20
with open("./tree/" + node.name + ".png", "wb") as f:
f.write(img)
n.set_image(os.path.abspath("./tree/" + node.name + ".png"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
with open("./tree/" + node.name + ".png", "wb") as f:
f.write(img)
n.set_image(os.path.abspath("./tree/" + node.name + ".png"))
node_save_path = (save_dir / node.name + ".pdf").resolve()
with open(node_save_path, "wb") as f:
f.write(img)
n.set_image(node_save_path))

Comment on lines 29 to 30
graph.write_dot("./tree/tree.dot")
graph.write_png("./tree/tree.png")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
graph.write_dot("./tree/tree.dot")
graph.write_png("./tree/tree.png")
graph.write_dot(save_dir / "tree.dot")
graph.write_png(save_dir / "tree.png")

sidt_val_values = [self.evaluate(d.mol) for d in self.validation_set]
true_val_values = [d.value for d in self.validation_set]

P,N,PP,PN,TP,FN,FP,TN = analyze_binary_classification(sidt_train_values,true_train_values)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure there are metrics functions to compute accuracy, recall, precision, etc. we can just import from sklearn to make the code cleaner. Can you do that?

Suggested change
P,N,PP,PN,TP,FN,FP,TN = analyze_binary_classification(sidt_train_values,true_train_values)
P, N, PP, PN, TP, FN, FP, TN = analyze_binary_classification(sidt_train_values, true_train_values)

}
self.best_rule_map = {name:self.nodes[name].rule for name in self.best_tree_nodes}

logging.info("# nodes: {}".format(len(self.nodes)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logging.info("# nodes: {}".format(len(self.nodes)))
logging.info(f"# nodes: {len(self.nodes)}")

2) merges nodes with their parents if they do not result in different predictions
"""

self.datum_truth_map = {datum:[getattr(n,"rule") for n in self.mol_node_maps[datum]["nodes"]] for datum in self.datums}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.datum_truth_map = {datum:[getattr(n,"rule") for n in self.mol_node_maps[datum]["nodes"]] for datum in self.datums}
self.datum_truth_map = {datum: [getattr(n,"rule") for n in self.mol_node_maps[datum]["nodes"]] for datum in self.datums}

"""

self.datum_truth_map = {datum:[getattr(n,"rule") for n in self.mol_node_maps[datum]["nodes"]] for datum in self.datums}
self.datum_node_map = {datum:[n for n in self.mol_node_maps[datum]["nodes"]] for datum in self.datums}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.datum_node_map = {datum:[n for n in self.mol_node_maps[datum]["nodes"]] for datum in self.datums}
self.datum_node_map = {datum: [n for n in self.mol_node_maps[datum]["nodes"]] for datum in self.datums}


assert len(new) == 0
assert len(comp) == 0
pnew = new_class_true/Nnew
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pnew = new_class_true/Nnew
pnew = new_class_true / Nnew

"source": [
"data = []\n",
"for sm in stable_smiles:\n",
" data.append(Datum(Molecule().from_smiles(sm),True))\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
" data.append(Datum(Molecule().from_smiles(sm),True))\n",
" data.append(Datum(Molecule().from_smiles(sm), True))\n",

Can you run black formatter through the notebook too

@hwpang
Copy link
Collaborator

hwpang commented May 8, 2024

Can you also add a pytest for this new type of tree? I don't think it's best practice to only rely on notebook test. The --nbmake option doesn't provide a very comprehensive report.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants