Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add code to read in training instances from a sparse input format, an…
…d some optimization code
- Loading branch information
Showing
3 changed files
with
169 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
105 changes: 105 additions & 0 deletions
105
src/main/java/quickml/supervised/tree/decisionTree/OptimizedDecisionForest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
package quickml.supervised.tree.decisionTree; | ||
|
||
import com.google.common.collect.Maps; | ||
import org.javatuples.Pair; | ||
import org.joda.time.DateTime; | ||
import org.joda.time.Duration; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
import quickml.data.instances.ClassifierInstance; | ||
import quickml.supervised.Utils; | ||
import quickml.supervised.crossValidation.ClassifierLossChecker; | ||
import quickml.supervised.crossValidation.LossChecker; | ||
import quickml.supervised.crossValidation.RegressionLossChecker; | ||
import quickml.supervised.crossValidation.data.FoldedData; | ||
import quickml.supervised.crossValidation.data.TrainingDataCycler; | ||
import quickml.supervised.crossValidation.lossfunctions.classifierLossFunctions.WeightedAUCCrossValLossFunction; | ||
import quickml.supervised.crossValidation.utils.DateTimeExtractor; | ||
|
||
import quickml.supervised.ensembles.randomForest.randomDecisionForest.RandomDecisionForest; | ||
import quickml.supervised.ensembles.randomForest.randomDecisionForest.RandomDecisionForestBuilder; | ||
import quickml.supervised.predictiveModelOptimizer.FieldValueRecommender; | ||
import quickml.supervised.predictiveModelOptimizer.PredictiveModelOptimizer; | ||
import quickml.supervised.predictiveModelOptimizer.SimplePredictiveModelOptimizerBuilder; | ||
import quickml.supervised.predictiveModelOptimizer.fieldValueRecommenders.FixedOrderRecommender; | ||
import quickml.supervised.tree.attributeIgnoringStrategies.IgnoreAttributesWithConstantProbability; | ||
|
||
import java.io.Serializable; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import static quickml.supervised.tree.constants.ForestOptions.*; | ||
import static quickml.supervised.tree.constants.ForestOptions.NUM_TREES; | ||
|
||
|
||
/** | ||
* Created by alexanderhawk on 3/5/15. | ||
*/ | ||
|
||
/* FIXME: This is unnecessarily specialized to out-of-time cross-validation, should be generalized | ||
* so that it can support alternate ways to separate training from test set (for example, | ||
* any Comparable class can be used to sort the training instances, not just DateTime). | ||
*/ | ||
public class OptimizedDecisionForest { | ||
private static final Logger logger = LoggerFactory.getLogger(quickml.supervised.tree.regressionTree.OptimizedRegressionForests.class); | ||
|
||
public static <T extends ClassifierInstance> Pair<Map<String, Serializable>, RandomDecisionForest> getOptimizedRandomForest(List<T> trainingData, Map<String, FieldValueRecommender> config) { | ||
TrainingDataCycler<T> dataCycler = new FoldedData<>(trainingData, 6, 2); | ||
return getOptimizedRandomForest(trainingData, config, dataCycler); | ||
} | ||
|
||
public static <T extends ClassifierInstance> Pair<Map<String, Serializable>, RandomDecisionForest> getOptimizedRandomForest(List<T> trainingData, Map<String, FieldValueRecommender> config, TrainingDataCycler<T> trainingDataCycler) { | ||
ClassifierLossChecker<T, RandomDecisionForest> lossChecker = new ClassifierLossChecker<>(new WeightedAUCCrossValLossFunction(1.0)); | ||
RandomDecisionForestBuilder<T> modelBuilder = new RandomDecisionForestBuilder<T>(); | ||
PredictiveModelOptimizer optimizer = new SimplePredictiveModelOptimizerBuilder<RandomDecisionForest, T>() | ||
.modelBuilder(modelBuilder) | ||
.dataCycler(trainingDataCycler) | ||
.lossChecker(lossChecker) | ||
.valuesToTest(config) | ||
.iterations(2).build(); | ||
|
||
Map<String, Serializable> optimalConfig = optimizer.determineOptimalConfig(); | ||
|
||
modelBuilder.updateBuilderConfig(optimalConfig); | ||
return Pair.with(optimalConfig, modelBuilder.buildPredictiveModel(trainingData)); | ||
} | ||
|
||
|
||
public static <T extends ClassifierInstance> Pair<Map<String, Serializable>, RandomDecisionForest> getOptimizedRandomForest(List<T> trainingData) { | ||
Map<String, FieldValueRecommender> config = createConfig(); | ||
return getOptimizedRandomForest(trainingData, config, new FoldedData<>(trainingData, 6, 2)); | ||
} | ||
|
||
|
||
private static <I extends ClassifierInstance> int getTimeSliceHours(List<I> trainingData, int rebuildsPerValidation, DateTimeExtractor<I> dateTimeExtractor) { | ||
Utils.sortTrainingInstancesByTime(trainingData, dateTimeExtractor); | ||
DateTime latestDateTime = dateTimeExtractor.extractDateTime(trainingData.get(trainingData.size() - 1)); | ||
int indexOfEarliestValidationInstance = (int) (0.8 * trainingData.size()) - 1; | ||
DateTime earliestValidationTime = dateTimeExtractor.extractDateTime(trainingData.get(indexOfEarliestValidationInstance)); | ||
Duration duration = new Duration(earliestValidationTime, latestDateTime); | ||
int validationPeriodHours = (int) duration.getStandardHours(); | ||
return validationPeriodHours / rebuildsPerValidation; | ||
} | ||
|
||
|
||
// FIXME: Since most users of QuickML will be content with a default set of hyperparameters, we shouldn't force | ||
|
||
private static Map<String, FieldValueRecommender> createConfig() { | ||
Map<String, FieldValueRecommender> config = Maps.newHashMap(); | ||
config.put(MAX_DEPTH.name(), new FixedOrderRecommender(2, 5, 12)); | ||
config.put(MIN_ATTRIBUTE_VALUE_OCCURRENCES.name(), new FixedOrderRecommender(7, 14)); | ||
config.put(MIN_LEAF_INSTANCES.name(), new FixedOrderRecommender(5));// 10)); | ||
config.put(ATTRIBUTE_IGNORING_STRATEGY.name(), new FixedOrderRecommender( | ||
new IgnoreAttributesWithConstantProbability(0.75), | ||
new IgnoreAttributesWithConstantProbability(0.85), | ||
new IgnoreAttributesWithConstantProbability(0.9) | ||
)); | ||
config.put(MIN_SLPIT_FRACTION.name(), new FixedOrderRecommender(0.0));//, 0.05, 0.2)); | ||
config.put(NUM_NUMERIC_BINS.name(), new FixedOrderRecommender(2));//, 5, 8)); | ||
config.put(NUM_SAMPLES_PER_NUMERIC_BIN.name(), new FixedOrderRecommender(25)); | ||
// config.put(DownsamplingClassifierBuilder.MINORITY_INSTANCE_PROPORTION, new FixedOrderRecommender(.1, .2)); | ||
config.put(DEGREE_OF_GAIN_RATIO_PENALTY.name(), new FixedOrderRecommender(1.0, 0.75)); | ||
config.put(NUM_TREES.name(), new FixedOrderRecommender(8)); | ||
return config; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
package quickml.utlities; | ||
|
||
import com.google.common.collect.Lists; | ||
import com.google.common.collect.Maps; | ||
import org.joda.time.DateTime; | ||
import org.joda.time.format.DateTimeFormat; | ||
import org.joda.time.format.DateTimeFormatter; | ||
import quickml.data.AttributesMap; | ||
import quickml.data.instances.ClassifierInstance; | ||
import quickml.supervised.crossValidation.utils.DateTimeExtractor; | ||
|
||
import java.io.BufferedReader; | ||
import java.io.FileReader; | ||
import java.io.IOException; | ||
import java.util.Arrays; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
/** | ||
* Created by alexanderhawk on 5/16/17. | ||
*/ | ||
public class LibSVMFormatReader { | ||
|
||
public List<ClassifierInstance> readLibSVMFormattedInstances(String path, String dateAttribute) { | ||
List<ClassifierInstance> instances = Lists.newArrayList(); | ||
try (BufferedReader br = new BufferedReader(new FileReader(path))) { | ||
for (String line; (line = br.readLine()) != null; ) { | ||
List<String> rawInstance = Arrays.asList(line.split(" ")); | ||
Double label = Double.valueOf(rawInstance.get(0)); | ||
AttributesMap map = AttributesMap.newHashMap(); | ||
DateTime instanceTimeStamp = null; | ||
for (String rawAttributeAndValue : rawInstance.subList(1, rawInstance.size())) { | ||
String[] attributeAndValue = rawAttributeAndValue.split(":"); | ||
String attribute = attributeAndValue[0]; | ||
String value = attributeAndValue[1]; | ||
if (attribute.equals(dateAttribute)) { | ||
DateTimeFormatter dateTimeFormatter = DateTimeFormat.forPattern("yyyy-MM-dd'T'HH:mm:ss.SSS"); //format of T may be wrong | ||
instanceTimeStamp = new DateTime(dateTimeFormatter.parseMillis((String) value)); | ||
} else { | ||
try { | ||
//add numeric variable as Double | ||
map.put(attribute, Double.parseDouble(value)); | ||
} catch (NumberFormatException e) { | ||
//add categorical variable as String | ||
map.put(attribute, value); | ||
} | ||
} | ||
} | ||
instances.add(new ClassifierInstance(map, label, instanceTimeStamp)); | ||
} | ||
} catch (IOException e) { | ||
throw new RuntimeException(e.getMessage()); | ||
} | ||
return instances; | ||
} | ||
} | ||
|
||
|