Skip to content

Commit

Permalink
Merge f9654a3 into 4c3ccf1
Browse files Browse the repository at this point in the history
  • Loading branch information
athawk81 committed May 21, 2015
2 parents 4c3ccf1 + f9654a3 commit 70d52fa
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 12 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Expand Up @@ -32,7 +32,7 @@
be accompanied by a bump in version number, regardless of how minor the change.
-->

<version>0.7.13</version>
<version>0.7.14</version>
<repositories>
<repository>
<id>sanity-maven-repo</id>
Expand Down
15 changes: 9 additions & 6 deletions src/main/java/quickml/supervised/classifier/Classifiers.java
Expand Up @@ -24,8 +24,10 @@
import quickml.supervised.predictiveModelOptimizer.PredictiveModelOptimizerBuilder;
import quickml.supervised.predictiveModelOptimizer.fieldValueRecommenders.FixedOrderRecommender;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static quickml.supervised.classifier.decisionTree.TreeBuilder.*;

Expand Down Expand Up @@ -64,8 +66,8 @@ public static <T extends ClassifierInstance> Pair<Map<String, Object>, Downsampl
return new Pair<Map<String, Object>, DownsamplingClassifier>(bestParams, downsamplingClassifier);
}

public static <T extends ClassifierInstance> Pair<Map<String, Object>, DownsamplingClassifier> getOptimizedDownsampledRandomForest(List<T> trainingData, int rebuildsPerValidation, double fractionOfDataForValidation, ClassifierLossFunction lossFunction, DateTimeExtractor dateTimeExtractor, DownsamplingClassifierBuilder<T> modelBuilder) {
Map<String, FieldValueRecommender> config = createConfig();
public static <T extends ClassifierInstance> Pair<Map<String, Object>, DownsamplingClassifier> getOptimizedDownsampledRandomForest(List<T> trainingData, int rebuildsPerValidation, double fractionOfDataForValidation, ClassifierLossFunction lossFunction, DateTimeExtractor dateTimeExtractor, DownsamplingClassifierBuilder<T> modelBuilder, Set<String> exemptAttributes) {
Map<String, FieldValueRecommender> config = createConfig(exemptAttributes);
return getOptimizedDownsampledRandomForest(trainingData, rebuildsPerValidation, fractionOfDataForValidation, lossFunction, dateTimeExtractor, modelBuilder, config);
}

Expand All @@ -74,13 +76,12 @@ public static <T extends ClassifierInstance> Pair<Map<String, Object>, Downsampl
return getOptimizedDownsampledRandomForest(trainingData, rebuildsPerValidation, fractionOfDataForValidation, lossFunction, dateTimeExtractor, modelBuilder, config);

}
public static Pair<Map<String, Object>, DownsamplingClassifier> getOptimizedDownsampledRandomForest(List<? extends ClassifierInstance> trainingData, int rebuildsPerValidation, double fractionOfDataForValidation, ClassifierLossFunction lossFunction, DateTimeExtractor dateTimeExtractor) {
Map<String, FieldValueRecommender> config = createConfig();
public static Pair<Map<String, Object>, DownsamplingClassifier> getOptimizedDownsampledRandomForest(List<? extends ClassifierInstance> trainingData, int rebuildsPerValidation, double fractionOfDataForValidation, ClassifierLossFunction lossFunction, DateTimeExtractor dateTimeExtractor, Set<String> exemptAttributes) {
Map<String, FieldValueRecommender> config = createConfig(exemptAttributes);
return getOptimizedDownsampledRandomForest(trainingData, rebuildsPerValidation, fractionOfDataForValidation, lossFunction, dateTimeExtractor, config);
}

private static int getTimeSliceHours(List<? extends ClassifierInstance> trainingData, int rebuildsPerValidation, DateTimeExtractor<ClassifierInstance> dateTimeExtractor) {

Utils.sortTrainingInstancesByTime(trainingData, dateTimeExtractor);
DateTime latestDateTime = dateTimeExtractor.extractDateTime(trainingData.get(trainingData.size()-1));
int indexOfEarliestValidationInstance = (int) (0.8 * trainingData.size()) - 1;
Expand All @@ -91,14 +92,16 @@ private static int getTimeSliceHours(List<? extends ClassifierInstance> training
}


private static Map<String, FieldValueRecommender> createConfig() {
private static Map<String, FieldValueRecommender> createConfig(Set<String> exemptAttributes) {
Map<String, FieldValueRecommender> config = Maps.newHashMap();
config.put(MAX_DEPTH, new FixedOrderRecommender(4, 8, 12));//Integer.MAX_VALUE, 2, 3, 5, 6, 9));
config.put(MIN_OCCURRENCES_OF_ATTRIBUTE_VALUE, new FixedOrderRecommender(7, 10));
config.put(MIN_LEAF_INSTANCES, new FixedOrderRecommender(0, 15));
config.put(DownsamplingClassifierBuilder.MINORITY_INSTANCE_PROPORTION, new FixedOrderRecommender(.1, .25));
config.put(DEGREE_OF_GAIN_RATIO_PENALTY, new FixedOrderRecommender(1.0, 0.75));
config.put(ORDINAL_TEST_SPLITS, new FixedOrderRecommender(5, 7));
config.put(MIN_SPLIT_FRACTION, new FixedOrderRecommender(0.01, 0.25, .5 ));
config.put(EXEMPT_ATTRIBUTES, new FixedOrderRecommender(exemptAttributes));
return config;
}

Expand Down
Expand Up @@ -41,11 +41,15 @@ public final class TreeBuilder<T extends ClassifierInstance> implements Predicti
public static final int RESERVOIR_SIZE = 100;
public static final Serializable MISSING_VALUE = "%missingVALUE%83257";
private static final int HARD_MINIMUM_INSTANCES_PER_CATEGORICAL_VALUE = 10;
public static final String MIN_SPLIT_FRACTION = "minSplitFraction";
public static final String EXEMPT_ATTRIBUTES = "exemptAttributes";

private Scorer scorer;
private int maxDepth = 5;
private double minimumScore = 0.00000000000001;
private int minDiscreteAttributeValueOccurances = 0;
private double minSplitFraction = .005;
private Set<String> exemptAttributes = Sets.newHashSet();

private int minLeafInstances = 0;

Expand Down Expand Up @@ -74,6 +78,16 @@ public TreeBuilder attributeIgnoringStrategy(AttributeIgnoringStrategy attribute
return this;
}

public TreeBuilder exemptAttributes(Set<String> exemptAttributes) {
this.exemptAttributes = exemptAttributes;
return this;
}

public TreeBuilder minSplitFraction(double minSplitFraction) {
this.minSplitFraction = minSplitFraction;
return this;
}

@Deprecated
public TreeBuilder ignoreAttributeAtNodeProbability(double ignoreAttributeAtNodeProbability) {
attributeIgnoringStrategy(new IgnoreAttributesWithConstantProbability(ignoreAttributeAtNodeProbability));
Expand All @@ -92,6 +106,8 @@ public TreeBuilder copy() {
copy.ordinalTestSpilts = ordinalTestSpilts;
copy.attributeIgnoringStrategy = attributeIgnoringStrategy.copy();
copy.fractionOfDataToUseInHoldOutSet = fractionOfDataToUseInHoldOutSet;
copy.minSplitFraction = minSplitFraction;
copy.exemptAttributes = exemptAttributes;
return copy;
}

Expand All @@ -110,8 +126,12 @@ public void updateBuilderConfig(final Map<String, Object> cfg) {
minCategoricalAttributeValueOccurances((Integer) cfg.get(MIN_OCCURRENCES_OF_ATTRIBUTE_VALUE));
if (cfg.containsKey(MIN_LEAF_INSTANCES))
minLeafInstances((Integer) cfg.get(MIN_LEAF_INSTANCES));
if (cfg.containsKey(MIN_SPLIT_FRACTION))
minSplitFraction((Double) cfg.get(MIN_SPLIT_FRACTION));
if (cfg.containsKey(ORDINAL_TEST_SPLITS))
ordinalTestSplits((Integer) cfg.get(ORDINAL_TEST_SPLITS));
if (cfg.containsKey(EXEMPT_ATTRIBUTES))
exemptAttributes((Set<String>) cfg.get(EXEMPT_ATTRIBUTES));
if (cfg.containsKey(DEGREE_OF_GAIN_RATIO_PENALTY))
degreeOfGainRatioPenalty((Double) cfg.get(DEGREE_OF_GAIN_RATIO_PENALTY));
if (cfg.containsKey(ATTRIBUTE_IGNORING_STRATEGY))
Expand Down Expand Up @@ -441,6 +461,12 @@ private Pair<? extends Branch, Double> createTwoClassCategoricalNode(Node parent
inCounts = inCounts.add(testValCounts);
outCounts = outCounts.subtract(testValCounts);

double numInstances = inCounts.getTotal() + outCounts.getTotal();
if (!exemptAttributes.contains(attribute) && (inCounts.getTotal()/ numInstances <minSplitFraction ||
outCounts.getTotal()/ numInstances < minSplitFraction)) {
continue;
}

if (inCounts.getTotal() < minLeafInstances || outCounts.getTotal() < minLeafInstances) {
continue;
}
Expand Down Expand Up @@ -666,6 +692,12 @@ private Pair<? extends Branch, Double> createNumericBranch(Node parent, final St
ClassificationCounter inClassificationCounts = ClassificationCounter.countAll(inSet);
ClassificationCounter outClassificationCounts = ClassificationCounter.countAll(outSet);

double numInstances = inClassificationCounts.getTotal() + outClassificationCounts.getTotal();
if (!exemptAttributes.contains(attribute) && (inClassificationCounts.getTotal()/ numInstances <minSplitFraction ||
outClassificationCounts.getTotal()/ numInstances < minSplitFraction)) {
continue;
}

if (binaryClassifications) {
if (attributeValueOrIntervalOfValuesHasInsufficientStatistics(inClassificationCounts)
|| inClassificationCounts.getTotal() < minLeafInstances
Expand All @@ -683,7 +715,7 @@ private Pair<? extends Branch, Double> createNumericBranch(Node parent, final St
if (thisScore > bestScore) {
bestScore = thisScore;
bestThreshold = threshold;
probabilityOfBeingInInset = inClassificationCounts.getTotal() / (inClassificationCounts.getTotal() + outClassificationCounts.getTotal());
probabilityOfBeingInInset = inClassificationCounts.getTotal() / numInstances;
}
}
if (bestScore == 0) {
Expand Down
@@ -1,5 +1,6 @@
package quickml.supervised.classifier;

import com.beust.jcommander.internal.Sets;
import org.javatuples.Pair;
import org.junit.Test;
import org.slf4j.Logger;
Expand All @@ -9,23 +10,29 @@
import quickml.supervised.classifier.downsampling.DownsamplingClassifier;
import quickml.supervised.crossValidation.lossfunctions.WeightedAUCCrossValLossFunction;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static quickml.supervised.InstanceLoader.getAdvertisingInstances;


public class StaticBuildersTest {
private static final Logger logger = LoggerFactory.getLogger(StaticBuildersTest.class);
public class ClassifiersTest {
private static final Logger logger = LoggerFactory.getLogger(ClassifiersTest.class);


@Test
public void getOptimizedDownsampledRandomForestIntegrationTest() throws Exception {
double fractionOfDataForValidation = .2;
int rebuildsPerValidation = 1;
List<ClassifierInstance> trainingData = getAdvertisingInstances().subList(0, 3000);
Set<String> exemptAttributes = Sets.newHashSet();
exemptAttributes.addAll(Arrays.asList("seenClick", "seenCampaignClick", "seenPixel", "seenCampaignPixel", "seenCreativeClick", "seenCampaignClick"));

List<ClassifierInstance> trainingData = getAdvertisingInstances().subList(0, 3000);
OnespotDateTimeExtractor dateTimeExtractor = new OnespotDateTimeExtractor();
Pair<Map<String, Object>, DownsamplingClassifier> downsamplingClassifierPair =
Classifiers.<ClassifierInstance>getOptimizedDownsampledRandomForest(trainingData, rebuildsPerValidation, fractionOfDataForValidation, new WeightedAUCCrossValLossFunction(1.0), dateTimeExtractor);
Classifiers.<ClassifierInstance>getOptimizedDownsampledRandomForest(trainingData, rebuildsPerValidation, fractionOfDataForValidation, new WeightedAUCCrossValLossFunction(1.0), dateTimeExtractor, exemptAttributes);
logger.info("logged weighted auc loss should be between 0.25 and 0.28");
}
}
Expand Up @@ -54,6 +54,9 @@ public void testOptimizer() throws Exception {
private Map<String, FieldValueRecommender> createConfig() {
Map<String, FieldValueRecommender> config = Maps.newHashMap();
Set<String> attributesToIgnore = Sets.newHashSet();
Set<String> exemptAttributes = Sets.newHashSet();
exemptAttributes.addAll(Arrays.asList("seenClick", "seenCampaignClick", "seenPixel", "seenCampaignPixel", "seenCreativeClick", "seenCampaignClick"));

attributesToIgnore.addAll(Arrays.asList("browser", "eap", "destinationId", "seenPixel", "internalCreativeId"));
double probabilityOfDiscardingFromAttributesToIgnore = 0.3;
CompositeAttributeIgnoringStrategy compositeAttributeIgnoringStrategy = new CompositeAttributeIgnoringStrategy(Arrays.asList(
Expand All @@ -67,6 +70,9 @@ private Map<String, FieldValueRecommender> createConfig() {
config.put(MIN_LEAF_INSTANCES, new FixedOrderRecommender(0, 20, 40));
config.put(SCORER, new FixedOrderRecommender(new InformationGainScorer(), new GiniImpurityScorer()));
config.put(DEGREE_OF_GAIN_RATIO_PENALTY, new FixedOrderRecommender(1.0, 0.75, .5 ));
config.put(MIN_SPLIT_FRACTION, new FixedOrderRecommender(0.01, 0.25, .5 ));
config.put(EXEMPT_ATTRIBUTES, new FixedOrderRecommender(exemptAttributes));

return config;
}

Expand Down

0 comments on commit 70d52fa

Please sign in to comment.