-
Notifications
You must be signed in to change notification settings - Fork 32
/
PredictNewsClassDemo.scala
63 lines (52 loc) · 1.96 KB
/
PredictNewsClassDemo.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
package applications.mining
import algorithms.evaluation.MultiClassEvaluation
import config.paramconf.ClassParams
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.PipelineModel
import org.apache.spark.sql.{Row, SparkSession}
/**
* 新闻多分类模型测试
*
* Created by yhao on 2017/3/15.
*/
object PredictNewsClassDemo extends Serializable {
def main(args: Array[String]): Unit = {
Logger.getLogger("org").setLevel(Level.WARN)
val spark = SparkSession
.builder
.master("local[2]")
.appName("predict news multi class demo")
.getOrCreate()
val args = Array("ckooc-ml/data/classnews/predict", "lr")
val filePath = args(0)
val modelType = args(1)
var modelPath = ""
val params = new ClassParams
modelType match {
case "lr" => modelPath = params.LRModelPath
case "dt" => modelPath = params.DTModelPath
case _ =>
println("模型类型错误!")
System.exit(1)
}
import spark.implicits._
val data = spark.sparkContext.textFile(filePath).flatMap { line =>
val tokens: Array[String] = line.split("\u00ef")
if (tokens.length > 3) Some((tokens(0), tokens(1), tokens(2), tokens(3))) else None
}.toDF("label", "title", "time", "content")
data.persist()
//加载模型,进行数据转换
val model = PipelineModel.load(modelPath)
val predictions = model.transform(data)
//=== 模型评估
val resultRDD = predictions.select("prediction", "indexedLabel").rdd.map { case Row(prediction: Double, label: Double) => (prediction, label) }
val (precision, recall, f1) = MultiClassEvaluation.multiClassEvaluate(resultRDD)
println("\n\n========= 评估结果 ==========")
println(s"\n加权准确率:$precision")
println(s"加权召回率:$recall")
println(s"F1值:$f1")
// predictions.select("label", "predictedLabel", "content").show(100, truncate = false)
data.unpersist()
spark.stop()
}
}