Permalink
Browse files

Merge f9654a3 into 4c3ccf1

  • Loading branch information...
athawk81 committed May 21, 2015
2 parents 4c3ccf1 + f9654a3 commit 70d52fae1e2e879efbcc6f4b7d512c659cdb3e22
View
@@ -32,7 +32,7 @@
be accompanied by a bump in version number, regardless of how minor the change.
-->
<version>0.7.13</version>
<version>0.7.14</version>
<repositories>
<repository>
<id>sanity-maven-repo</id>
@@ -24,8 +24,10 @@
import quickml.supervised.predictiveModelOptimizer.PredictiveModelOptimizerBuilder;
import quickml.supervised.predictiveModelOptimizer.fieldValueRecommenders.FixedOrderRecommender;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static quickml.supervised.classifier.decisionTree.TreeBuilder.*;
@@ -64,8 +66,8 @@
return new Pair<Map<String, Object>, DownsamplingClassifier>(bestParams, downsamplingClassifier);
}
public static <T extends ClassifierInstance> Pair<Map<String, Object>, DownsamplingClassifier> getOptimizedDownsampledRandomForest(List<T> trainingData, int rebuildsPerValidation, double fractionOfDataForValidation, ClassifierLossFunction lossFunction, DateTimeExtractor dateTimeExtractor, DownsamplingClassifierBuilder<T> modelBuilder) {
Map<String, FieldValueRecommender> config = createConfig();
public static <T extends ClassifierInstance> Pair<Map<String, Object>, DownsamplingClassifier> getOptimizedDownsampledRandomForest(List<T> trainingData, int rebuildsPerValidation, double fractionOfDataForValidation, ClassifierLossFunction lossFunction, DateTimeExtractor dateTimeExtractor, DownsamplingClassifierBuilder<T> modelBuilder, Set<String> exemptAttributes) {
Map<String, FieldValueRecommender> config = createConfig(exemptAttributes);
return getOptimizedDownsampledRandomForest(trainingData, rebuildsPerValidation, fractionOfDataForValidation, lossFunction, dateTimeExtractor, modelBuilder, config);
}
@@ -74,13 +76,12 @@
return getOptimizedDownsampledRandomForest(trainingData, rebuildsPerValidation, fractionOfDataForValidation, lossFunction, dateTimeExtractor, modelBuilder, config);
}
public static Pair<Map<String, Object>, DownsamplingClassifier> getOptimizedDownsampledRandomForest(List<? extends ClassifierInstance> trainingData, int rebuildsPerValidation, double fractionOfDataForValidation, ClassifierLossFunction lossFunction, DateTimeExtractor dateTimeExtractor) {
Map<String, FieldValueRecommender> config = createConfig();
public static Pair<Map<String, Object>, DownsamplingClassifier> getOptimizedDownsampledRandomForest(List<? extends ClassifierInstance> trainingData, int rebuildsPerValidation, double fractionOfDataForValidation, ClassifierLossFunction lossFunction, DateTimeExtractor dateTimeExtractor, Set<String> exemptAttributes) {
Map<String, FieldValueRecommender> config = createConfig(exemptAttributes);
return getOptimizedDownsampledRandomForest(trainingData, rebuildsPerValidation, fractionOfDataForValidation, lossFunction, dateTimeExtractor, config);
}
private static int getTimeSliceHours(List<? extends ClassifierInstance> trainingData, int rebuildsPerValidation, DateTimeExtractor<ClassifierInstance> dateTimeExtractor) {
Utils.sortTrainingInstancesByTime(trainingData, dateTimeExtractor);
DateTime latestDateTime = dateTimeExtractor.extractDateTime(trainingData.get(trainingData.size()-1));
int indexOfEarliestValidationInstance = (int) (0.8 * trainingData.size()) - 1;
@@ -91,14 +92,16 @@ private static int getTimeSliceHours(List<? extends ClassifierInstance> training
}
private static Map<String, FieldValueRecommender> createConfig() {
private static Map<String, FieldValueRecommender> createConfig(Set<String> exemptAttributes) {
Map<String, FieldValueRecommender> config = Maps.newHashMap();
config.put(MAX_DEPTH, new FixedOrderRecommender(4, 8, 12));//Integer.MAX_VALUE, 2, 3, 5, 6, 9));
config.put(MIN_OCCURRENCES_OF_ATTRIBUTE_VALUE, new FixedOrderRecommender(7, 10));
config.put(MIN_LEAF_INSTANCES, new FixedOrderRecommender(0, 15));
config.put(DownsamplingClassifierBuilder.MINORITY_INSTANCE_PROPORTION, new FixedOrderRecommender(.1, .25));
config.put(DEGREE_OF_GAIN_RATIO_PENALTY, new FixedOrderRecommender(1.0, 0.75));
config.put(ORDINAL_TEST_SPLITS, new FixedOrderRecommender(5, 7));
config.put(MIN_SPLIT_FRACTION, new FixedOrderRecommender(0.01, 0.25, .5 ));
config.put(EXEMPT_ATTRIBUTES, new FixedOrderRecommender(exemptAttributes));
return config;
}
@@ -41,11 +41,15 @@
public static final int RESERVOIR_SIZE = 100;
public static final Serializable MISSING_VALUE = "%missingVALUE%83257";
private static final int HARD_MINIMUM_INSTANCES_PER_CATEGORICAL_VALUE = 10;
public static final String MIN_SPLIT_FRACTION = "minSplitFraction";
public static final String EXEMPT_ATTRIBUTES = "exemptAttributes";
private Scorer scorer;
private int maxDepth = 5;
private double minimumScore = 0.00000000000001;
private int minDiscreteAttributeValueOccurances = 0;
private double minSplitFraction = .005;
private Set<String> exemptAttributes = Sets.newHashSet();
private int minLeafInstances = 0;
@@ -74,6 +78,16 @@ public TreeBuilder attributeIgnoringStrategy(AttributeIgnoringStrategy attribute
return this;
}
public TreeBuilder exemptAttributes(Set<String> exemptAttributes) {
this.exemptAttributes = exemptAttributes;
return this;
}
public TreeBuilder minSplitFraction(double minSplitFraction) {
this.minSplitFraction = minSplitFraction;
return this;
}
@Deprecated
public TreeBuilder ignoreAttributeAtNodeProbability(double ignoreAttributeAtNodeProbability) {
attributeIgnoringStrategy(new IgnoreAttributesWithConstantProbability(ignoreAttributeAtNodeProbability));
@@ -92,6 +106,8 @@ public TreeBuilder copy() {
copy.ordinalTestSpilts = ordinalTestSpilts;
copy.attributeIgnoringStrategy = attributeIgnoringStrategy.copy();
copy.fractionOfDataToUseInHoldOutSet = fractionOfDataToUseInHoldOutSet;
copy.minSplitFraction = minSplitFraction;
copy.exemptAttributes = exemptAttributes;
return copy;
}
@@ -110,8 +126,12 @@ public void updateBuilderConfig(final Map<String, Object> cfg) {
minCategoricalAttributeValueOccurances((Integer) cfg.get(MIN_OCCURRENCES_OF_ATTRIBUTE_VALUE));
if (cfg.containsKey(MIN_LEAF_INSTANCES))
minLeafInstances((Integer) cfg.get(MIN_LEAF_INSTANCES));
if (cfg.containsKey(MIN_SPLIT_FRACTION))
minSplitFraction((Double) cfg.get(MIN_SPLIT_FRACTION));
if (cfg.containsKey(ORDINAL_TEST_SPLITS))
ordinalTestSplits((Integer) cfg.get(ORDINAL_TEST_SPLITS));
if (cfg.containsKey(EXEMPT_ATTRIBUTES))
exemptAttributes((Set<String>) cfg.get(EXEMPT_ATTRIBUTES));
if (cfg.containsKey(DEGREE_OF_GAIN_RATIO_PENALTY))
degreeOfGainRatioPenalty((Double) cfg.get(DEGREE_OF_GAIN_RATIO_PENALTY));
if (cfg.containsKey(ATTRIBUTE_IGNORING_STRATEGY))
@@ -441,6 +461,12 @@ private boolean isSmallTrainingSet(Iterable<T> trainingData) {
inCounts = inCounts.add(testValCounts);
outCounts = outCounts.subtract(testValCounts);
double numInstances = inCounts.getTotal() + outCounts.getTotal();
if (!exemptAttributes.contains(attribute) && (inCounts.getTotal()/ numInstances <minSplitFraction ||
outCounts.getTotal()/ numInstances < minSplitFraction)) {
continue;
}
if (inCounts.getTotal() < minLeafInstances || outCounts.getTotal() < minLeafInstances) {
continue;
}
@@ -666,6 +692,12 @@ private boolean hasBothMinorityAndMajorityClassifications(Map<Serializable, Doub
ClassificationCounter inClassificationCounts = ClassificationCounter.countAll(inSet);
ClassificationCounter outClassificationCounts = ClassificationCounter.countAll(outSet);
double numInstances = inClassificationCounts.getTotal() + outClassificationCounts.getTotal();
if (!exemptAttributes.contains(attribute) && (inClassificationCounts.getTotal()/ numInstances <minSplitFraction ||
outClassificationCounts.getTotal()/ numInstances < minSplitFraction)) {
continue;
}
if (binaryClassifications) {
if (attributeValueOrIntervalOfValuesHasInsufficientStatistics(inClassificationCounts)
|| inClassificationCounts.getTotal() < minLeafInstances
@@ -683,7 +715,7 @@ private boolean hasBothMinorityAndMajorityClassifications(Map<Serializable, Doub
if (thisScore > bestScore) {
bestScore = thisScore;
bestThreshold = threshold;
probabilityOfBeingInInset = inClassificationCounts.getTotal() / (inClassificationCounts.getTotal() + outClassificationCounts.getTotal());
probabilityOfBeingInInset = inClassificationCounts.getTotal() / numInstances;
}
}
if (bestScore == 0) {
@@ -1,5 +1,6 @@
package quickml.supervised.classifier;
import com.beust.jcommander.internal.Sets;
import org.javatuples.Pair;
import org.junit.Test;
import org.slf4j.Logger;
@@ -9,23 +10,29 @@
import quickml.supervised.classifier.downsampling.DownsamplingClassifier;
import quickml.supervised.crossValidation.lossfunctions.WeightedAUCCrossValLossFunction;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static quickml.supervised.InstanceLoader.getAdvertisingInstances;
public class StaticBuildersTest {
private static final Logger logger = LoggerFactory.getLogger(StaticBuildersTest.class);
public class ClassifiersTest {
private static final Logger logger = LoggerFactory.getLogger(ClassifiersTest.class);
@Test
public void getOptimizedDownsampledRandomForestIntegrationTest() throws Exception {
double fractionOfDataForValidation = .2;
int rebuildsPerValidation = 1;
List<ClassifierInstance> trainingData = getAdvertisingInstances().subList(0, 3000);
Set<String> exemptAttributes = Sets.newHashSet();
exemptAttributes.addAll(Arrays.asList("seenClick", "seenCampaignClick", "seenPixel", "seenCampaignPixel", "seenCreativeClick", "seenCampaignClick"));
List<ClassifierInstance> trainingData = getAdvertisingInstances().subList(0, 3000);
OnespotDateTimeExtractor dateTimeExtractor = new OnespotDateTimeExtractor();
Pair<Map<String, Object>, DownsamplingClassifier> downsamplingClassifierPair =
Classifiers.<ClassifierInstance>getOptimizedDownsampledRandomForest(trainingData, rebuildsPerValidation, fractionOfDataForValidation, new WeightedAUCCrossValLossFunction(1.0), dateTimeExtractor);
Classifiers.<ClassifierInstance>getOptimizedDownsampledRandomForest(trainingData, rebuildsPerValidation, fractionOfDataForValidation, new WeightedAUCCrossValLossFunction(1.0), dateTimeExtractor, exemptAttributes);
logger.info("logged weighted auc loss should be between 0.25 and 0.28");
}
}
@@ -54,6 +54,9 @@ public void testOptimizer() throws Exception {
private Map<String, FieldValueRecommender> createConfig() {
Map<String, FieldValueRecommender> config = Maps.newHashMap();
Set<String> attributesToIgnore = Sets.newHashSet();
Set<String> exemptAttributes = Sets.newHashSet();
exemptAttributes.addAll(Arrays.asList("seenClick", "seenCampaignClick", "seenPixel", "seenCampaignPixel", "seenCreativeClick", "seenCampaignClick"));
attributesToIgnore.addAll(Arrays.asList("browser", "eap", "destinationId", "seenPixel", "internalCreativeId"));
double probabilityOfDiscardingFromAttributesToIgnore = 0.3;
CompositeAttributeIgnoringStrategy compositeAttributeIgnoringStrategy = new CompositeAttributeIgnoringStrategy(Arrays.asList(
@@ -67,6 +70,9 @@ public void testOptimizer() throws Exception {
config.put(MIN_LEAF_INSTANCES, new FixedOrderRecommender(0, 20, 40));
config.put(SCORER, new FixedOrderRecommender(new InformationGainScorer(), new GiniImpurityScorer()));
config.put(DEGREE_OF_GAIN_RATIO_PENALTY, new FixedOrderRecommender(1.0, 0.75, .5 ));
config.put(MIN_SPLIT_FRACTION, new FixedOrderRecommender(0.01, 0.25, .5 ));
config.put(EXEMPT_ATTRIBUTES, new FixedOrderRecommender(exemptAttributes));
return config;
}

0 comments on commit 70d52fa

Please sign in to comment.