Skip to content

Trust internal scikit-learn types needed for GB/HGB models#513

Merged
adrinjalali merged 13 commits intomainfrom
copilot/ping-copilot-review-response
Apr 17, 2026
Merged

Trust internal scikit-learn types needed for GB/HGB models#513
adrinjalali merged 13 commits intomainfrom
copilot/ping-copilot-review-response

Conversation

Copy link
Copy Markdown
Contributor

Copilot AI commented Apr 16, 2026

  • Investigate CI failures
  • Fix black formatting in test_persist.py (collapse multi-line startswith to single line)
  • Fix pixi.lock (update sha256, pypi path, add editable flag)
  • Test locally to verify fixes (891 tests pass, black/ruff clean)

Copilot AI and others added 2 commits April 16, 2026 11:44
Address review comments from PR #508:
- Add sklearn internal types to _trusted_types.py instead of creating
  a new SklearnInternalObjectNode (review comment 2)
- Include internal types in ObjectNode's trusted defaults
- Add tests that save/load and verify correctness (review comment 3)
- Filter non-sklearn-prefixed types where appropriate (review comment 1)

Agent-Logs-Url: https://github.com/skops-dev/skops/sessions/6dadf905-6073-4057-bf3d-faed31a559fe

Co-authored-by: adrinjalali <1663864+adrinjalali@users.noreply.github.com>
Agent-Logs-Url: https://github.com/skops-dev/skops/sessions/6dadf905-6073-4057-bf3d-faed31a559fe

Co-authored-by: adrinjalali <1663864+adrinjalali@users.noreply.github.com>
Copy link
Copy Markdown
Member

@adrinjalali adrinjalali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot please check this review

Comment thread skops/io/tests/test_persist.py Outdated
Comment on lines +532 to +535
with warnings.catch_warnings():
warnings.filterwarnings("ignore", module="sklearn")
estimator.fit(X, y)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tell me what warnings we're catching and ignoring here

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No warnings are actually produced during fitting of these estimators — I verified this by running all 9 configs with warnings.catch_warnings(record=True) and none produced any warnings. The suppression was cargo-culted from test_can_persist_fitted which does need it for some sklearn estimators. Removed in f17710e.

Comment on lines +73 to +79
try:
from sklearn._loss._loss import CyHalfMultinomialLoss

_SKLEARN_INTERNAL_TYPES.append(CyHalfMultinomialLoss)
except ImportError:
pass

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this feels sketchy, which you also had trouble while working on it. Tell me where the issue comes from, and what would needs to happen maybe on the sklearn side for this not to be like this.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CyHalfMultinomialLoss is a Cython extension type from sklearn._loss._loss. Its __module__ reports _loss instead of the fully qualified sklearn._loss._loss — this is because the Cython .pyx source file doesn't set the module name to the full package path. All Cy* types in sklearn._loss._loss have this same issue (CyAbsoluteError, CyHalfBinomialLoss, etc.).

On the sklearn side, this could be fixed by explicitly setting __module__ = 'sklearn._loss._loss' in the Cython source (or by using a build-time configuration that embeds the full module path). This is the same issue that affects the existing ALLOWED_LOSSES — all Cy* loss types there also report _loss.Cy* rather than sklearn._loss._loss.Cy*.

Removed CyHalfMultinomialLoss from auto-trust and added a comment explaining this in f17710e. Multiclass GradientBoostingClassifier will surface it as an untrusted type that users need to explicitly trust.

Comment thread skops/io/_trusted_types.py Outdated
except ImportError:
pass

SKLEARN_INTERNAL_TYPE_NAMES = [get_type_name(t) for t in _SKLEARN_INTERNAL_TYPES]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should have a filter on all sklearn types that they should start with sklearn. when we get their types, so that we don't end up accidentally trusting a monkey-patched type coming from a different lib.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the sklearn. prefix filter in f17710e. SKLEARN_INTERNAL_TYPE_NAMES now filters with get_type_name(t).startswith("sklearn."), matching the pattern used by SKLEARN_ESTIMATOR_TYPE_NAMES. This excludes CyHalfMultinomialLoss (which reports _loss.CyHalfMultinomialLoss) until sklearn fixes the Cython module name.

Copilot AI and others added 2 commits April 16, 2026 12:15
… drop warning suppression

- Filter SKLEARN_INTERNAL_TYPE_NAMES to only include types starting
  with 'sklearn.' to avoid trusting monkey-patched types from other
  packages
- Remove CyHalfMultinomialLoss from auto-trusted list since its
  __module__ reports '_loss' instead of 'sklearn._loss._loss' (Cython
  build issue on sklearn side)
- Add detailed comment explaining the sklearn Cython module name issue
- Remove unnecessary warnings.catch_warnings() in test — no warnings
  are produced during fitting of these estimators
- Update test to use get_untrusted_types() + loads(trusted=...) pattern

Agent-Logs-Url: https://github.com/skops-dev/skops/sessions/ccbb4abe-8ea6-442c-a6dc-41b9489d35cf

Co-authored-by: adrinjalali <1663864+adrinjalali@users.noreply.github.com>
Agent-Logs-Url: https://github.com/skops-dev/skops/sessions/ccbb4abe-8ea6-442c-a6dc-41b9489d35cf

Co-authored-by: adrinjalali <1663864+adrinjalali@users.noreply.github.com>
@adrinjalali
Copy link
Copy Markdown
Member

@copilot scikit-learn/scikit-learn#33770 fixes the issue with the Cy* classes. But in the meantime, their __module__ is clearly not reported correctly. They shouldn't be trusted by default when the __module__ is not correct. Adjust the test to indicate that, with the right comment, and include all Cy* objects when their __module__ is correct as a trusted object.

Copilot AI and others added 2 commits April 16, 2026 14:07
…ule__ correctness

Agent-Logs-Url: https://github.com/skops-dev/skops/sessions/a65530ea-7c30-419d-a36e-e6c6a423c5f7

Co-authored-by: adrinjalali <1663864+adrinjalali@users.noreply.github.com>
…convention)

Agent-Logs-Url: https://github.com/skops-dev/skops/sessions/a65530ea-7c30-419d-a36e-e6c6a423c5f7

Co-authored-by: adrinjalali <1663864+adrinjalali@users.noreply.github.com>
Copy link
Copy Markdown
Contributor Author

Copilot AI commented Apr 16, 2026

@copilot scikit-learn/scikit-learn#33770 fixes the issue with the Cy* classes. But in the meantime, their __module__ is clearly not reported correctly. They shouldn't be trus...

Added all Cy* types from sklearn._loss._loss to _SKLEARN_INTERNAL_TYPES in _trusted_types.py. The existing startswith("sklearn.") filter automatically excludes them while their __module__ is incorrectly set to _loss, and will automatically include them once sklearn PR #33770 lands.

The test now conditionally asserts: multiclass GBC expects CyHalfMultinomialLoss as untrusted only when CyHalfMultinomialLoss.__module__ doesn't start with sklearn.. Once sklearn fixes __module__, all 9 GB/HGB variants will have zero untrusted types.

See commit 1fd89ac.

Agent-Logs-Url: https://github.com/skops-dev/skops/sessions/9bf5c6f9-7729-4d13-a37a-18a6b934d296

Co-authored-by: adrinjalali <1663864+adrinjalali@users.noreply.github.com>
Comment thread skops/io/_trusted_types.py Outdated
Comment on lines +81 to +86
# Cython loss classes from sklearn._loss._loss. These are internal Cython extension
# types used by GradientBoosting models. Currently their __module__ incorrectly reports
# '_loss' instead of 'sklearn._loss._loss' due to a bug in sklearn's Cython build
# (see https://github.com/scikit-learn/scikit-learn/pull/33770). The startswith
# ("sklearn.") filter below will automatically exclude them while __module__ is wrong,
# and automatically include them once sklearn fixes the issue.
Copy link
Copy Markdown

@cakedev0 cakedev0 Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LossNode already default-trusts several _loss.* Cython classes, not just CyHalfMultinomialLoss. In skops/io/_sklearn.py:70, ALLOWED_LOSSES includes many _loss.* names (_loss.CyAbsoluteError, _loss.CyExponentialLoss, and a tens more)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CyHalfMultinomialLoss is not in this list because it's not a CyLossFunction subclass (it has a slightly different structure).

But a good solution might still be to use the LossNode to support it. I vibe-tried that and it seems to work (with a diff of ~20 lines).

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also just add CyHalfMultinomialLoss to the trusted types, as we already default-trusts several _loss.* Cython classes... According to my LLM, it's not less safe.

@cakedev0
Copy link
Copy Markdown

This is indeed quite simpler, but it doesn't work for multi-class classification with GB/HGB because those depends on CyHalfMultinomialLoss.

@adrinjalali adrinjalali marked this pull request as ready for review April 17, 2026 08:06
Comment thread skops/io/_sklearn.py
# CyHalfMultinomialLoss is not a subclass of CyLossFunction, so it needs its
# own dispatch entry. It's already in ALLOWED_LOSSES so LossNode will trust it.
if CyHalfMultinomialLoss is not None:
GET_STATE_DISPATCH_FUNCTIONS.append((CyHalfMultinomialLoss, loss_get_state))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work because CyHalfMultinomialLoss doesn't have the same reduce/state shape as other Cy* losses.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests are passing, can you give me a test case where it fails?

Copy link
Copy Markdown

@cakedev0 cakedev0 Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this loss_get_state, it works (I tired the fix on nighly, 1.7 and 1.6, it works there too):

def loss_get_state(obj: Any, save_context: SaveContext) -> dict[str, Any]:
    reduce = obj.__reduce__()
    if type(obj) == reduce[0]:
        state = reduce_get_state(obj, save_context)
        state["__loader__"] = "LossNode"
    elif type(obj) == reduce[1][0]:
        # The output is commonly of the form:
        # >>> CyPinballLoss(1).__reduce__()
        # (<cyfunction __pyx_unpickle_CyPinballLoss at 0x7b1d00099ff0>,
        #             (<class '_loss.CyPinballLoss'>, 232784418, (1.0,)))
        #
        # CyHalfMultinomialLoss differs slightly and returns:
        # >>> CyHalfMultinomialLoss().__reduce__()
        # (<cyfunction __pyx_unpickle_CyHalfMultinomialLoss at 0x...>,
        #  (<class '_loss.CyHalfMultinomialLoss'>, 238750788, None), ())
        #
        # In that case, the constructor takes no args and the tuple state lives
        # in reduce[2].
        state = {
            "__class__": obj.__class__.__name__,
            "__module__": get_module(type(obj)),
            "__loader__": "LossNode",
        }
        state["__reduce__"] = {}
        if len(reduce) == 3:
            state["__reduce__"]["args"] = get_state((), save_context)
            state["content"] = get_state(reduce[2], save_context)
        else:
            state["__reduce__"]["args"] = get_state(reduce[1][2], save_context)
            state["content"] = get_state({}, save_context)

    return state

Copy link
Copy Markdown

@cakedev0 cakedev0 Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests are passing, can you give me a test case where it fails?

Locally (using pixi env: ci-sklearn18), I have one single test that fails: FAILED skops/io/tests/test_persist.py::test_gradient_boosting_estimators_have_no_untrusted_types[GradientBoostingClassifier-log_loss-multiclass] - TypeError: _loss.CyHalfMultinomialLoss() argument after * must be an iterable, not NoneType

The CI has the same failure, but also other failures for other sklearn versions, which are the ones I skipped in my PR with:

@pytest.mark.skipif(
    SKLEARN_VERSION < parse_version("1.4"),
    reason=(
        "Before scikit-learn 1.4, GradientBoosting uses different internal loss "
        "objects (`sklearn.ensemble._gb_losses`), which we don't try to support "
        "as trusted types."
    ),
)

@adrinjalali
Copy link
Copy Markdown
Member

LGTM, WDYT @cakedev0

@cakedev0
Copy link
Copy Markdown

Yes LGTM too 👍 Thanks!

@adrinjalali adrinjalali merged commit e74ebfb into main Apr 17, 2026
51 checks passed
@adrinjalali adrinjalali deleted the copilot/ping-copilot-review-response branch April 17, 2026 13:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants