From 7baf0f8e5f9a43d0a3e6c214da983e32ae4b1343 Mon Sep 17 00:00:00 2001 From: Tuan Nguyen Date: Wed, 15 Jul 2020 10:38:54 -0700 Subject: [PATCH] Let user specify feature importance type for XGBoost (#490) --- .../xgboost4j/scala/spark/XGBoostParams.scala | 43 ++++++++++++++++--- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostParams.scala b/core/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostParams.scala index 3c8959997d..18643f82f8 100644 --- a/core/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostParams.scala +++ b/core/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostParams.scala @@ -30,14 +30,13 @@ package ml.dmlc.xgboost4j.scala.spark +import enumeratum.{Enum, EnumEntry} import ml.dmlc.xgboost4j.LabeledPoint import ml.dmlc.xgboost4j.scala.Booster import ml.dmlc.xgboost4j.scala.spark.params.GeneralParams import org.apache.log4j.{Level, Logger} import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} -import scala.collection.mutable.ArrayBuffer - /** * Hack to access [[XGBoostClassifierParams]] */ @@ -86,11 +85,14 @@ case object OpXGBoost { /** * Converts feature score map into a vector * - * @param featureVectorSize size of feature vectors the xgboost model is trained on + * @param featureVectorSize size of feature vectors the xgboost model is trained on + * @param importanceType type of feature importance to calculate [Gain, Cover, TotalGain, TotalCover] * @return vector containing feature scores */ - def getFeatureScoreVector(featureVectorSize: Option[Int] = None): Vector = { - val featureScore = booster.getFeatureScore() + def getFeatureScoreVector( + featureVectorSize: Option[Int] = None, importanceType: ImportanceType = ImportanceType.Gain + ): Vector = { + val featureScore = booster.getScore(featureMap = null, importanceType = importanceType.name) require(featureScore.nonEmpty, "Feature score map is empty") val indexScore = featureScore.map { case (fid, score) => val index = fid.tail.toInt @@ -109,3 +111,34 @@ case object OpXGBoost { def processMissingValues(xgbLabelPoints: Iterator[LabeledPoint], missing: Float): Iterator[LabeledPoint] = XGBoost.processMissingValues(xgbLabelPoints, missing) } + +/** + * Settings for XGBoost feature importance type + */ +sealed abstract class ImportanceType(val name: String) extends EnumEntry with Serializable + +object ImportanceType extends Enum[ImportanceType] { + val values: Seq[ImportanceType] = findValues + + /** + * The average gain across all splits the feature is used in. + */ + case object Gain extends ImportanceType(name = "gain") + + /** + * The average coverage across all splits the feature is used in. + */ + case object Cover extends ImportanceType(name = "cover") + + /** + * The total gain across all splits the feature is used in. + */ + case object TotalGain extends ImportanceType(name = "total_gain") + + /** + * The total coverage across all splits the feature is used in. + */ + case object TotalCover extends ImportanceType(name = "total_cover") + +} +