Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grouped K Fold #540

Closed
trebuchet90 opened this issue Nov 30, 2016 · 5 comments
Closed

Grouped K Fold #540

trebuchet90 opened this issue Nov 30, 2016 · 5 comments

Comments

@trebuchet90
Copy link

@trebuchet90 trebuchet90 commented Nov 30, 2016

Is there a GroupKFold feature in caret or R?
http://scikit-learn.org/stable/modules/cross_validation.html

@topepo
Copy link
Owner

@topepo topepo commented Dec 1, 2016

Here is some preliminary code. You would run this before train then use the index argument of trainControl to specify the folds.

group_cv <- function(x, k = length(unique(x))) {
  dat <- data.frame(index = seq(along = x), group = x)
  groups <- data.frame(group = unique(dat$group))
  group_folds <- createFolds(groups$group, returnTrain = TRUE, k = k)
  group_folds <- lapply(group_folds, function(x, y) y[x,,drop = FALSE], y = groups)
  dat_folds <- lapply(group_folds, function(x, y) merge(x, y), y = dat)
  lapply(dat_folds, function(x) sort(x$index))
}

groups <- c(1, 1, 1, 2, 2, 2, 3, 3, 3, 3)

set.seed(242)
folds <- group_cv(groups)
## check: 
for_mod <- lapply(folds, function(x, y) y[x], y = groups)
holdout <- lapply(folds, function(x, y) y[-unique(x)], y = groups)
for(i in seq(along = folds)) {
  if(any(unique(for_mod[[i]]) %in% unique(holdout[[i]])))
    cat("didn't work!")
}

## Test with more groups and smaller # folds
set.seed(91906)
other_groups <- sort(sample(letters[1:15], size = 50, replace = TRUE))
folds_1 <- group_cv(other_groups, k = 5)

for_mod <- lapply(folds_1, function(x, y) y[x], y = other_groups)
holdout <- lapply(folds_1, function(x, y) y[-unique(x)], y = other_groups)
for(i in seq(along = folds_1)) {
  if(any(unique(for_mod[[i]]) %in% unique(holdout[[i]])))
    cat("didn't work!")
}

Please test this since I just wrote it on the fly

@trebuchet90
Copy link
Author

@trebuchet90 trebuchet90 commented Dec 1, 2016

That's amazing. Thank you so much. If you ever need anything let me know. I love love your library.

Grouped K Fold to me is so important for machine learning. every interesting problem I've ever done and lots on Kaggle require it. Would it be possible to make this an official feature?

@topepo
Copy link
Owner

@topepo topepo commented Dec 1, 2016

@trebuchet90
Copy link
Author

@trebuchet90 trebuchet90 commented Dec 1, 2016

I think the way you did it as an index was perfect. That's what I was looking for anyway.
In a larger ensemble different CVs in a nested CV could be grouped on different indexes.
Should I close this issue? or wait for the feature to be added?

@topepo
Copy link
Owner

@topepo topepo commented Dec 1, 2016

Go ahead and leave it open so I can reference commits to it.

Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Linked pull requests

Successfully merging a pull request may close this issue.

None yet
2 participants
You can’t perform that action at this time.