diff --git a/pom.xml b/pom.xml index 5e87d209..309388c1 100644 --- a/pom.xml +++ b/pom.xml @@ -32,7 +32,7 @@ be accompanied by a bump in version number, regardless of how minor the change. --> - 0.7.13 + 0.7.14 sanity-maven-repo diff --git a/src/main/java/quickml/supervised/classifier/Classifiers.java b/src/main/java/quickml/supervised/classifier/Classifiers.java index 460dd7c8..39756d2b 100644 --- a/src/main/java/quickml/supervised/classifier/Classifiers.java +++ b/src/main/java/quickml/supervised/classifier/Classifiers.java @@ -24,8 +24,10 @@ import quickml.supervised.predictiveModelOptimizer.PredictiveModelOptimizerBuilder; import quickml.supervised.predictiveModelOptimizer.fieldValueRecommenders.FixedOrderRecommender; +import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.Set; import static quickml.supervised.classifier.decisionTree.TreeBuilder.*; @@ -64,8 +66,8 @@ public static Pair, Downsampl return new Pair, DownsamplingClassifier>(bestParams, downsamplingClassifier); } - public static Pair, DownsamplingClassifier> getOptimizedDownsampledRandomForest(List trainingData, int rebuildsPerValidation, double fractionOfDataForValidation, ClassifierLossFunction lossFunction, DateTimeExtractor dateTimeExtractor, DownsamplingClassifierBuilder modelBuilder) { - Map config = createConfig(); + public static Pair, DownsamplingClassifier> getOptimizedDownsampledRandomForest(List trainingData, int rebuildsPerValidation, double fractionOfDataForValidation, ClassifierLossFunction lossFunction, DateTimeExtractor dateTimeExtractor, DownsamplingClassifierBuilder modelBuilder, Set exemptAttributes) { + Map config = createConfig(exemptAttributes); return getOptimizedDownsampledRandomForest(trainingData, rebuildsPerValidation, fractionOfDataForValidation, lossFunction, dateTimeExtractor, modelBuilder, config); } @@ -74,13 +76,12 @@ public static Pair, Downsampl return getOptimizedDownsampledRandomForest(trainingData, rebuildsPerValidation, fractionOfDataForValidation, lossFunction, dateTimeExtractor, modelBuilder, config); } - public static Pair, DownsamplingClassifier> getOptimizedDownsampledRandomForest(List trainingData, int rebuildsPerValidation, double fractionOfDataForValidation, ClassifierLossFunction lossFunction, DateTimeExtractor dateTimeExtractor) { - Map config = createConfig(); + public static Pair, DownsamplingClassifier> getOptimizedDownsampledRandomForest(List trainingData, int rebuildsPerValidation, double fractionOfDataForValidation, ClassifierLossFunction lossFunction, DateTimeExtractor dateTimeExtractor, Set exemptAttributes) { + Map config = createConfig(exemptAttributes); return getOptimizedDownsampledRandomForest(trainingData, rebuildsPerValidation, fractionOfDataForValidation, lossFunction, dateTimeExtractor, config); } private static int getTimeSliceHours(List trainingData, int rebuildsPerValidation, DateTimeExtractor dateTimeExtractor) { - Utils.sortTrainingInstancesByTime(trainingData, dateTimeExtractor); DateTime latestDateTime = dateTimeExtractor.extractDateTime(trainingData.get(trainingData.size()-1)); int indexOfEarliestValidationInstance = (int) (0.8 * trainingData.size()) - 1; @@ -91,7 +92,7 @@ private static int getTimeSliceHours(List training } - private static Map createConfig() { + private static Map createConfig(Set exemptAttributes) { Map config = Maps.newHashMap(); config.put(MAX_DEPTH, new FixedOrderRecommender(4, 8, 12));//Integer.MAX_VALUE, 2, 3, 5, 6, 9)); config.put(MIN_OCCURRENCES_OF_ATTRIBUTE_VALUE, new FixedOrderRecommender(7, 10)); @@ -99,6 +100,8 @@ private static Map createConfig() { config.put(DownsamplingClassifierBuilder.MINORITY_INSTANCE_PROPORTION, new FixedOrderRecommender(.1, .25)); config.put(DEGREE_OF_GAIN_RATIO_PENALTY, new FixedOrderRecommender(1.0, 0.75)); config.put(ORDINAL_TEST_SPLITS, new FixedOrderRecommender(5, 7)); + config.put(MIN_SPLIT_FRACTION, new FixedOrderRecommender(0.01, 0.25, .5 )); + config.put(EXEMPT_ATTRIBUTES, new FixedOrderRecommender(exemptAttributes)); return config; } diff --git a/src/main/java/quickml/supervised/classifier/decisionTree/TreeBuilder.java b/src/main/java/quickml/supervised/classifier/decisionTree/TreeBuilder.java index 6bb1cfa0..8df12f79 100644 --- a/src/main/java/quickml/supervised/classifier/decisionTree/TreeBuilder.java +++ b/src/main/java/quickml/supervised/classifier/decisionTree/TreeBuilder.java @@ -41,11 +41,15 @@ public final class TreeBuilder implements Predicti public static final int RESERVOIR_SIZE = 100; public static final Serializable MISSING_VALUE = "%missingVALUE%83257"; private static final int HARD_MINIMUM_INSTANCES_PER_CATEGORICAL_VALUE = 10; + public static final String MIN_SPLIT_FRACTION = "minSplitFraction"; + public static final String EXEMPT_ATTRIBUTES = "exemptAttributes"; private Scorer scorer; private int maxDepth = 5; private double minimumScore = 0.00000000000001; private int minDiscreteAttributeValueOccurances = 0; + private double minSplitFraction = .005; + private Set exemptAttributes = Sets.newHashSet(); private int minLeafInstances = 0; @@ -74,6 +78,16 @@ public TreeBuilder attributeIgnoringStrategy(AttributeIgnoringStrategy attribute return this; } + public TreeBuilder exemptAttributes(Set exemptAttributes) { + this.exemptAttributes = exemptAttributes; + return this; + } + + public TreeBuilder minSplitFraction(double minSplitFraction) { + this.minSplitFraction = minSplitFraction; + return this; + } + @Deprecated public TreeBuilder ignoreAttributeAtNodeProbability(double ignoreAttributeAtNodeProbability) { attributeIgnoringStrategy(new IgnoreAttributesWithConstantProbability(ignoreAttributeAtNodeProbability)); @@ -92,6 +106,8 @@ public TreeBuilder copy() { copy.ordinalTestSpilts = ordinalTestSpilts; copy.attributeIgnoringStrategy = attributeIgnoringStrategy.copy(); copy.fractionOfDataToUseInHoldOutSet = fractionOfDataToUseInHoldOutSet; + copy.minSplitFraction = minSplitFraction; + copy.exemptAttributes = exemptAttributes; return copy; } @@ -110,8 +126,12 @@ public void updateBuilderConfig(final Map cfg) { minCategoricalAttributeValueOccurances((Integer) cfg.get(MIN_OCCURRENCES_OF_ATTRIBUTE_VALUE)); if (cfg.containsKey(MIN_LEAF_INSTANCES)) minLeafInstances((Integer) cfg.get(MIN_LEAF_INSTANCES)); + if (cfg.containsKey(MIN_SPLIT_FRACTION)) + minSplitFraction((Double) cfg.get(MIN_SPLIT_FRACTION)); if (cfg.containsKey(ORDINAL_TEST_SPLITS)) ordinalTestSplits((Integer) cfg.get(ORDINAL_TEST_SPLITS)); + if (cfg.containsKey(EXEMPT_ATTRIBUTES)) + exemptAttributes((Set) cfg.get(EXEMPT_ATTRIBUTES)); if (cfg.containsKey(DEGREE_OF_GAIN_RATIO_PENALTY)) degreeOfGainRatioPenalty((Double) cfg.get(DEGREE_OF_GAIN_RATIO_PENALTY)); if (cfg.containsKey(ATTRIBUTE_IGNORING_STRATEGY)) @@ -441,6 +461,12 @@ private Pair createTwoClassCategoricalNode(Node parent inCounts = inCounts.add(testValCounts); outCounts = outCounts.subtract(testValCounts); + double numInstances = inCounts.getTotal() + outCounts.getTotal(); + if (!exemptAttributes.contains(attribute) && (inCounts.getTotal()/ numInstances createNumericBranch(Node parent, final St ClassificationCounter inClassificationCounts = ClassificationCounter.countAll(inSet); ClassificationCounter outClassificationCounts = ClassificationCounter.countAll(outSet); + double numInstances = inClassificationCounts.getTotal() + outClassificationCounts.getTotal(); + if (!exemptAttributes.contains(attribute) && (inClassificationCounts.getTotal()/ numInstances createNumericBranch(Node parent, final St if (thisScore > bestScore) { bestScore = thisScore; bestThreshold = threshold; - probabilityOfBeingInInset = inClassificationCounts.getTotal() / (inClassificationCounts.getTotal() + outClassificationCounts.getTotal()); + probabilityOfBeingInInset = inClassificationCounts.getTotal() / numInstances; } } if (bestScore == 0) { diff --git a/src/test/java/quickml/supervised/classifier/StaticBuildersTest.java b/src/test/java/quickml/supervised/classifier/ClassifiersTest.java similarity index 71% rename from src/test/java/quickml/supervised/classifier/StaticBuildersTest.java rename to src/test/java/quickml/supervised/classifier/ClassifiersTest.java index 8b4f8a33..b206c45e 100644 --- a/src/test/java/quickml/supervised/classifier/StaticBuildersTest.java +++ b/src/test/java/quickml/supervised/classifier/ClassifiersTest.java @@ -1,5 +1,6 @@ package quickml.supervised.classifier; +import com.beust.jcommander.internal.Sets; import org.javatuples.Pair; import org.junit.Test; import org.slf4j.Logger; @@ -9,23 +10,29 @@ import quickml.supervised.classifier.downsampling.DownsamplingClassifier; import quickml.supervised.crossValidation.lossfunctions.WeightedAUCCrossValLossFunction; +import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.Set; import static quickml.supervised.InstanceLoader.getAdvertisingInstances; -public class StaticBuildersTest { - private static final Logger logger = LoggerFactory.getLogger(StaticBuildersTest.class); +public class ClassifiersTest { + private static final Logger logger = LoggerFactory.getLogger(ClassifiersTest.class); + @Test public void getOptimizedDownsampledRandomForestIntegrationTest() throws Exception { double fractionOfDataForValidation = .2; int rebuildsPerValidation = 1; - List trainingData = getAdvertisingInstances().subList(0, 3000); + Set exemptAttributes = Sets.newHashSet(); + exemptAttributes.addAll(Arrays.asList("seenClick", "seenCampaignClick", "seenPixel", "seenCampaignPixel", "seenCreativeClick", "seenCampaignClick")); + + List trainingData = getAdvertisingInstances().subList(0, 3000); OnespotDateTimeExtractor dateTimeExtractor = new OnespotDateTimeExtractor(); Pair, DownsamplingClassifier> downsamplingClassifierPair = - Classifiers.getOptimizedDownsampledRandomForest(trainingData, rebuildsPerValidation, fractionOfDataForValidation, new WeightedAUCCrossValLossFunction(1.0), dateTimeExtractor); + Classifiers.getOptimizedDownsampledRandomForest(trainingData, rebuildsPerValidation, fractionOfDataForValidation, new WeightedAUCCrossValLossFunction(1.0), dateTimeExtractor, exemptAttributes); logger.info("logged weighted auc loss should be between 0.25 and 0.28"); } } \ No newline at end of file diff --git a/src/test/java/quickml/supervised/predictiveModelOptimizer/PredictiveModelOptimizerIntegrationTest.java b/src/test/java/quickml/supervised/predictiveModelOptimizer/PredictiveModelOptimizerIntegrationTest.java index bbdf0ee1..5749379e 100644 --- a/src/test/java/quickml/supervised/predictiveModelOptimizer/PredictiveModelOptimizerIntegrationTest.java +++ b/src/test/java/quickml/supervised/predictiveModelOptimizer/PredictiveModelOptimizerIntegrationTest.java @@ -54,6 +54,9 @@ public void testOptimizer() throws Exception { private Map createConfig() { Map config = Maps.newHashMap(); Set attributesToIgnore = Sets.newHashSet(); + Set exemptAttributes = Sets.newHashSet(); + exemptAttributes.addAll(Arrays.asList("seenClick", "seenCampaignClick", "seenPixel", "seenCampaignPixel", "seenCreativeClick", "seenCampaignClick")); + attributesToIgnore.addAll(Arrays.asList("browser", "eap", "destinationId", "seenPixel", "internalCreativeId")); double probabilityOfDiscardingFromAttributesToIgnore = 0.3; CompositeAttributeIgnoringStrategy compositeAttributeIgnoringStrategy = new CompositeAttributeIgnoringStrategy(Arrays.asList( @@ -67,6 +70,9 @@ private Map createConfig() { config.put(MIN_LEAF_INSTANCES, new FixedOrderRecommender(0, 20, 40)); config.put(SCORER, new FixedOrderRecommender(new InformationGainScorer(), new GiniImpurityScorer())); config.put(DEGREE_OF_GAIN_RATIO_PENALTY, new FixedOrderRecommender(1.0, 0.75, .5 )); + config.put(MIN_SPLIT_FRACTION, new FixedOrderRecommender(0.01, 0.25, .5 )); + config.put(EXEMPT_ATTRIBUTES, new FixedOrderRecommender(exemptAttributes)); + return config; }