Skip to content

Commit

Permalink
Let user specify feature importance type for XGBoost (#490)
Browse files Browse the repository at this point in the history
  • Loading branch information
TuanNguyen27 committed Jul 15, 2020
1 parent 9857138 commit 7baf0f8
Showing 1 changed file with 38 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,13 @@

package ml.dmlc.xgboost4j.scala.spark

import enumeratum.{Enum, EnumEntry}
import ml.dmlc.xgboost4j.LabeledPoint
import ml.dmlc.xgboost4j.scala.Booster
import ml.dmlc.xgboost4j.scala.spark.params.GeneralParams
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}

import scala.collection.mutable.ArrayBuffer

/**
* Hack to access [[XGBoostClassifierParams]]
*/
Expand Down Expand Up @@ -86,11 +85,14 @@ case object OpXGBoost {
/**
* Converts feature score map into a vector
*
* @param featureVectorSize size of feature vectors the xgboost model is trained on
* @param featureVectorSize size of feature vectors the xgboost model is trained on
* @param importanceType type of feature importance to calculate [Gain, Cover, TotalGain, TotalCover]
* @return vector containing feature scores
*/
def getFeatureScoreVector(featureVectorSize: Option[Int] = None): Vector = {
val featureScore = booster.getFeatureScore()
def getFeatureScoreVector(
featureVectorSize: Option[Int] = None, importanceType: ImportanceType = ImportanceType.Gain
): Vector = {
val featureScore = booster.getScore(featureMap = null, importanceType = importanceType.name)
require(featureScore.nonEmpty, "Feature score map is empty")
val indexScore = featureScore.map { case (fid, score) =>
val index = fid.tail.toInt
Expand All @@ -109,3 +111,34 @@ case object OpXGBoost {
def processMissingValues(xgbLabelPoints: Iterator[LabeledPoint], missing: Float): Iterator[LabeledPoint] =
XGBoost.processMissingValues(xgbLabelPoints, missing)
}

/**
* Settings for XGBoost feature importance type
*/
sealed abstract class ImportanceType(val name: String) extends EnumEntry with Serializable

object ImportanceType extends Enum[ImportanceType] {
val values: Seq[ImportanceType] = findValues

/**
* The average gain across all splits the feature is used in.
*/
case object Gain extends ImportanceType(name = "gain")

/**
* The average coverage across all splits the feature is used in.
*/
case object Cover extends ImportanceType(name = "cover")

/**
* The total gain across all splits the feature is used in.
*/
case object TotalGain extends ImportanceType(name = "total_gain")

/**
* The total coverage across all splits the feature is used in.
*/
case object TotalCover extends ImportanceType(name = "total_cover")

}

0 comments on commit 7baf0f8

Please sign in to comment.