-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
created todo comment for stratification fix #29
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@elemets @panas89 I went ahead and debugged these changes using the AIDS research example code notebook by doing the following:
- passing
stratify_cols=["gender", "race"]
, - passing
stratify_y=True
, - placing two print statements at the end of the two respective
if stratify_cols and stratify_y:
blocks:- first:
print(stratify_key)
- second:
print(strat_key_val_test)
- first:
No errors, exceptions, or warnings were thrown, thus yielding the following output of a successful run:
gender race cid
0 0 0 0
1 0 0 1
2 1 0 0
3 1 0 0
4 1 0 0
... ... ... ...
2134 1 0 0
2135 1 1 0
2136 1 1 0
2137 1 0 1
2138 1 0 0
[2139 rows x 3 columns]
gender race cid
1943 0 1 0
1583 1 0 0
1891 1 0 0
1316 1 0 1
1117 1 0 1
... ... ... ...
6 1 0 1
1135 1 1 0
125 1 0 0
359 1 0 1
492 1 0 0
[856 rows x 3 columns]
100%|██████████| 324/324 [01:20<00:00, 4.02it/s]Best score/param set found on validation set:
{'params': {'selectKBest__k': 6,
'xgb__colsample_bytree': 1.0,
'xgb__early_stopping_rounds': 10,
'xgb__eval_metric': 'logloss',
'xgb__learning_rate': 0.01,
'xgb__max_depth': 7,
'xgb__n_estimators': 99,
'xgb__subsample': 1.0},
'score': 0.9262226970560304}
Best roc_auc: 0.926
The print statement confirms that the stratify_y
col ("cid"
) is being concatenated to the stratify_cols
DataFrame subset for x ("gender", "race"
), thus showing that the stratification for both cases is now working.
@@ -941,8 +946,14 @@ def train_val_test_split( | |||
# if calibrate: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove unnecessary comments and TODO
Line 1207 function get_cross_validate(), variable stratify is redundant |
We need to make a note in the documentation that stratify_cols cannot be used when using cross_validation |
Checked with debugger code changes! works! |
done |
Description:
Currently, the
train_val_test_split
method allows for stratification either by y (stratify_y
) or by specified columns (stratify_cols
), but not both at the same time. There are use cases where stratification by both the target variable (y) and specific columns is necessary to ensure a balanced and representative split across different data segments.Proposed Enhancement:
Modify the method to support simultaneous stratification by both y and
stratify_cols
. This can be achieved by combining the stratification keys or implementing logic that ensures both y and the specified columns are considered during the stratification process.Current Method Implementation: