-
Notifications
You must be signed in to change notification settings - Fork 855
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Issue with StackingCVClassifier train_meta_features_? #366
Comments
Thanks for pointing that out. In the docs, it was probably a copy & paste error when it was ported from the StackingClassifier docs (have to double-check). I think it's correct though if
I think you are right. I was just inspecting the code, and the meta features get saved, and after that the reordering of the labels is done: if self.store_train_meta_features:
self.train_meta_features_ = all_model_predictions
# We have to shuffle the labels in the same order as we generated
# predictions during CV (we kinda shuffled them when we did
# Stratified CV).
# We also do the same with the features (we will need this only IF
# use_features_in_secondary is True)
reordered_labels = np.array([]).astype(y.dtype)
reordered_features = np.array([]).reshape((0, X.shape[1]))\
.astype(X.dtype)
for train_index, test_index in skf:
reordered_labels = np.concatenate((reordered_labels,
y[test_index]))
reordered_features = np.concatenate((reordered_features,
X[test_index])) Instead of reordering the labels, it might be better to reorder the meta features (aka |
I just merged a fix, the meta features should be saved in the order of the original labels now! Note that I also renamed the
Anyways, thanks a lot for pointing these issues out! |
Using version 0.11.0
There seems to be a couple issues with this attribute, which would certainly be useful.
1) sclf.train_meta_features_.shape is actually (number of training rows, number of classifiers *2) because both classes predictions (in the case of a binary problem) are maintained.
2) I have doubts that the index order is maintained when setting stratify =True or shuffle = True
Here is an example where the resulting meta_features appear to not be in the correct order with the original Y
Here is an example where the order is apparently correct
The text was updated successfully, but these errors were encountered: