Skip to content

tdhock/mlr3resampling

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

75 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

mlr3resampling provides new cross-validation algorithms for the mlr3 framework in R

testshttps://github.com/tdhock/mlr3resampling/workflows/R-CMD-check/badge.svg
coveragehttps://codecov.io/gh/tdhock/mlr3resampling/branch/main/graph/badge.svg

Installation

install.packages("mlr3resampling")#release version from CRAN
## OR: development version from GitHub:
install.packages("remotes")
remotes::install_github("tdhock/mlr3resampling")

Description

For an overview of functionality, please read my recent blog post.

Algorithm 1: cross-validation for comparing train on same and other

See examples in ResamplingSameOtherSizesCV vignette and data viz for regression and classification.

A supervised learning algorithm inputs a train set, and outputs a prediction function, which can be used on a test set. If each data point belongs to a group (such as geographic region, year, etc), then how do we know if it is possible to train on one group, and predict accurately on another group? Cross-validation can be used to determine the extent to which this is possible, by first assigning fold IDs from 1 to K to all data (possibly using stratification, usually by group and label). Then we loop over test sets (group/fold combinations), train sets (same group, other groups, all groups), and compute test/prediction accuracy for each combination. Comparing test/prediction accuracy between same and other, we can determine the extent to which it is possible:

  • perfect if same/other have similar test accuracy for each group, and all is more accuate;
  • other/all are usually somewhat less accurate than same in real data;
  • other can be just as bad as featureless baseline when the groups have different patterns.

This is implemented in ResamplingSameOtherSizesCV when you use it on a task that defines the subset role, for example the Arizona trees data, for which each row is a pixel in an image, and we want to do binary classification – does the pixel contain a tree or not?

> data(AZtrees,package="mlr3resampling")
> table(AZtrees$region3)

  NE   NW    S 
1464 1563 2929 

We see in the output above that the region3 column has three values (NE, NW, S). Each represents the region/area in which the pixel was found. If we want good predictions in the south (S), can we train on the north? (NE+NW) We can use the code below to setup the CV experiment. The rows 12,15,18 below represent splits that attempt to answer that question (test.subset=S, train.subsets=other).

> task.obj <- mlr3::TaskClassif$new("AZtrees3", AZtrees, target="y")
> task.obj$col_roles$feature <- grep("SAMPLE", names(AZtrees), value=TRUE)
> task.obj$col_roles$strata <- "y"  #keep data proportional when splitting.
> task.obj$col_roles$group <- "polygon"  #keep data together when splitting.
> task.obj$col_roles$subset <- "region3" #fix one test region, train on same/other/all region(s).
> same_other_sizes_cv <- mlr3resampling::ResamplingSameOtherSizesCV$new()
> same_other_sizes_cv$instantiate(task.obj)
> same_other_sizes_cv$instance$iteration.dt[, .(test.subset, train.subsets, test.fold)]
    test.subset train.subsets test.fold
	 <char>        <char>     <int>
 1:          NE           all         1
 2:          NW           all         1
 3:           S           all         1
 4:          NE           all         2
 5:          NW           all         2
 6:           S           all         2
 7:          NE           all         3
 8:          NW           all         3
 9:           S           all         3
10:          NE         other         1
11:          NW         other         1
12:           S         other         1
13:          NE         other         2
14:          NW         other         2
15:           S         other         2
16:          NE         other         3
17:          NW         other         3
18:           S         other         3
19:          NE          same         1
20:          NW          same         1
21:           S          same         1
22:          NE          same         2
23:          NW          same         2
24:           S          same         2
25:          NE          same         3
26:          NW          same         3
27:           S          same         3
    test.subset train.subsets test.fold

The rows in the output above represent different kinds of splits:

  • train.subsets=same is used as a baseline.
  • train.subsets=all is used to answer the question, “is it beneficial to combine all subsets when training?”

Code to re-run:

data(AZtrees,package="mlr3resampling")
table(AZtrees$region3)
task.obj <- mlr3::TaskClassif$new("AZtrees3", AZtrees, target="y")
task.obj$col_roles$feature <- grep("SAMPLE", names(AZtrees), value=TRUE)
task.obj$col_roles$strata <- "y"  #keep data proportional when splitting.
task.obj$col_roles$group <- "polygon"  #keep data together when splitting.
task.obj$col_roles$subset <- "region3" #fix one test region, train on same/other/all region(s).
same_other_sizes_cv <- mlr3resampling::ResamplingSameOtherSizesCV$new()
same_other_sizes_cv$instantiate(task.obj)
same_other_sizes_cv$instance$iteration.dt[, .(test.subset, train.subsets, test.fold)]

Algorithm 2: cross-validation for comparing different sized train sets

See examples in ResamplingSameOtherSizes vignette and data viz for regression and classification.

How many train samples are required to get accurate predictions on a test set? Cross-validation can be used to answer this question, with variable size train sets. For example consider the Arizona Trees data below,

> dim(AZtrees)
[1] 5956   25
> length(unique(AZtrees$polygon))
[1] 189

The output above indicates we have 5956 rows and 189 polygons. We can do cross-validation on either polygons (if task has group role) or rows (if no group role set). The code below sets a down-sampling ratio of 0.8, and four sizes of down-sampled train sets.

> same_other_sizes_cv <- mlr3resampling::ResamplingSameOtherSizesCV$new()
> same_other_sizes_cv$param_set$values$ratio <- 0.8
> same_other_sizes_cv$param_set$values$sizes <- 4
> same_other_sizes_cv$instantiate(task.obj)
> same_other_sizes_cv$instance$iteration.dt[, .(n.train.groups, test.fold)]
    n.train.groups test.fold
             <int>     <int>
 1:             51         1
 2:             64         1
 3:             80         1
 4:            100         1
 5:            126         1
 6:             51         2
 7:             64         2
 8:             80         2
 9:            100         2
10:            126         2
11:             51         3
12:             64         3
13:             80         3
14:            100         3
15:            126         3

The output above has one row per train/test split that will be computed in the cross-validation experiment. The full train set size is 126 polygons, and there are four smaller train set sizes (each a factor of 0.8 smaller). Each train set size will be computed for each fold ID from 1 to 3.

Code to re-run:

data(AZtrees,package="mlr3resampling")
dim(AZtrees)
length(unique(AZtrees$polygon))
task.obj <- mlr3::TaskClassif$new("AZtrees3", AZtrees, target="y")
task.obj$col_roles$feature <- grep("SAMPLE", names(AZtrees), value=TRUE)
task.obj$col_roles$strata <- "y"  #keep data proportional when splitting.
task.obj$col_roles$group <- "polygon"  #keep data together when splitting.
same_other_sizes_cv <- mlr3resampling::ResamplingSameOtherSizesCV$new()
same_other_sizes_cv$param_set$values$sizes <- 4
same_other_sizes_cv$param_set$values$ratio <- 0.8
same_other_sizes_cv$instantiate(task.obj)
same_other_sizes_cv$instance$iteration.dt[, .(n.train.groups, test.fold)]

More Usage Examples and Discussion

Older examples in ResamplingSameOtherCV vignette and data viz for regression and classification.

Older examples in ResamplingVariableSizeTrainCV vignette and data viz for regression and classification.

The examples linked below have examples with larger data sizes than the examples in the CRAN vignettes linked above.

Related work

mlr3resampling code was copied/modified from Resampling and ResamplingCV classes in the excellent mlr3 package.

About

Resampling algorithms for mlr3 framework in R

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published