-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
✨ Add data splitter for grouped k-fold cross validation
- Loading branch information
Showing
3 changed files
with
136 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# frozen_string_literal: true | ||
|
||
require 'rumale/base/splitter' | ||
require 'rumale/preprocessing/label_encoder' | ||
|
||
module Rumale | ||
module ModelSelection | ||
# GroupKFold is a class that generates the set of data indices for K-fold cross-validation. | ||
# The data points belonging to the same group do not be split into different folds. | ||
# The number of groups should be greater than or equal to the number of splits. | ||
# | ||
# @example | ||
# kf = Rumale::ModelSelection::GroupKFold.new(n_splits: 5) | ||
# kf.split(samples, labels, groups).each do |train_ids, test_ids| | ||
# train_samples = samples[train_ids, true] | ||
# test_samples = samples[test_ids, true] | ||
# ... | ||
# end | ||
class GroupKFold | ||
include Base::Splitter | ||
|
||
# Return the number of folds. | ||
# @return [Integer] | ||
attr_reader :n_splits | ||
|
||
# Create a new data splitter for grouped K-fold cross validation. | ||
# | ||
# @param n_splits [Integer] The number of folds. | ||
def initialize(n_splits: 5) | ||
check_params_numeric(n_splits: n_splits) | ||
@n_splits = n_splits | ||
end | ||
|
||
# Generate data indices for grouped K-fold cross validation. | ||
# | ||
# @overload split(x, y, groups) -> Array | ||
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) | ||
# The dataset to be used to generate data indices for grouped K-fold cross validation. | ||
# @param y [Numo::Int32] (shape: [n_samples]) | ||
# The labels to be used to generate data indices for grouped K-fold cross validation. | ||
# This argument exists to unify the interface between the K-fold methods, it is not used in the method. | ||
# @param groups [Numo::Int32] (shape: [n_samples]) | ||
# The group labels to be used to generate data indices for grouped K-fold cross validation. | ||
# @return [Array] The set of data indices for constructing the training and testing dataset in each fold. | ||
def split(x, _y, groups) | ||
x = check_convert_sample_array(x) | ||
groups = check_convert_label_array(groups) | ||
check_sample_label_size(x, groups) | ||
|
||
encoder = Rumale::Preprocessing::LabelEncoder.new | ||
groups = encoder.fit_transform(groups) | ||
n_groups = encoder.classes.size | ||
|
||
raise ArgumentError, 'The number of groups should be greater than or equal to the number of splits.' if n_groups < @n_splits | ||
|
||
n_samples_per_group = groups.bincount | ||
group_ids = n_samples_per_group.sort_index.reverse | ||
n_samples_per_group = n_samples_per_group[group_ids] | ||
|
||
n_samples_per_fold = Numo::Int32.zeros(@n_splits) | ||
group_to_fold = Numo::Int32.zeros(n_groups) | ||
|
||
n_samples_per_group.each_with_index do |weight, id| | ||
min_sample_fold_id = n_samples_per_fold.min_index | ||
n_samples_per_fold[min_sample_fold_id] += weight | ||
group_to_fold[group_ids[id]] = min_sample_fold_id | ||
end | ||
|
||
n_samples = x.shape[0] | ||
sample_ids = Array(0...n_samples) | ||
fold_ids = group_to_fold[groups] | ||
|
||
Array.new(@n_splits) do |fid| | ||
train_ids = fold_ids.eq(fid).where.to_a | ||
test_ids = sample_ids - train_ids | ||
[train_ids, test_ids] | ||
end | ||
end | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# frozen_string_literal: true | ||
|
||
require 'spec_helper' | ||
|
||
RSpec.describe Rumale::ModelSelection::GroupKFold do | ||
let(:x) { Numo::DFloat[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16], [17, 18]] } | ||
let(:y) { nil } | ||
let(:groups) { Numo::Int32[0, 0, 1, 3, 1, 0, 3, 1, 0] } | ||
let(:splitter) { described_class.new(n_splits: n_splits) } | ||
let(:validation_ids) { splitter.split(x, y, groups) } | ||
|
||
context 'when the number of splits is equal to the number of groups' do | ||
let(:n_splits) { 3 } | ||
|
||
it 'splits the dataset with group labels', :aggregate_failures do | ||
expect(splitter.n_splits).to eq(n_splits) | ||
expect(validation_ids).to be_a(Array) | ||
expect(validation_ids.size).to eq(n_splits) | ||
expect(validation_ids[0].size).to eq(2) | ||
expect(validation_ids[1].size).to eq(2) | ||
expect(validation_ids[2].size).to eq(2) | ||
expect(validation_ids[0][0]).to match_array(groups.eq(0).where.to_a) | ||
expect(validation_ids[0][1]).to match_array(groups.ne(0).where.to_a) | ||
expect(validation_ids[1][0]).to match_array(groups.eq(1).where.to_a) | ||
expect(validation_ids[1][1]).to match_array(groups.ne(1).where.to_a) | ||
expect(validation_ids[2][0]).to match_array(groups.eq(3).where.to_a) | ||
expect(validation_ids[2][1]).to match_array(groups.ne(3).where.to_a) | ||
end | ||
end | ||
|
||
context 'when given the number of splits is less than the number of groups' do | ||
let(:n_splits) { 2 } | ||
|
||
it 'splits the dataset with group labels', :aggregate_failures do | ||
expect(splitter.n_splits).to eq(n_splits) | ||
expect(validation_ids).to be_a(Array) | ||
expect(validation_ids.size).to eq(n_splits) | ||
expect(validation_ids[0].size).to eq(2) | ||
expect(validation_ids[1].size).to eq(2) | ||
expect(validation_ids[0][0]).to match_array(groups.eq(0).where.to_a) | ||
expect(validation_ids[0][1]).to match_array(groups.ne(0).where.to_a) | ||
expect(validation_ids[1][0]).to match_array(groups.ne(0).where.to_a) | ||
expect(validation_ids[1][1]).to match_array(groups.eq(0).where.to_a) | ||
end | ||
end | ||
|
||
context 'when given the number of splits is greater than the number of groups' do | ||
let(:n_splits) { 4 } | ||
|
||
it 'raises ArgumentError' do | ||
expect { splitter.split(x, y, groups) }.to raise_error(ArgumentError) | ||
end | ||
end | ||
end |