Browse files

final commit before going public

  • Loading branch information...
1 parent 09bc037 commit 901451fedf549531b8ef0320b1d32f3f2db2aa34 @sanity committed Sep 20, 2011
View
6 .settings/com.chronon.sb.browser.launcher.prefs
@@ -0,0 +1,6 @@
+#Tue Sep 20 09:28:22 CDT 2011
+auto.defaultPkgs=
+auto.excludePkgs=
+auto.partialPkgs=
+eclipse.preferences.version=1
+isManual=false
View
6 src/main/java/com/moboscope/quickdt/Branch.java
@@ -62,5 +62,11 @@ public void dump(final int indent, final PrintStream ps) {
}
public abstract String toNotString();
+
+ @Override
+ protected void calcMeanDepth(final LeafDepthStats stats) {
+ trueChild.calcMeanDepth(stats);
+ falseChild.calcMeanDepth(stats);
+ }
}
View
6 src/main/java/com/moboscope/quickdt/Leaf.java
@@ -49,4 +49,10 @@ public int size() {
return 1;
}
+ @Override
+ protected void calcMeanDepth(final LeafDepthStats stats) {
+ stats.ttlDepth += label.depth * label.exampleCount;
+ stats.ttlSamples += label.exampleCount;
+ }
+
}
View
18 src/main/java/com/moboscope/quickdt/MoboTest.java
@@ -6,7 +6,9 @@
import org.json.simple.*;
-import com.google.common.collect.Lists;
+import com.google.common.collect.*;
+import com.moboscope.quickdt.TreeBuilder.Scorer;
+import com.moboscope.quickdt.scorers.Scorer1;
public class MoboTest {
@@ -34,13 +36,15 @@ public static void main(final String[] args) throws Exception {
System.out.println("Read " + instances.size() + " instances");
- final TreeBuilder tb = new TreeBuilder();
+ for (final Scorer scorer : Sets.newHashSet(new Scorer1())) {
+ final TreeBuilder tb = new TreeBuilder(scorer);
- final long startTime = System.currentTimeMillis();
- final Node tree = tb.buildTree(instances, 100, 1.0);
- System.out.println("Build time: " + (System.currentTimeMillis() - startTime));
-
- tree.dump(System.out);
+ final long startTime = System.currentTimeMillis();
+ final Node tree = tb.buildTree(instances, 100, 1.0);
+ System.out.println(scorer.getClass().getSimpleName() + " build time "
+ + (System.currentTimeMillis() - startTime) + ", size: " + tree.size() + " mean depth: "
+ + tree.meanDepth());
+ }
}
}
View
74 src/main/java/com/moboscope/quickdt/Node.java
@@ -3,34 +3,79 @@
import java.io.*;
public abstract class Node implements Serializable {
- public abstract Label getLabel(Attributes attributes);
-
- public abstract int size();
+ public abstract void dump(int indent, PrintStream ps);
+ /**
+ * Writes a textual representation of this tree to a PrintStream
+ *
+ * @param ps
+ */
public void dump(final PrintStream ps) {
dump(0, ps);
}
- public abstract void dump(int indent, PrintStream ps);
+ /**
+ * Get a label for a given set of Attributes
+ *
+ * @param attributes
+ * @return
+ */
+ public abstract Label getLabel(Attributes attributes);
+
+ /**
+ * Return the mean depth of leaves in the tree. A lower number generally
+ * indicates that the decision tree learner has done a better job.
+ *
+ * @return
+ */
+ public double meanDepth() {
+ final LeafDepthStats stats = new LeafDepthStats();
+ calcMeanDepth(stats);
+ return (double) stats.ttlDepth / stats.ttlSamples;
+ }
+
+ /**
+ * Return the number of nodes in this decision tree.
+ *
+ * @return
+ */
+ public abstract int size();
+
+ protected abstract void calcMeanDepth(LeafDepthStats stats);
public static class Label implements Serializable {
private static final long serialVersionUID = -4063175796311033721L;
- public Label(final Serializable output, final int depth, final int exampleCount, final double probability) {
- this.output = output;
- this.depth = depth;
- this.exampleCount = exampleCount;
- this.probability = probability;
- }
+ /**
+ * How deep in the tree is this label? A lower number typically
+ * indicates a more confident classification.
+ */
+ public int depth;
+ /**
+ * How many training examples matched this leaf? A higher number
+ * indicates a more confident classification.
+ */
public final int exampleCount;
+ /**
+ * What label was assigned by this leaf?
+ */
public Serializable output;
- public int depth;
-
+ /**
+ * What is the probability that this classification is correct based on
+ * the training data?
+ */
public double probability;
+ public Label(final Serializable output, final int depth, final int exampleCount, final double probability) {
+ this.output = output;
+ this.depth = depth;
+ this.exampleCount = exampleCount;
+ this.probability = probability;
+ }
+
@Override
public String toString() {
final StringBuilder builder = new StringBuilder();
@@ -47,4 +92,9 @@ public String toString() {
}
}
+
+ protected static class LeafDepthStats {
+ int ttlDepth = 0;
+ int ttlSamples = 0;
+ }
}
View
19 src/main/java/com/moboscope/quickdt/TreeBuilder.java
@@ -69,11 +69,21 @@ public Node buildTree(final Iterable<Instance> trainingData, final int depth, fi
final Instance sampleInstance = Iterables.get(trainingData, 0);
+ boolean smallTrainingSet = true;
+ int tsCount = 0;
+ for (final Instance i : trainingData) {
+ tsCount++;
+ if (tsCount > 20) {
+ smallTrainingSet = false;
+ break;
+ }
+ }
+
Branch bestNode = null;
double bestScore = 0;
for (final Entry<String, Serializable> e : sampleInstance.attributes.entrySet()) {
Pair<? extends Branch, Double> thisPair;
- if (e.getValue() instanceof Number) {
+ if (!smallTrainingSet && e.getValue() instanceof Number) {
thisPair = createOrdinalNode(e.getKey(), trainingData, splits.get(e.getKey()));
} else {
thisPair = createNominalNode(e.getKey(), trainingData);
@@ -84,13 +94,8 @@ public Node buildTree(final Iterable<Instance> trainingData, final int depth, fi
}
}
- if (bestNode == null) {
- final StringBuilder sb = new StringBuilder();
- for (final Instance i : trainingData) {
- sb.append(i.toString());
- }
+ if (bestNode == null)
return thisLeaf;
- }
bestNode.trueChild = buildTree(Lists.newLinkedList(Iterables.filter(trainingData, bestNode.getInPredicate())),
depth + 1, maxDepth, minProbability, splits);
View
4 src/main/java/com/moboscope/quickdt/scorers/Scorer1.java
@@ -8,6 +8,10 @@
public class Scorer1 implements Scorer {
+ /*
+ * The best scorer so far, fast with small trees
+ */
+
public double scoreSplit(final int aTtl, final Map<Serializable, Integer> a, final int bTtl,
final Map<Serializable, Integer> b) {
double score = 0;
View
10 src/main/java/com/moboscope/quickdt/scorers/Scorer2.java
@@ -8,6 +8,10 @@
public class Scorer2 implements Scorer {
+ /*
+ * The best scorer so far, fast with small trees
+ */
+
public double scoreSplit(final int aTtl, final Map<Serializable, Integer> a, final int bTtl,
final Map<Serializable, Integer> b) {
double score = 0;
@@ -21,10 +25,10 @@ public double scoreSplit(final int aTtl, final Map<Serializable, Integer> a, fin
bCount = 0;
}
- final double aProp = aCount;
- final double bProp = bCount;
+ final double aProp = (double) aCount / aTtl;
+ final double bProp = (double) bCount / bTtl;
- score += Math.abs(aProp - bProp) * Math.min(aTtl, bTtl);
+ score += Math.abs(aProp - bProp) * Math.pow(Math.min(aTtl, bTtl), 2);
}
return score;
}
View
6 src/main/java/com/moboscope/quickdt/scorers/Scorer3.java
@@ -8,6 +8,10 @@
public class Scorer3 implements Scorer {
+ /*
+ * The best scorer so far, fast with small trees
+ */
+
public double scoreSplit(final int aTtl, final Map<Serializable, Integer> a, final int bTtl,
final Map<Serializable, Integer> b) {
double score = 0;
@@ -24,7 +28,7 @@ public double scoreSplit(final int aTtl, final Map<Serializable, Integer> a, fin
final double aProp = (double) aCount / aTtl;
final double bProp = (double) bCount / bTtl;
- score += Math.abs(aProp - bProp);
+ score += Math.abs(aProp - bProp) * Math.sqrt(Math.min(aTtl, bTtl));
}
return score;
}
View
34 src/main/java/com/moboscope/quickdt/scorers/Scorer4.java
@@ -1,34 +0,0 @@
-package com.moboscope.quickdt.scorers;
-
-import java.io.Serializable;
-import java.util.Map;
-
-import com.google.common.collect.Sets;
-import com.moboscope.quickdt.TreeBuilder.Scorer;
-
-public class Scorer4 implements Scorer {
-
- public double scoreSplit(final int aTtl, final Map<Serializable, Integer> a, final int bTtl,
- final Map<Serializable, Integer> b) {
- if (aTtl == 0 || bTtl == 0)
- return 0;
- double score = 0;
- for (final Serializable value : Sets.union(a.keySet(), b.keySet())) {
- Integer aCount = a.get(value);
- if (aCount == null) {
- aCount = 0;
- }
- Integer bCount = b.get(value);
- if (bCount == null) {
- bCount = 0;
- }
-
- final double aProp = aCount;
- final double bProp = bCount;
-
- score += Math.abs(aProp - bProp);
- }
- return score;
- }
-
-}
View
7 src/test/java/com/uprizer/quickdt/TreeBuilderTest.java
@@ -25,7 +25,7 @@ public void simpleBmiTest() {
}
- @Test(enabled = false)
+ @Test()
public void multiScorerBmiTest() {
final Set<Instance> instances = Sets.newHashSet();
@@ -50,11 +50,6 @@ public void multiScorerBmiTest() {
System.out.println("Scorer3 tree size: " + tree.size());
}
{
- final TreeBuilder tb = new TreeBuilder(new Scorer4());
- final Node tree = tb.buildTree(instances, 100, 1.0);
- System.out.println("Scorer4 tree size: " + tree.size());
- }
- {
final TreeBuilder tb = new TreeBuilder(new CorrectClassificationProbScorer());
final Node tree = tb.buildTree(instances, 100, 1.0);
System.out.println("CorrectClassificationProbScorer tree size: " + tree.size());

0 comments on commit 901451f

Please sign in to comment.