Skip to content

Commit

Permalink
add code to read in training instances from a sparse input format, an…
Browse files Browse the repository at this point in the history
…d some optimization code
  • Loading branch information
athawk81 committed May 17, 2017
1 parent 2b348d5 commit 8779b5b
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/main/java/quickml/data/instances/ClassifierInstance.java
@@ -1,5 +1,6 @@
package quickml.data.instances;

import org.joda.time.DateTime;
import quickml.data.AttributesMap;

import java.io.Serializable;
Expand All @@ -8,12 +9,17 @@
* Created by alexanderhawk on 4/14/15.
*/
public class ClassifierInstance extends InstanceWithAttributesMap<Serializable> {
public DateTime timeStamp;
public ClassifierInstance(AttributesMap attributes, Serializable label) {
super(attributes, label, 1.0);
}
public ClassifierInstance(AttributesMap attributes, Serializable label, double weight) {
super(attributes, label, weight);
}
public ClassifierInstance(AttributesMap attributes, Serializable label, DateTime timeStamp) {
super(attributes, label, 1.0);
this.timeStamp = timeStamp;
}

}

@@ -0,0 +1,105 @@
package quickml.supervised.tree.decisionTree;

import com.google.common.collect.Maps;
import org.javatuples.Pair;
import org.joda.time.DateTime;
import org.joda.time.Duration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import quickml.data.instances.ClassifierInstance;
import quickml.supervised.Utils;
import quickml.supervised.crossValidation.ClassifierLossChecker;
import quickml.supervised.crossValidation.LossChecker;
import quickml.supervised.crossValidation.RegressionLossChecker;
import quickml.supervised.crossValidation.data.FoldedData;
import quickml.supervised.crossValidation.data.TrainingDataCycler;
import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.WeightedAUCCrossValLossFunction;
import quickml.supervised.crossValidation.utils.DateTimeExtractor;

import quickml.supervised.ensembles.randomForest.randomDecisionForest.RandomDecisionForest;
import quickml.supervised.ensembles.randomForest.randomDecisionForest.RandomDecisionForestBuilder;
import quickml.supervised.predictiveModelOptimizer.FieldValueRecommender;
import quickml.supervised.predictiveModelOptimizer.PredictiveModelOptimizer;
import quickml.supervised.predictiveModelOptimizer.SimplePredictiveModelOptimizerBuilder;
import quickml.supervised.predictiveModelOptimizer.fieldValueRecommenders.FixedOrderRecommender;
import quickml.supervised.tree.attributeIgnoringStrategies.IgnoreAttributesWithConstantProbability;

import java.io.Serializable;
import java.util.List;
import java.util.Map;

import static quickml.supervised.tree.constants.ForestOptions.*;
import static quickml.supervised.tree.constants.ForestOptions.NUM_TREES;


/**
* Created by alexanderhawk on 3/5/15.
*/

/* FIXME: This is unnecessarily specialized to out-of-time cross-validation, should be generalized
* so that it can support alternate ways to separate training from test set (for example,
* any Comparable class can be used to sort the training instances, not just DateTime).
*/
public class OptimizedDecisionForest {
private static final Logger logger = LoggerFactory.getLogger(quickml.supervised.tree.regressionTree.OptimizedRegressionForests.class);

public static <T extends ClassifierInstance> Pair<Map<String, Serializable>, RandomDecisionForest> getOptimizedRandomForest(List<T> trainingData, Map<String, FieldValueRecommender> config) {
TrainingDataCycler<T> dataCycler = new FoldedData<>(trainingData, 6, 2);
return getOptimizedRandomForest(trainingData, config, dataCycler);
}

public static <T extends ClassifierInstance> Pair<Map<String, Serializable>, RandomDecisionForest> getOptimizedRandomForest(List<T> trainingData, Map<String, FieldValueRecommender> config, TrainingDataCycler<T> trainingDataCycler) {
ClassifierLossChecker<T, RandomDecisionForest> lossChecker = new ClassifierLossChecker<>(new WeightedAUCCrossValLossFunction(1.0));
RandomDecisionForestBuilder<T> modelBuilder = new RandomDecisionForestBuilder<T>();
PredictiveModelOptimizer optimizer = new SimplePredictiveModelOptimizerBuilder<RandomDecisionForest, T>()
.modelBuilder(modelBuilder)
.dataCycler(trainingDataCycler)
.lossChecker(lossChecker)
.valuesToTest(config)
.iterations(2).build();

Map<String, Serializable> optimalConfig = optimizer.determineOptimalConfig();

modelBuilder.updateBuilderConfig(optimalConfig);
return Pair.with(optimalConfig, modelBuilder.buildPredictiveModel(trainingData));
}


public static <T extends ClassifierInstance> Pair<Map<String, Serializable>, RandomDecisionForest> getOptimizedRandomForest(List<T> trainingData) {
Map<String, FieldValueRecommender> config = createConfig();
return getOptimizedRandomForest(trainingData, config, new FoldedData<>(trainingData, 6, 2));
}


private static <I extends ClassifierInstance> int getTimeSliceHours(List<I> trainingData, int rebuildsPerValidation, DateTimeExtractor<I> dateTimeExtractor) {
Utils.sortTrainingInstancesByTime(trainingData, dateTimeExtractor);
DateTime latestDateTime = dateTimeExtractor.extractDateTime(trainingData.get(trainingData.size() - 1));
int indexOfEarliestValidationInstance = (int) (0.8 * trainingData.size()) - 1;
DateTime earliestValidationTime = dateTimeExtractor.extractDateTime(trainingData.get(indexOfEarliestValidationInstance));
Duration duration = new Duration(earliestValidationTime, latestDateTime);
int validationPeriodHours = (int) duration.getStandardHours();
return validationPeriodHours / rebuildsPerValidation;
}


// FIXME: Since most users of QuickML will be content with a default set of hyperparameters, we shouldn't force

private static Map<String, FieldValueRecommender> createConfig() {
Map<String, FieldValueRecommender> config = Maps.newHashMap();
config.put(MAX_DEPTH.name(), new FixedOrderRecommender(2, 5, 12));
config.put(MIN_ATTRIBUTE_VALUE_OCCURRENCES.name(), new FixedOrderRecommender(7, 14));
config.put(MIN_LEAF_INSTANCES.name(), new FixedOrderRecommender(5));// 10));
config.put(ATTRIBUTE_IGNORING_STRATEGY.name(), new FixedOrderRecommender(
new IgnoreAttributesWithConstantProbability(0.75),
new IgnoreAttributesWithConstantProbability(0.85),
new IgnoreAttributesWithConstantProbability(0.9)
));
config.put(MIN_SLPIT_FRACTION.name(), new FixedOrderRecommender(0.0));//, 0.05, 0.2));
config.put(NUM_NUMERIC_BINS.name(), new FixedOrderRecommender(2));//, 5, 8));
config.put(NUM_SAMPLES_PER_NUMERIC_BIN.name(), new FixedOrderRecommender(25));
// config.put(DownsamplingClassifierBuilder.MINORITY_INSTANCE_PROPORTION, new FixedOrderRecommender(.1, .2));
config.put(DEGREE_OF_GAIN_RATIO_PENALTY.name(), new FixedOrderRecommender(1.0, 0.75));
config.put(NUM_TREES.name(), new FixedOrderRecommender(8));
return config;
}
}
58 changes: 58 additions & 0 deletions src/main/java/quickml/utlities/LibSVMFormatReader.java
@@ -0,0 +1,58 @@
package quickml.utlities;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import org.joda.time.DateTime;
import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter;
import quickml.data.AttributesMap;
import quickml.data.instances.ClassifierInstance;
import quickml.supervised.crossValidation.utils.DateTimeExtractor;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

/**
* Created by alexanderhawk on 5/16/17.
*/
public class LibSVMFormatReader {

public List<ClassifierInstance> readLibSVMFormattedInstances(String path, String dateAttribute) {
List<ClassifierInstance> instances = Lists.newArrayList();
try (BufferedReader br = new BufferedReader(new FileReader(path))) {
for (String line; (line = br.readLine()) != null; ) {
List<String> rawInstance = Arrays.asList(line.split(" "));
Double label = Double.valueOf(rawInstance.get(0));
AttributesMap map = AttributesMap.newHashMap();
DateTime instanceTimeStamp = null;
for (String rawAttributeAndValue : rawInstance.subList(1, rawInstance.size())) {
String[] attributeAndValue = rawAttributeAndValue.split(":");
String attribute = attributeAndValue[0];
String value = attributeAndValue[1];
if (attribute.equals(dateAttribute)) {
DateTimeFormatter dateTimeFormatter = DateTimeFormat.forPattern("yyyy-MM-dd'T'HH:mm:ss.SSS"); //format of T may be wrong
instanceTimeStamp = new DateTime(dateTimeFormatter.parseMillis((String) value));
} else {
try {
//add numeric variable as Double
map.put(attribute, Double.parseDouble(value));
} catch (NumberFormatException e) {
//add categorical variable as String
map.put(attribute, value);
}
}
}
instances.add(new ClassifierInstance(map, label, instanceTimeStamp));
}
} catch (IOException e) {
throw new RuntimeException(e.getMessage());
}
return instances;
}
}


0 comments on commit 8779b5b

Please sign in to comment.