forked from apache/spark
-
Notifications
You must be signed in to change notification settings - Fork 0
/
RegressionEvaluator.scala
141 lines (119 loc) · 4.84 KB
/
RegressionEvaluator.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.Since
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.RegressionMetrics
import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, FloatType}
/**
* Evaluator for regression, which expects input columns prediction, label and
* an optional weight column.
*/
@Since("1.4.0")
final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Evaluator with HasPredictionCol with HasLabelCol
with HasWeightCol with DefaultParamsWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("regEval"))
/**
* Param for metric name in evaluation. Supports:
* - `"rmse"` (default): root mean squared error
* - `"mse"`: mean squared error
* - `"r2"`: R^2^ metric
* - `"mae"`: mean absolute error
* - `"var"`: explained variance
*
* @group param
*/
@Since("1.4.0")
val metricName: Param[String] = {
val allowedParams = ParamValidators.inArray(Array("mse", "rmse", "r2", "mae", "var"))
new Param(this, "metricName", "metric name in evaluation (mse|rmse|r2|mae|var)", allowedParams)
}
/** @group getParam */
@Since("1.4.0")
def getMetricName: String = $(metricName)
/** @group setParam */
@Since("1.4.0")
def setMetricName(value: String): this.type = set(metricName, value)
/**
* param for whether the regression is through the origin.
* Default: false.
* @group expertParam
*/
@Since("3.0.0")
val throughOrigin: BooleanParam = new BooleanParam(this, "throughOrigin",
"Whether the regression is through the origin.")
/** @group expertGetParam */
@Since("3.0.0")
def getThroughOrigin: Boolean = $(throughOrigin)
/** @group expertSetParam */
@Since("3.0.0")
def setThroughOrigin(value: Boolean): this.type = set(throughOrigin, value)
setDefault(throughOrigin -> false)
/** @group setParam */
@Since("1.4.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)
/** @group setParam */
@Since("1.4.0")
def setLabelCol(value: String): this.type = set(labelCol, value)
/** @group setParam */
@Since("3.0.0")
def setWeightCol(value: String): this.type = set(weightCol, value)
setDefault(metricName -> "rmse")
@Since("2.0.0")
override def evaluate(dataset: Dataset[_]): Double = {
val schema = dataset.schema
SchemaUtils.checkColumnTypes(schema, $(predictionCol), Seq(DoubleType, FloatType))
SchemaUtils.checkNumericType(schema, $(labelCol))
val predictionAndLabelsWithWeights = dataset
.select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType),
if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)))
.rdd
.map { case Row(prediction: Double, label: Double, weight: Double) =>
(prediction, label, weight) }
val metrics = new RegressionMetrics(predictionAndLabelsWithWeights, $(throughOrigin))
$(metricName) match {
case "rmse" => metrics.rootMeanSquaredError
case "mse" => metrics.meanSquaredError
case "r2" => metrics.r2
case "mae" => metrics.meanAbsoluteError
case "var" => metrics.explainedVariance
}
}
@Since("1.4.0")
override def isLargerBetter: Boolean = $(metricName) match {
case "r2" | "var" => true
case _ => false
}
@Since("1.5.0")
override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra)
@Since("3.0.0")
override def toString: String = {
s"RegressionEvaluator: uid=$uid, metricName=${$(metricName)}, " +
s"throughOrigin=${$(throughOrigin)}"
}
}
@Since("1.6.0")
object RegressionEvaluator extends DefaultParamsReadable[RegressionEvaluator] {
@Since("1.6.0")
override def load(path: String): RegressionEvaluator = super.load(path)
}