Skip to content

Commit ad0130a

Browse files
authored
[jvm-packages] Expose the global configuration. (dmlc#11238)
1 parent 138e146 commit ad0130a

File tree

5 files changed

+163
-4
lines changed

5 files changed

+163
-4
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
Copyright (c) 2025 by Contributors
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
package ml.dmlc.xgboost4j.java;
17+
18+
import java.util.HashMap;
19+
import java.util.Map;
20+
21+
import com.fasterxml.jackson.core.JsonProcessingException;
22+
import com.fasterxml.jackson.core.type.TypeReference;
23+
import com.fasterxml.jackson.databind.ObjectMapper;
24+
25+
/**
26+
* Global configuration context for XGBoost.
27+
*
28+
* @version 3.0.0
29+
*
30+
* See the parameter document for supported global configuration. The configuration is
31+
* restored upon close.
32+
*/
33+
public class ConfigContext implements AutoCloseable {
34+
String orig;
35+
36+
ConfigContext() throws XGBoostError {
37+
this.orig = getImpl();
38+
}
39+
40+
static String getImpl() throws XGBoostError {
41+
String[] config = new String[1];
42+
XGBoostJNI.checkCall(XGBoostJNI.XGBGetGlobalConfig(config));
43+
return config[0];
44+
}
45+
46+
public static Map<String, Object> get() throws XGBoostError {
47+
String jconfig = getImpl();
48+
ObjectMapper mapper = new ObjectMapper();
49+
try {
50+
Map<String, Object> config = mapper.readValue(jconfig,
51+
new TypeReference<Map<String, Object>>() {
52+
});
53+
return config;
54+
} catch (JsonProcessingException ex) {
55+
throw new XGBoostError("Failed to get the global config due to a decode error.", ex);
56+
}
57+
}
58+
59+
public <T> ConfigContext set(String key, T value) throws XGBoostError {
60+
HashMap<String, Object> map = new HashMap<String, Object>();
61+
map.put(key, value);
62+
ObjectMapper mapper = new ObjectMapper();
63+
try {
64+
String config = mapper.writeValueAsString(map);
65+
XGBoostJNI.checkCall(XGBoostJNI.XGBSetGlobalConfig(config));
66+
} catch (JsonProcessingException ex) {
67+
throw new XGBoostError("Failed to set the global config due to an encode error.", ex);
68+
}
69+
return this;
70+
}
71+
72+
@Override
73+
public void close() throws XGBoostError {
74+
XGBoostJNI.checkCall(XGBoostJNI.XGBSetGlobalConfig(this.orig));
75+
}
76+
};

jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoostJNI.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,4 +183,7 @@ public final static native int XGDMatrixCreateFromArrayInterfaceColumns(
183183

184184
public final static native int XGDMatrixGetQuantileCut(long handle, long[][] outIndptr, float[][] outValues);
185185

186+
public final static native int XGBSetGlobalConfig(String config);
187+
188+
public final static native int XGBGetGlobalConfig(String[] out);
186189
}

jvm-packages/xgboost4j/src/native/xgboost4j.cpp

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,10 +1206,6 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWaitFor(JNI
12061206
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_TrackerWorkerArgs(
12071207
JNIEnv *jenv, jclass, jlong jhandle, jlong timeout, jobjectArray jout) {
12081208
using namespace xgboost; // NOLINT
1209-
1210-
Json config{Object{}};
1211-
config["timeout"] = Integer{static_cast<Integer::Int>(timeout)};
1212-
std::string sconfig = Json::Dump(config);
12131209
auto handle = reinterpret_cast<TrackerHandle>(jhandle);
12141210
char const *args;
12151211
JVM_CHECK_CALL(XGTrackerWorkerArgs(handle, &args));
@@ -1534,3 +1530,37 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGDMatrixGetQuanti
15341530

15351531
return ret;
15361532
}
1533+
1534+
/*
1535+
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
1536+
* Method: XGBSetGlobalConfig
1537+
* Signature: (Ljava/lang/String;)I
1538+
*/
1539+
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBSetGlobalConfig(JNIEnv *jenv,
1540+
jclass,
1541+
jstring config) {
1542+
std::unique_ptr<char const, Deleter<char const>> args{
1543+
jenv->GetStringUTFChars(config, nullptr), [&](char const *ptr) {
1544+
if (ptr) {
1545+
jenv->ReleaseStringUTFChars(config, ptr);
1546+
}
1547+
}};
1548+
auto ret = XGBSetGlobalConfig(args.get());
1549+
JVM_CHECK_CALL(ret);
1550+
return ret;
1551+
}
1552+
1553+
/*
1554+
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
1555+
* Method: XGBGetGlobalConfig
1556+
* Signature: ([Ljava/lang/String;)I
1557+
*/
1558+
JNIEXPORT jint JNICALL
1559+
Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBGetGlobalConfig(JNIEnv *jenv, jclass, jobjectArray jout) {
1560+
char const *args;
1561+
auto ret = XGBGetGlobalConfig(&args);
1562+
JVM_CHECK_CALL(ret);
1563+
jstring jret = jenv->NewStringUTF(args);
1564+
jenv->SetObjectArrayElement(jout, 0, jret);
1565+
return 0;
1566+
}

jvm-packages/xgboost4j/src/native/xgboost4j.h

Lines changed: 16 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/*
2+
Copyright (c) 2025 by Contributors
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
package ml.dmlc.xgboost4j.java;
17+
18+
import junit.framework.TestCase;
19+
import org.junit.Test;
20+
21+
/**
22+
* Test cases for the config context.
23+
*/
24+
public class ConfigContextTest {
25+
@Test
26+
public void testBasic() throws XGBoostError {
27+
try (ConfigContext ctx = new ConfigContext().set("verbosity", 3)) {
28+
int v = (int) ConfigContext.get().get("verbosity");
29+
TestCase.assertEquals(3, v);
30+
}
31+
int v = (int) ConfigContext.get().get("verbosity");
32+
TestCase.assertEquals(1, v);
33+
}
34+
}

0 commit comments

Comments
 (0)