This repository has been archived by the owner on May 3, 2022. It is now read-only.
/
MLUpdate.java
378 lines (339 loc) · 15.3 KB
/
MLUpdate.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
/*
* Copyright (c) 2014, Cloudera and Intel, Inc. All Rights Reserved.
*
* Cloudera, Inc. licenses this file to you under the Apache License,
* Version 2.0 (the "License"). You may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
* CONDITIONS OF ANY KIND, either express or implied. See the License for
* the specific language governing permissions and limitations under the
* License.
*/
package com.cloudera.oryx.ml;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import com.google.common.base.Preconditions;
import com.typesafe.config.Config;
import org.apache.hadoop.fs.FileStatus;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.rdd.RDD;
import org.dmg.pmml.PMML;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.cloudera.oryx.api.TopicProducer;
import com.cloudera.oryx.api.batch.BatchLayerUpdate;
import com.cloudera.oryx.common.collection.Pair;
import com.cloudera.oryx.common.lang.ExecUtils;
import com.cloudera.oryx.common.pmml.PMMLUtils;
import com.cloudera.oryx.common.random.RandomManager;
import com.cloudera.oryx.common.settings.ConfigUtils;
import com.cloudera.oryx.ml.param.HyperParamValues;
import com.cloudera.oryx.ml.param.HyperParams;
/**
* A specialization of {@link BatchLayerUpdate} for machine learning-oriented
* update processes. This implementation contains the framework for test/train split
* for example, parameter optimization, and so on. Subclasses instead implement
* methods like {@link #buildModel(JavaSparkContext,JavaRDD,List,Path)} to create a PMML model and
* {@link #evaluate(JavaSparkContext,PMML,Path,JavaRDD,JavaRDD)} to evaluate a model from
* held-out test data.
*
* @param <M> type of message to read from the input topic
*/
public abstract class MLUpdate<M> implements BatchLayerUpdate<Object,M,String> {
private static final Logger log = LoggerFactory.getLogger(MLUpdate.class);
public static final String MODEL_FILE_NAME = "model.pmml";
private final double testFraction;
private final int candidates;
private final String hyperParamSearch;
private final int evalParallelism;
private final Double threshold;
private final int maxMessageSize;
protected MLUpdate(Config config) {
this.testFraction = config.getDouble("oryx.ml.eval.test-fraction");
int candidates = config.getInt("oryx.ml.eval.candidates");
this.evalParallelism = config.getInt("oryx.ml.eval.parallelism");
this.threshold = ConfigUtils.getOptionalDouble(config, "oryx.ml.eval.threshold");
this.maxMessageSize = config.getInt("oryx.update-topic.message.max-size");
Preconditions.checkArgument(testFraction >= 0.0 && testFraction <= 1.0);
Preconditions.checkArgument(candidates > 0);
Preconditions.checkArgument(evalParallelism > 0);
Preconditions.checkArgument(maxMessageSize > 0);
if (testFraction == 0.0 && candidates > 1) {
log.info("Eval is disabled (test fraction = 0) so candidates is overridden to 1");
candidates = 1;
}
this.candidates = candidates;
this.hyperParamSearch = config.getString("oryx.ml.eval.hyperparam-search");
}
protected final double getTestFraction() {
return testFraction;
}
/**
* @return a list of hyperparameter value ranges to try, one {@link HyperParamValues} per
* hyperparameter. Different combinations of the values derived from the list will be
* passed back into {@link #buildModel(JavaSparkContext,JavaRDD,List,Path)}
*/
public List<HyperParamValues<?>> getHyperParameterValues() {
return Collections.emptyList();
}
/**
* @param sparkContext active Spark Context
* @param trainData training data on which to build a model
* @param hyperParameters ordered list of hyper parameter values to use in building model
* @param candidatePath directory where additional model files can be written
* @return a {@link PMML} representation of a model trained on the given data
*/
public abstract PMML buildModel(JavaSparkContext sparkContext,
JavaRDD<M> trainData,
List<?> hyperParameters,
Path candidatePath);
/**
* @return {@code true} iff additional updates must be published along with the model; if
* {@link #publishAdditionalModelData(JavaSparkContext, PMML, JavaRDD, JavaRDD, Path, TopicProducer)} must
* be called. This is only applicable for special model types.
*/
public boolean canPublishAdditionalModelData() {
return false;
}
/**
* Optionally, publish additional model-related information to the update topic,
* after the model has been written. This is needed only in specific cases, like the
* ALS algorithm, where the model serialization in PMML can't contain all of the info.
*
* @param sparkContext active Spark Context
* @param pmml model for which extra data should be written
* @param newData data that has arrived in current interval
* @param pastData all previously-known data (may be {@code null})
* @param modelParentPath directory containing model files, if applicable
* @param modelUpdateTopic message topic to write to
*/
public void publishAdditionalModelData(JavaSparkContext sparkContext,
PMML pmml,
JavaRDD<M> newData,
JavaRDD<M> pastData,
Path modelParentPath,
TopicProducer<String, String> modelUpdateTopic) {
// Do nothing by default
}
/**
* @param sparkContext active Spark Context
* @param model model to evaluate
* @param modelParentPath directory containing model files, if applicable
* @param testData data on which to test the model performance
* @param trainData data on which model was trained, which can also be useful in evaluating
* unsupervised learning problems
* @return an evaluation of the model on the test data. Higher should mean "better"
*/
public abstract double evaluate(JavaSparkContext sparkContext,
PMML model,
Path modelParentPath,
JavaRDD<M> testData,
JavaRDD<M> trainData);
@Override
public void runUpdate(JavaSparkContext sparkContext,
long timestamp,
JavaPairRDD<Object,M> newKeyMessageData,
JavaPairRDD<Object,M> pastKeyMessageData,
String modelDirString,
TopicProducer<String,String> modelUpdateTopic)
throws IOException, InterruptedException {
Objects.requireNonNull(newKeyMessageData);
JavaRDD<M> newData = newKeyMessageData.values();
JavaRDD<M> pastData = pastKeyMessageData == null ? null : pastKeyMessageData.values();
if (newData != null) {
newData.cache();
// This forces caching of the RDD. This shouldn't be necessary but we see some freezes
// when many workers try to materialize the RDDs at once. Hence the workaround.
newData.foreachPartition(p -> {});
}
if (pastData != null) {
pastData.cache();
pastData.foreachPartition(p -> {});
}
List<List<?>> hyperParameterCombos = HyperParams.chooseHyperParameterCombos(
getHyperParameterValues(), hyperParamSearch, candidates);
Path modelDir = new Path(modelDirString);
Path tempModelPath = new Path(modelDir, ".temporary");
Path candidatesPath = new Path(tempModelPath, Long.toString(System.currentTimeMillis()));
FileSystem fs = FileSystem.get(modelDir.toUri(), sparkContext.hadoopConfiguration());
fs.mkdirs(candidatesPath);
Path bestCandidatePath = findBestCandidatePath(
sparkContext, newData, pastData, hyperParameterCombos, candidatesPath);
Path finalPath = new Path(modelDir, Long.toString(System.currentTimeMillis()));
if (bestCandidatePath == null) {
log.info("Unable to build any model");
} else {
// Move best model into place
fs.rename(bestCandidatePath, finalPath);
}
// Then delete everything else
fs.delete(candidatesPath, true);
if (modelUpdateTopic == null) {
log.info("No update topic configured, not publishing models to a topic");
} else {
// Push PMML model onto update topic, if it exists
Path bestModelPath = new Path(finalPath, MODEL_FILE_NAME);
if (fs.exists(bestModelPath)) {
FileStatus bestModelPathFS = fs.getFileStatus(bestModelPath);
PMML bestModel = null;
boolean modelNeededForUpdates = canPublishAdditionalModelData();
boolean modelNotTooLarge = bestModelPathFS.getLen() <= maxMessageSize;
if (modelNeededForUpdates || modelNotTooLarge) {
// Either the model is required for publishAdditionalModelData, or required because it's going to
// be serialized to Kafka
try (InputStream in = fs.open(bestModelPath)) {
bestModel = PMMLUtils.read(in);
}
}
if (modelNotTooLarge) {
modelUpdateTopic.send("MODEL", PMMLUtils.toString(bestModel));
} else {
modelUpdateTopic.send("MODEL-REF", fs.makeQualified(bestModelPath).toString());
}
if (modelNeededForUpdates) {
publishAdditionalModelData(
sparkContext, bestModel, newData, pastData, finalPath, modelUpdateTopic);
}
}
}
if (newData != null) {
newData.unpersist();
}
if (pastData != null) {
pastData.unpersist();
}
}
private Path findBestCandidatePath(JavaSparkContext sparkContext,
JavaRDD<M> newData,
JavaRDD<M> pastData,
List<List<?>> hyperParameterCombos,
Path candidatesPath) throws IOException {
Map<Path,Double> pathToEval = ExecUtils.collectInParallel(
candidates,
Math.min(evalParallelism, candidates),
true,
i -> buildAndEval(i, hyperParameterCombos, sparkContext, newData, pastData, candidatesPath),
Collectors.toMap(Pair::getFirst, Pair::getSecond));
FileSystem fs = null;
Path bestCandidatePath = null;
double bestEval = Double.NEGATIVE_INFINITY;
for (Map.Entry<Path,Double> pathEval : pathToEval.entrySet()) {
Path path = pathEval.getKey();
if (path != null) {
if (fs == null) {
fs = FileSystem.get(path.toUri(), sparkContext.hadoopConfiguration());
}
if (fs.exists(path)) {
Double eval = pathEval.getValue();
if (!Double.isNaN(eval)) {
// Valid evaluation; if it's the best so far, keep it
if (eval > bestEval) {
log.info("Best eval / model path is now {} / {}", eval, path);
bestEval = eval;
bestCandidatePath = path;
}
} else if (bestCandidatePath == null && testFraction == 0.0) {
// Normal case when eval is disabled; no eval is possible, but keep the one model
// that was built
bestCandidatePath = path;
}
} // else can't do anything; no model at all
}
}
if (threshold != null && bestEval < threshold) {
log.info("Best model at {} had eval {}, but did not exceed threshold {}; discarding model",
bestCandidatePath, bestEval, threshold);
bestCandidatePath = null;
}
return bestCandidatePath;
}
private Pair<Path,Double> buildAndEval(int i,
List<? extends List<?>> hyperParameterCombos,
JavaSparkContext sparkContext,
JavaRDD<M> newData,
JavaRDD<M> pastData,
Path candidatesPath) {
// % = cycle through combinations if needed
List<?> hyperParameters = hyperParameterCombos.get(i % hyperParameterCombos.size());
Path candidatePath = new Path(candidatesPath, Integer.toString(i));
log.info("Building candidate {} with params {}", i, hyperParameters);
Pair<JavaRDD<M>,JavaRDD<M>> trainTestData = splitTrainTest(newData, pastData);
JavaRDD<M> allTrainData = trainTestData.getFirst();
JavaRDD<M> testData = trainTestData.getSecond();
double eval = Double.NaN;
if (empty(allTrainData)) {
log.info("No train data to build a model");
} else {
PMML model = buildModel(sparkContext, allTrainData, hyperParameters, candidatePath);
if (model == null) {
log.info("Unable to build a model");
} else {
Path modelPath = new Path(candidatePath, MODEL_FILE_NAME);
log.info("Writing model to {}", modelPath);
try {
FileSystem fs = FileSystem.get(candidatePath.toUri(), sparkContext.hadoopConfiguration());
fs.mkdirs(candidatePath);
try (OutputStream out = fs.create(modelPath)) {
PMMLUtils.write(model, out);
}
} catch (IOException ioe) {
throw new IllegalStateException(ioe);
}
if (empty(testData)) {
log.info("No test data available to evaluate model");
} else {
log.info("Evaluating model");
eval = evaluate(sparkContext, model, candidatePath, testData, allTrainData);
}
}
}
log.info("Model eval for params {}: {} ({})", hyperParameters, eval, candidatePath);
return new Pair<>(candidatePath, eval);
}
private Pair<JavaRDD<M>,JavaRDD<M>> splitTrainTest(JavaRDD<M> newData, JavaRDD<M> pastData) {
Objects.requireNonNull(newData);
if (testFraction <= 0.0) {
return new Pair<>(pastData == null ? newData : newData.union(pastData), null);
}
if (testFraction >= 1.0) {
return new Pair<>(pastData, newData);
}
if (empty(newData)) {
return new Pair<>(pastData, null);
}
Pair<JavaRDD<M>,JavaRDD<M>> newTrainTest = splitNewDataToTrainTest(newData);
JavaRDD<M> newTrainData = newTrainTest.getFirst();
return new Pair<>(pastData == null ? newTrainData : newTrainData.union(pastData),
newTrainTest.getSecond());
}
private static boolean empty(JavaRDD<?> rdd) {
return rdd == null || rdd.isEmpty();
}
/**
* Default implementation which randomly splits new data into train/test sets.
* This handles the case where {@link #getTestFraction()} is not 0 or 1.
*
* @param newData data that has arrived in the current input batch
* @return a {@link Pair} of train, test {@link RDD}s.
*/
protected Pair<JavaRDD<M>,JavaRDD<M>> splitNewDataToTrainTest(JavaRDD<M> newData) {
RDD<M>[] testTrainRDDs = newData.rdd().randomSplit(
new double[]{1.0 - testFraction, testFraction},
RandomManager.getRandom().nextLong());
return new Pair<>(newData.wrapRDD(testTrainRDDs[0]),
newData.wrapRDD(testTrainRDDs[1]));
}
}