@@ -1392,50 +1392,46 @@ def __init__(
1392
1392
raise TypeError ('Unknown type:' , model_file )
1393
1393
1394
1394
params = params or {}
1395
- params = _configure_metrics (params .copy ())
1396
- params = self ._configure_constraints (params )
1397
- if isinstance (params , list ):
1398
- params .append ((' validate_parameters' , True ))
1395
+ params_processed = _configure_metrics (params .copy ())
1396
+ params_processed = self ._configure_constraints (params_processed )
1397
+ if isinstance (params_processed , list ):
1398
+ params_processed .append ((" validate_parameters" , True ))
1399
1399
else :
1400
- params [ ' validate_parameters' ] = True
1400
+ params_processed [ " validate_parameters" ] = True
1401
1401
1402
- self .set_param (params or {})
1403
- if (params is not None ) and ('booster' in params ):
1404
- self .booster = params ['booster' ]
1405
- else :
1406
- self .booster = 'gbtree'
1402
+ self .set_param (params_processed or {})
1407
1403
1408
- def _transform_monotone_constrains (self , value : Union [Dict [str , int ], str ]) -> str :
1404
+ def _transform_monotone_constrains (
1405
+ self , value : Union [Dict [str , int ], str ]
1406
+ ) -> Union [Tuple [int , ...], str ]:
1409
1407
if isinstance (value , str ):
1410
1408
return value
1411
1409
1412
1410
constrained_features = set (value .keys ())
1413
- if not constrained_features .issubset (set (self .feature_names or [])):
1414
- raise ValueError ('Constrained features are not a subset of '
1415
- 'training data feature names' )
1411
+ feature_names = self .feature_names or []
1412
+ if not constrained_features .issubset (set (feature_names )):
1413
+ raise ValueError (
1414
+ "Constrained features are not a subset of training data feature names"
1415
+ )
1416
1416
1417
- return '(' + ',' .join ([str (value .get (feature_name , 0 ))
1418
- for feature_name in self .feature_names ]) + ')'
1417
+ return tuple (value .get (name , 0 ) for name in feature_names )
1419
1418
1420
1419
def _transform_interaction_constraints (
1421
- self , value : Union [List [ Tuple [str ]], str ]
1422
- ) -> str :
1420
+ self , value : Union [Sequence [ Sequence [str ]], str ]
1421
+ ) -> Union [ str , List [ List [ int ]]] :
1423
1422
if isinstance (value , str ):
1424
1423
return value
1425
-
1426
- feature_idx_mapping = {k : str (v ) for v , k in enumerate (self .feature_names or [])}
1424
+ feature_idx_mapping = {
1425
+ name : idx for idx , name in enumerate (self .feature_names or [])
1426
+ }
1427
1427
1428
1428
try :
1429
- s = "["
1429
+ result = []
1430
1430
for constraint in value :
1431
- s += (
1432
- "["
1433
- + "," .join (
1434
- [feature_idx_mapping [feature_name ] for feature_name in constraint ]
1435
- )
1436
- + "],"
1431
+ result .append (
1432
+ [feature_idx_mapping [feature_name ] for feature_name in constraint ]
1437
1433
)
1438
- return s [: - 1 ] + "]"
1434
+ return result
1439
1435
except KeyError as e :
1440
1436
raise ValueError (
1441
1437
"Constrained features are not a subset of training data feature names"
@@ -1444,17 +1440,16 @@ def _transform_interaction_constraints(
1444
1440
def _configure_constraints (self , params : Union [List , Dict ]) -> Union [List , Dict ]:
1445
1441
if isinstance (params , dict ):
1446
1442
value = params .get ("monotone_constraints" )
1447
- if value :
1448
- params [
1449
- "monotone_constraints"
1450
- ] = self . _transform_monotone_constrains ( value )
1443
+ if value is not None :
1444
+ params ["monotone_constraints" ] = self . _transform_monotone_constrains (
1445
+ value
1446
+ )
1451
1447
1452
1448
value = params .get ("interaction_constraints" )
1453
- if value :
1449
+ if value is not None :
1454
1450
params [
1455
1451
"interaction_constraints"
1456
1452
] = self ._transform_interaction_constraints (value )
1457
-
1458
1453
elif isinstance (params , list ):
1459
1454
for idx , param in enumerate (params ):
1460
1455
name , value = param
@@ -2462,11 +2457,9 @@ def trees_to_dataframe(self, fmap: Union[str, os.PathLike] = '') -> DataFrame:
2462
2457
if not PANDAS_INSTALLED :
2463
2458
raise ImportError (('pandas must be available to use this method.'
2464
2459
'Install pandas before calling again.' ))
2465
-
2466
- if getattr (self , 'booster' , None ) is not None and self .booster not in {'gbtree' , 'dart' }:
2467
- raise ValueError (
2468
- f"This method is not defined for Booster type { self .booster } "
2469
- )
2460
+ booster = json .loads (self .save_config ())["learner" ]["gradient_booster" ]["name" ]
2461
+ if booster not in {"gbtree" , "dart" }:
2462
+ raise ValueError (f"This method is not defined for Booster type { booster } " )
2470
2463
2471
2464
tree_ids = []
2472
2465
node_ids = []
0 commit comments