Skip to content

[jvm-packages] Fix parameters #11489

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 14, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -31,15 +31,16 @@ public ExtMemQuantileDMatrix(Iterator<ColumnBatch> iter,
DMatrix ref,
int nthread,
int maxQuantileBatches,
int minCachePageBytes) throws XGBoostError {
long minCachePageBytes,
float cacheHostRatio) throws XGBoostError {
long[] out = new long[1];
long[] refHandle = null;
if (ref != null) {
refHandle = new long[1];
refHandle[0] = ref.getHandle();
}
String conf = this.getConfig(missing, maxBin, nthread,
maxQuantileBatches, minCachePageBytes);
maxQuantileBatches, minCachePageBytes, cacheHostRatio);
XGBoostJNI.checkCall(XGBoostJNI.XGExtMemQuantileDMatrixCreateFromCallback(
iter, refHandle, conf, out));
handle = out[0];
@@ -50,7 +51,7 @@ public ExtMemQuantileDMatrix(
float missing,
int maxBin,
DMatrix ref) throws XGBoostError {
this(iter, missing, maxBin, ref, 0, -1, -1);
this(iter, missing, maxBin, ref, 0, -1, -1, Float.NaN);
}

public ExtMemQuantileDMatrix(
@@ -61,19 +62,23 @@ public ExtMemQuantileDMatrix(
}

private String getConfig(float missing, int maxBin, int nthread,
int maxQuantileBatches, int minCachePageBytes) {
int maxQuantileBatches, long minCachePageBytes, float cacheHostRatio) {
Map<String, Object> conf = new java.util.HashMap<>();
conf.put("missing", missing);
conf.put("max_bin", maxBin);
conf.put("nthread", nthread);

if (maxQuantileBatches > 0) {
conf.put("max_quantile_batches", maxQuantileBatches);
conf.put("max_quantile_blocks", maxQuantileBatches);
}
if (minCachePageBytes > 0) {
conf.put("min_cache_page_bytes", minCachePageBytes);
}

if (cacheHostRatio >= 0.0 && cacheHostRatio <= 1.0) {
conf.put("cache_host_ratio", cacheHostRatio);
}

conf.put("on_host", true);
conf.put("cache_prefix", ".");
ObjectMapper mapper = new ObjectMapper();
Original file line number Diff line number Diff line change
@@ -28,10 +28,11 @@ class ExtMemQuantileDMatrix private[scala](
ref: Option[QuantileDMatrix],
nthread: Int,
maxQuantileBatches: Int,
minCachePageBytes: Int) {
minCachePageBytes: Long,
cacheHostRatio: Float) {
this(new jExtMemQuantileDMatrix(iter.asJava, missing, maxBin,
ref.map(_.jDMatrix).orNull,
nthread, maxQuantileBatches, minCachePageBytes))
nthread, maxQuantileBatches, minCachePageBytes, cacheHostRatio))
}

def this(iter: Iterator[ColumnBatch], missing: Float, maxBin: Int) {
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@ package ml.dmlc.xgboost4j.scala.spark

import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._
import scala.util.Try

import ai.rapids.cudf.Table
import com.nvidia.spark.rapids.{ColumnarRdd, GpuColumnVectorUtils}
@@ -134,6 +135,7 @@ class GpuXGBoostPlugin extends XGBoostPlugin {

val maxQuantileBatches = estimator.getMaxQuantileBatches
val minCachePageBytes = estimator.getMinCachePageBytes
val cacheHostRatio = Try(estimator.getCacheHostRatio).getOrElse(Float.NaN)

/** build QuantileDMatrix on the executor side */
def buildQuantileDMatrix(input: Iterator[Table],
@@ -143,7 +145,7 @@ class GpuXGBoostPlugin extends XGBoostPlugin {
case Some(_) =>
val itr = new ExternalMemoryIterator(input, indices, extMemPath)
new ExtMemQuantileDMatrix(itr, missing, maxBin, ref, nthread,
maxQuantileBatches, minCachePageBytes)
maxQuantileBatches, minCachePageBytes, cacheHostRatio)

case None =>
val itr = input.map { table =>
@@ -188,7 +190,6 @@ class GpuXGBoostPlugin extends XGBoostPlugin {

val sconf = dataset.sparkSession.conf
val rmmEnabled: Boolean = try {
sconf.get("spark.rapids.memory.gpu.pooling.enabled").toBoolean &&
sconf.get("spark.rapids.memory.gpu.pool").trim.toLowerCase != "none"
} catch {
case _: Throwable => false // Any exception will return false
Original file line number Diff line number Diff line change
@@ -193,10 +193,18 @@ private[spark] trait SparkParams[T <: Params] extends HasFeaturesCols with HasFe

final def getMaxQuantileBatches: Int = $(maxQuantileBatches)

final val minCachePageBytes = new IntParam(this, "minCachePageBytes", "Minimum number of " +
final val minCachePageBytes = new LongParam(this, "minCachePageBytes", "Minimum number of " +
"bytes for each ellpack page in cache. Only used for in-host")

final def getMinCachePageBytes: Int = $(minCachePageBytes)
final def getMinCachePageBytes: Long = $(minCachePageBytes)

final val cacheHostRatio = new FloatParam(this, "cacheHostRatio",
"Used by the GPU implementation. For GPU-based inputs, XGBoost can split the cache into " +
"host and device caches to reduce the data transfer overhead. This parameter specifies " +
"the size of host cache compared to the size of the entire cache: host / (host + device)",
ParamValidators.inRange(0.0, 1.0))

final def getCacheHostRatio: Float = $(cacheHostRatio)

setDefault(numRound -> 100, numWorkers -> 1, inferBatchSize -> (32 << 10),
numEarlyStoppingRounds -> 0, forceRepartition -> false, missing -> Float.NaN,
@@ -248,7 +256,10 @@ private[spark] trait SparkParams[T <: Params] extends HasFeaturesCols with HasFe

def setMaxQuantileBatches(value: Int): T = set(maxQuantileBatches, value).asInstanceOf[T]

def setMinCachePageBytes(value: Int): T = set(minCachePageBytes, value).asInstanceOf[T]
def setMinCachePageBytes(value: Long): T = set(minCachePageBytes, value).asInstanceOf[T]

def setCacheHostRatio(value: Float): T = set(cacheHostRatio, value)
.asInstanceOf[T]

protected[spark] def featureIsArrayType(schema: StructType): Boolean =
schema(getFeaturesCol).dataType.isInstanceOf[ArrayType]
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
Copyright (c) 2024 by Contributors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package ml.dmlc.xgboost4j.scala.spark

import scala.util.Try

import org.scalatest.funsuite.AnyFunSuite


class XGBoostParamsSuite extends AnyFunSuite with PerTest with TmpFolderPerSuite {

test("invalid parameters") {
val df = smallBinaryClassificationVector
val estimator = new XGBoostClassifier()

// We didn't set it by default
var thrown = intercept[RuntimeException] {
estimator.getCacheHostRatio
}
assert(thrown.getMessage.contains("Failed to find a default value for cacheHostRatio"))

val v = Try(estimator.getCacheHostRatio).getOrElse(Float.NaN)
assert(v.equals(Float.NaN))

// We didn't set it by default
thrown = intercept[RuntimeException] {
estimator.setCacheHostRatio(-1.0f)
}
assert(thrown.getMessage.contains("parameter cacheHostRatio given invalid value -1.0"))

Seq(0.0f, 0.2f, 1.0f).forall(v => {
estimator.setCacheHostRatio(v)
estimator.getCacheHostRatio == v
})

estimator.setCacheHostRatio(0.66f)
val v1 = Try(estimator.getCacheHostRatio).getOrElse(Float.NaN)
assert(v1 == 0.66f)
}

}
Loading
Oops, something went wrong.