In [1]:
import time
import xgboost as xgb
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

import jax.numpy as jnp
from jax import grad, jacfwd
import numpy as np


# Custom binary logistic loss using JAX
def custom_logistic_loss(preds, labels):
    preds = 1 / (1 + jnp.exp(-preds))  # sigmoid
    eps = 1e-7
    return -jnp.sum(labels * jnp.log(preds + eps) + (1 - labels) * jnp.log(1 - preds + eps))


def custom_obj_jax(preds, dtrain):
    labels = dtrain.get_label()
    labels = jnp.array(labels)

    def loss_fn(preds_):
        return custom_logistic_loss(preds_, labels)

    # Get gradient and hessian using JAX
    grad_fn = grad(loss_fn)
    hess_fn = jacfwd(grad_fn)

    preds_jax = jnp.array(preds)
    grad_val = grad_fn(preds_jax)
    hess_val = hess_fn(preds_jax)

    return np.array(grad_val), np.array(hess_val)


# Generate synthetic data
X, y = make_classification(n_samples=1_000_000, n_features=20, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test, label=y_test)

# Base parameters (we override 'objective' to use our custom one)
params = {
    'tree_method': 'hist',
    'verbosity': 1,
    'eta': 0.3,
    'device': 'cuda',
}


def train_model(params, label):
    print(f"Training with {label}...")
    start = time.time()
    bst = xgb.train(params, dtrain, num_boost_round=100, obj=custom_obj_jax)
    end = time.time()

    preds = bst.predict(dtest)
    preds_binary = [1 if p > 0.5 else 0 for p in preds]
    acc = accuracy_score(y_test, preds_binary)

    print(f"{label} Accuracy: {acc:.4f}")
    print(f"{label} Time: {end - start:.2f} seconds")


# Run with custom JAX objective
train_model(params, "CPU w/ JAX custom objective")


Training with CPU w/ JAX custom objective...


2025-06-03 18:27:08.936208: E external/xla/xla/service/gpu/gpu_hlo_schedule.cc:654] The byte size of input/output arguments (2560000000000) exceeds the base limit (31805177856). This indicates an error in the calculation!
2025-06-03 18:27:08.936348: W external/xla/xla/hlo/transforms/simplifiers/hlo_rematerialization.cc:3022] Can't reduce memory use below 0B (0 bytes) by rematerialization; only reduced to 2.33TiB (2560000000000 bytes), down from 2.33TiB (2560000000000 bytes) originally
2025-06-03 18:27:18.961465: W external/xla/xla/tsl/framework/bfc_allocator.cc:501] Allocator (GPU_0_bfc) ran out of memory trying to allocate 2.33TiB (rounded to 2560000000000)requested by op 
2025-06-03 18:27:18.961554: W external/xla/xla/tsl/framework/bfc_allocator.cc:512] *___________________________________________________________________________________________________
E0603 18:27:18.961568   30137 pjrt_stream_executor_client.cc:2917] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory w

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 2560000000000 bytes.