Skip to content

Commit

Permalink
Support sort_dicts and underscore_numbers args
Browse files Browse the repository at this point in the history
  • Loading branch information
deepyaman committed May 23, 2024
1 parent 60e4acd commit f5d1e9d
Showing 1 changed file with 89 additions and 15 deletions.
104 changes: 89 additions & 15 deletions sklearn/utils/_pprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@

import inspect
import pprint
import sys

from .._config import get_config
from ..base import BaseEstimator
Expand Down Expand Up @@ -171,10 +172,24 @@ def __init__(
stream=None,
*,
compact=False,
sort_dicts=True,
underscore_numbers=False,
indent_at_name=True,
n_max_elements_to_show=None,
):
super().__init__(indent, width, depth, stream, compact=compact)
super().__init__(
indent,
width,
depth,
stream,
compact=compact,
sort_dicts=sort_dicts,
**(
{}
if sys.version_info < (3, 10)
else {"underscore_numbers": underscore_numbers}
),
)
self._indent_at_name = indent_at_name
if self._indent_at_name:
self._indent_per_level = 1 # ignore indent param
Expand All @@ -186,7 +201,13 @@ def __init__(

def format(self, object, context, maxlevels, level):
return _safe_repr(
object, context, maxlevels, level, changed_only=self._changed_only
object,
context,
maxlevels,
level,
sort_dicts=self._sort_dicts,
underscore_numbers=getattr(self, "_underscore_numbers", False),
changed_only=self._changed_only,
)

def _pprint_estimator(self, object, stream, indent, allowance, context, level):
Expand All @@ -199,9 +220,12 @@ def _pprint_estimator(self, object, stream, indent, allowance, context, level):
else:
params = object.get_params(deep=False)

self._format_params(
sorted(params.items()), stream, indent, allowance + 1, context, level
)
if self._sort_dicts:
items = sorted(params.items(), key=pprint._safe_tuple)
else:
items = params.items()

self._format_params(items, stream, indent, allowance + 1, context, level)
stream.write(")")

def _format_dict_items(self, items, stream, indent, allowance, context, level):
Expand Down Expand Up @@ -349,15 +373,29 @@ def _pprint_key_val_tuple(self, object, stream, indent, allowance, context, leve
_dispatch[KeyValTuple.__repr__] = _pprint_key_val_tuple


def _safe_repr(object, context, maxlevels, level, changed_only=False):
def _safe_repr(
object,
context,
maxlevels,
level,
sort_dicts,
underscore_numbers,
changed_only=False,
):
"""Same as the builtin _safe_repr, with added support for Estimator
objects."""
typ = type(object)

if typ in pprint._builtin_scalars:
return repr(object), True, False

r = getattr(typ, "__repr__", None)

if issubclass(typ, int) and r is int.__repr__:
if underscore_numbers:
return f"{object:_d}", True, False
else:
return repr(object), True, False

if issubclass(typ, dict) and r is dict.__repr__:
if not object:
return "{}", True, False
Expand All @@ -373,13 +411,28 @@ def _safe_repr(object, context, maxlevels, level, changed_only=False):
append = components.append
level += 1
saferepr = _safe_repr
items = sorted(object.items(), key=pprint._safe_tuple)
if sort_dicts:
items = sorted(object.items(), key=pprint._safe_tuple)
else:
items = object.items()
for k, v in items:
krepr, kreadable, krecur = saferepr(
k, context, maxlevels, level, changed_only=changed_only
k,
context,
maxlevels,
level,
sort_dicts,
underscore_numbers,
changed_only=changed_only,
)
vrepr, vreadable, vrecur = saferepr(
v, context, maxlevels, level, changed_only=changed_only
v,
context,
maxlevels,
level,
sort_dicts,
underscore_numbers,
changed_only=changed_only,
)
append("%s: %s" % (krepr, vrepr))
readable = readable and kreadable and vreadable
Expand Down Expand Up @@ -414,7 +467,13 @@ def _safe_repr(object, context, maxlevels, level, changed_only=False):
level += 1
for o in object:
orepr, oreadable, orecur = _safe_repr(
o, context, maxlevels, level, changed_only=changed_only
o,
context,
maxlevels,
level,
sort_dicts,
underscore_numbers,
changed_only=changed_only,
)
append(orepr)
if not oreadable:
Expand All @@ -427,7 +486,7 @@ def _safe_repr(object, context, maxlevels, level, changed_only=False):
if issubclass(typ, BaseEstimator):
objid = id(object)
if maxlevels and level >= maxlevels:
return "{...}", False, objid in context
return f"{typ.__name__}(...)", False, objid in context
if objid in context:
return pprint._recursion(object), False, True
context[objid] = 1
Expand All @@ -441,13 +500,28 @@ def _safe_repr(object, context, maxlevels, level, changed_only=False):
append = components.append
level += 1
saferepr = _safe_repr
items = sorted(params.items(), key=pprint._safe_tuple)
if sort_dicts:
items = sorted(params.items(), key=pprint._safe_tuple)
else:
items = params.items()
for k, v in items:
krepr, kreadable, krecur = saferepr(
k, context, maxlevels, level, changed_only=changed_only
k,
context,
maxlevels,
level,
sort_dicts,
underscore_numbers,
changed_only=changed_only,
)
vrepr, vreadable, vrecur = saferepr(
v, context, maxlevels, level, changed_only=changed_only
v,
context,
maxlevels,
level,
sort_dicts,
underscore_numbers,
changed_only=changed_only,
)
append("%s=%s" % (krepr.strip("'"), vrepr))
readable = readable and kreadable and vreadable
Expand Down

0 comments on commit f5d1e9d

Please sign in to comment.