@@ -2,6 +2,7 @@

import com.google.common.collect.Lists;
import org.joda.time.DateTime;
import org.joda.time.DateTimeConstants;
import org.joda.time.Hours;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -19,130 +20,93 @@
*/
public class TemporallyReweightedPMBuilder implements UpdatablePredictiveModelBuilder<TemporallyReweightedPM> {
private static final Logger logger = LoggerFactory.getLogger(TemporallyReweightedPMBuilder.class);

private double halfLifeOfPositiveInHours;
private double decayConstantOfPositive;
private double halfLifeOfNegativeInHours;
private double decayConstantNegative;
private static final double POSITIVE_CLASSIFICATION = 1.0;
private static final double DEFAULT_DECAY_CONSTANT = 173; //approximately 5 days
private double decayConstantOfPositive = DEFAULT_DECAY_CONSTANT;
private double decayConstantOfNegative = DEFAULT_DECAY_CONSTANT;
private PredictiveModelBuilder<?> wrappedBuilder;
private DateTimeExtractor dateTimeExtractor;
private Serializable iD;


public TemporallyReweightedPMBuilder( PredictiveModelBuilder<?> wrappedBuilder, DateTimeExtractor dateTimeExtractor) {
public TemporallyReweightedPMBuilder(PredictiveModelBuilder<?> wrappedBuilder, DateTimeExtractor dateTimeExtractor) {
this.wrappedBuilder = wrappedBuilder;
this.dateTimeExtractor = dateTimeExtractor;
}

public TemporallyReweightedPMBuilder halfLifeOfPositive(double halfLifeOfPositiveInDays) {
this.halfLifeOfPositiveInHours = halfLifeOfPositiveInDays*24;
this.decayConstantOfPositive = halfLifeOfPositiveInHours / Math.log(2);
this.decayConstantOfPositive = halfLifeOfPositiveInDays * DateTimeConstants.HOURS_PER_DAY / Math.log(2);
return this;
}

public TemporallyReweightedPMBuilder halfLifeOfNegative(double halfLifeOfNegativeInDays) {
this.halfLifeOfNegativeInHours = halfLifeOfNegativeInDays;
this.decayConstantNegative = halfLifeOfNegativeInHours / Math.log(2);
this.decayConstantOfNegative = halfLifeOfNegativeInDays * DateTimeConstants.HOURS_PER_DAY / Math.log(2);
return this;
}

//this function should not be in the PMB interface since it is specific to us.
@Override
public void setID(Serializable iD) {
this.iD = iD;
wrappedBuilder.setID(iD);
}

@Override
public TemporallyReweightedPM buildPredictiveModel(Iterable<? extends AbstractInstance> trainingData) {
ArrayList<AbstractInstance> trainingDataList = Lists.newArrayList();
for (AbstractInstance instance : trainingData)
trainingDataList.add(instance);
DateTime mostRecent = dateTimeExtractor.extractDateTime(trainingDataList.get(trainingDataList.size() - 1));
trainingDataList = reweightTrainingData(trainingDataList, mostRecent);
final PredictiveModel predictiveModel = wrappedBuilder.buildPredictiveModel(trainingData);
DateTime mostRecent = getMostRecentInstance(trainingData);
List<AbstractInstance> trainingDataList = reweightTrainingData(trainingData, mostRecent);
final PredictiveModel predictiveModel = wrappedBuilder.buildPredictiveModel(trainingDataList);
return new TemporallyReweightedPM(predictiveModel);
}

private ArrayList<AbstractInstance> reweightTrainingData(ArrayList<AbstractInstance> sortedData, DateTime mostRecentInstance) {
applyTemporalReweighting(sortedData, mostRecentInstance);
return sortedData;
}


//refactor to adjust positives and negatives in same pass
private void applyTemporalReweighting(List<? extends AbstractInstance> sortedData, DateTime mostRecentInstance) {
private List<AbstractInstance> reweightTrainingData(Iterable<? extends AbstractInstance> sortedData, DateTime mostRecentInstance) {
ArrayList<AbstractInstance> trainingDataList = Lists.newArrayList();
for (AbstractInstance instance : sortedData) {
double decayConstant = (instance.getClassification()==1.0) ? decayConstantOfPositive : decayConstantNegative;
DateTime timOfInstance = dateTimeExtractor.extractDateTime(instance);
double hoursBack = Hours.hoursBetween(mostRecentInstance, timOfInstance).getHours();
double decayConstant = (instance.getClassification().equals(POSITIVE_CLASSIFICATION)) ? decayConstantOfPositive : decayConstantOfNegative;
DateTime timeOfInstance = dateTimeExtractor.extractDateTime(instance);
double hoursBack = Hours.hoursBetween(mostRecentInstance, timeOfInstance).getHours();
double newWeight = Math.exp(-1.0 * hoursBack / decayConstant);
instance.setWeight(newWeight);
trainingDataList.add(instance.reweight(newWeight));
}
return trainingDataList;
}

private ArrayList<AbstractInstance> sortTrainingData(Iterable<? extends AbstractInstance> trainingData) {
List<AbstractInstance> sortedData = Lists.<AbstractInstance>newArrayList();
for (AbstractInstance instance : trainingData) {
sortedData.add(instance);
}

Comparator<AbstractInstance> comparator = new Comparator<AbstractInstance>() {
@Override
public int compare(AbstractInstance o1, AbstractInstance o2) {
DateTime firstInstance = dateTimeExtractor.extractDateTime(o1);
DateTime secondInstance = dateTimeExtractor.extractDateTime(o2);
if (firstInstance.isAfter(secondInstance)) {
return 1;
} else if (firstInstance.isEqual(secondInstance)) {
return 0;
} else {
return -1;
}
}
};

Collections.sort(sortedData, comparator);
return (ArrayList<AbstractInstance>)sortedData;
}

@Override
public PredictiveModelBuilder<TemporallyReweightedPM> updatable(final boolean updatable) {
this.wrappedBuilder.updatable(updatable);
return this;
}


@Override
public void updatePredictiveModel(TemporallyReweightedPM predictiveModel, Iterable<? extends AbstractInstance> newData, List<? extends AbstractInstance> trainingData, boolean splitNodes) {
if (wrappedBuilder instanceof UpdatablePredictiveModelBuilder) {
ArrayList<AbstractInstance> trainingDataList = IterableToArrayList(trainingData);
ArrayList<AbstractInstance> sortedNewData = sortTrainingData(newData); //don't need to sort...just get max element
DateTime mostRecentInstance = dateTimeExtractor.extractDateTime(sortedNewData.get(sortedNewData.size()-1));
DateTime mostRecentInstance = getMostRecentInstance(newData);

ArrayList<AbstractInstance> reweightedTrainingData = reweightTrainingData(trainingDataList, mostRecentInstance); //is this needed?
ArrayList<AbstractInstance> reweightedNewTrainingData = reweightTrainingData(sortedNewData, mostRecentInstance); //is this needed?
//Reweighting the 'original' training set might be a problem when splitting nodes on update
List<AbstractInstance> trainingDataList = reweightTrainingData(trainingData, mostRecentInstance);
List<AbstractInstance> newDataList = reweightTrainingData(newData, mostRecentInstance);

PredictiveModel pm = predictiveModel.getWrappedModel();
((UpdatablePredictiveModelBuilder) wrappedBuilder).updatePredictiveModel(pm, reweightedNewTrainingData, reweightedTrainingData, splitNodes);
((UpdatablePredictiveModelBuilder) wrappedBuilder).updatePredictiveModel(pm, newDataList, trainingDataList, splitNodes);
logger.info("Updating default predictive model");
} else {
throw new RuntimeException("Cannot update predictive model without UpdatablePredictiveModelBuilder");
}
}

private ArrayList<AbstractInstance> IterableToArrayList(List<? extends AbstractInstance> trainingData) {
ArrayList<AbstractInstance> trainingDataList = Lists.newArrayList();
for (AbstractInstance instance : trainingData)
trainingDataList.add(instance);
return trainingDataList;
private DateTime getMostRecentInstance(Iterable<? extends AbstractInstance> newData) {
DateTime mostRecent = null;
for(AbstractInstance instance : newData) {
DateTime instanceTime = dateTimeExtractor.extractDateTime(instance);
if (mostRecent == null || instanceTime.isAfter(mostRecent)) {
mostRecent = instanceTime;
}
}
return mostRecent;
}

@Override
public void stripData(TemporallyReweightedPM predictiveModel) {
if (wrappedBuilder instanceof UpdatablePredictiveModelBuilder) {
((UpdatablePredictiveModelBuilder) wrappedBuilder).stripData(predictiveModel.getWrappedModel());
}
else {
} else {
throw new RuntimeException("Cannot strip data without UpdatablePredictiveModelBuilder");
}

@@ -11,6 +11,8 @@

public class TemporallyReweightedPMBuilderBuilder implements PredictiveModelBuilderBuilder<TemporallyReweightedPM, TemporallyReweightedPMBuilder> {

public static final String HALF_LIFE_OF_NEGATIVE = "halfLifeOfNegative";
public static final String HALF_LIFE_OF_POSITIVE = "halfLifeOfPositive";
private final PredictiveModelBuilderBuilder<?, ?> wrappedBuilderBuilder;

public TemporallyReweightedPMBuilderBuilder(PredictiveModelBuilderBuilder<?, ?> wrappedBuilderBuilder) {
@@ -21,15 +23,16 @@ public TemporallyReweightedPMBuilderBuilder(PredictiveModelBuilderBuilder<?, ?>
public Map<String, FieldValueRecommender> createDefaultParametersToOptimize() {
Map<String, FieldValueRecommender> parametersToOptimize = Maps.newHashMap();
parametersToOptimize.putAll(wrappedBuilderBuilder.createDefaultParametersToOptimize());
parametersToOptimize.put("halfLifeOfNegative", new FixedOrderRecommender(5, 10, 20, 50));
parametersToOptimize.put("halfLifeOfPositive", new FixedOrderRecommender(5, 10, 20, 50));

parametersToOptimize.put(HALF_LIFE_OF_NEGATIVE, new FixedOrderRecommender(5.0, 10.0, 20.0));
parametersToOptimize.put(HALF_LIFE_OF_POSITIVE, new FixedOrderRecommender(5.0, 10.0, 20.0));
return parametersToOptimize;
}

@Override //set date time extractor to be correc
public TemporallyReweightedPMBuilder buildBuilder(final Map<String, Object> predictiveModelConfig) {
final double halfLifeOfPositive = (Double) predictiveModelConfig.get("halfLifeOfPositive");
final double halfLifeOfNegative = (Double) predictiveModelConfig.get("halfLifeOfNegative");
final double halfLifeOfPositive = (Double) predictiveModelConfig.get(HALF_LIFE_OF_POSITIVE);
final double halfLifeOfNegative = (Double) predictiveModelConfig.get(HALF_LIFE_OF_NEGATIVE);
return new TemporallyReweightedPMBuilder(wrappedBuilderBuilder.buildBuilder(predictiveModelConfig), new SimpleDateFormatExtractor())
.halfLifeOfNegative(halfLifeOfNegative)
.halfLifeOfPositive(halfLifeOfPositive);
@@ -1,15 +1,14 @@
package quickdt.predictiveModels;

import com.beust.jcommander.internal.Lists;
import quickdt.Misc;
import quickdt.data.Attributes;
import quickdt.data.HashMapAttributes;
import quickdt.data.Instance;

import java.io.*;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;

/**
* Created by Chris on 5/14/2014.
@@ -31,9 +30,18 @@ public static List<Instance> getIntegerInstances(int numInstances) {
for (int x = 0; x < numInstances; x++) {
final double height = (4 * 12) + Misc.random.nextInt(3 * 12);
final double weight = 120 + Misc.random.nextInt(110);
Calendar calendar = Calendar.getInstance();
final int year = calendar.get(Calendar.YEAR);
final int month = calendar.get(Calendar.MONTH);
final int day = Misc.random.nextInt(28)+1;
final int hour = Misc.random.nextInt(24);
final Attributes attributes = new HashMapAttributes();
attributes.put("weight", weight);
attributes.put("height", height);
attributes.put("timeOfArrival-year", year);
attributes.put("timeOfArrival-monthOfYear", month);
attributes.put("timeOfArrival-dayOfMonth", day);
attributes.put("timeOfArrival-hourOfDay", hour);
instances.add(new Instance(attributes, bmiHealthyInteger(weight, height)));
}
return instances;
@@ -0,0 +1,30 @@
package quickdt.predictiveModels.temporallyWeightPredictiveModel;

import org.testng.Assert;
import org.testng.annotations.Test;
import quickdt.crossValidation.SampleDateTimeExtractor;
import quickdt.data.Instance;
import quickdt.predictiveModels.TreeBuilderTestUtils;
import quickdt.predictiveModels.decisionTree.TreeBuilder;
import quickdt.predictiveModels.decisionTree.scorers.SplitDiffScorer;

import java.util.List;

/**
* Created by chrisreeves on 6/23/14.
*/
public class TemporallyReweightedPMBuilderTest {
@Test
public void simpleBmiTest() throws Exception {
final List<Instance> instances = TreeBuilderTestUtils.getIntegerInstances(10000);
final TreeBuilder tb = new TreeBuilder(new SplitDiffScorer());
final TemporallyReweightedPMBuilder builder = new TemporallyReweightedPMBuilder(tb, new SampleDateTimeExtractor());
final long startTime = System.currentTimeMillis();
final TemporallyReweightedPM model = builder.buildPredictiveModel(instances);

TreeBuilderTestUtils.serializeDeserialize(model);

Assert.assertTrue((System.currentTimeMillis() - startTime) < 20000,"Building this node should take far less than 20 seconds");
}

}