Skip to content

Commit

Permalink
✨ Add data splitter for grouped k-fold cross validation
Browse files Browse the repository at this point in the history
  • Loading branch information
yoshoku committed Aug 18, 2020
1 parent 1e5ab9b commit f460338
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 0 deletions.
1 change: 1 addition & 0 deletions lib/rumale.rb
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
require 'rumale/preprocessing/binarizer'
require 'rumale/preprocessing/polynomial_features'
require 'rumale/model_selection/k_fold'
require 'rumale/model_selection/group_k_fold'
require 'rumale/model_selection/stratified_k_fold'
require 'rumale/model_selection/shuffle_split'
require 'rumale/model_selection/stratified_shuffle_split'
Expand Down
81 changes: 81 additions & 0 deletions lib/rumale/model_selection/group_k_fold.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# frozen_string_literal: true

require 'rumale/base/splitter'
require 'rumale/preprocessing/label_encoder'

module Rumale
module ModelSelection
# GroupKFold is a class that generates the set of data indices for K-fold cross-validation.
# The data points belonging to the same group do not be split into different folds.
# The number of groups should be greater than or equal to the number of splits.
#
# @example
# kf = Rumale::ModelSelection::GroupKFold.new(n_splits: 5)
# kf.split(samples, labels, groups).each do |train_ids, test_ids|
# train_samples = samples[train_ids, true]
# test_samples = samples[test_ids, true]
# ...
# end
class GroupKFold
include Base::Splitter

# Return the number of folds.
# @return [Integer]
attr_reader :n_splits

# Create a new data splitter for grouped K-fold cross validation.
#
# @param n_splits [Integer] The number of folds.
def initialize(n_splits: 5)
check_params_numeric(n_splits: n_splits)
@n_splits = n_splits
end

# Generate data indices for grouped K-fold cross validation.
#
# @overload split(x, y, groups) -> Array
# @param x [Numo::DFloat] (shape: [n_samples, n_features])
# The dataset to be used to generate data indices for grouped K-fold cross validation.
# @param y [Numo::Int32] (shape: [n_samples])
# The labels to be used to generate data indices for grouped K-fold cross validation.
# This argument exists to unify the interface between the K-fold methods, it is not used in the method.
# @param groups [Numo::Int32] (shape: [n_samples])
# The group labels to be used to generate data indices for grouped K-fold cross validation.
# @return [Array] The set of data indices for constructing the training and testing dataset in each fold.
def split(x, _y, groups)
x = check_convert_sample_array(x)
groups = check_convert_label_array(groups)
check_sample_label_size(x, groups)

encoder = Rumale::Preprocessing::LabelEncoder.new
groups = encoder.fit_transform(groups)
n_groups = encoder.classes.size

raise ArgumentError, 'The number of groups should be greater than or equal to the number of splits.' if n_groups < @n_splits

n_samples_per_group = groups.bincount
group_ids = n_samples_per_group.sort_index.reverse
n_samples_per_group = n_samples_per_group[group_ids]

n_samples_per_fold = Numo::Int32.zeros(@n_splits)
group_to_fold = Numo::Int32.zeros(n_groups)

n_samples_per_group.each_with_index do |weight, id|
min_sample_fold_id = n_samples_per_fold.min_index
n_samples_per_fold[min_sample_fold_id] += weight
group_to_fold[group_ids[id]] = min_sample_fold_id
end

n_samples = x.shape[0]
sample_ids = Array(0...n_samples)
fold_ids = group_to_fold[groups]

Array.new(@n_splits) do |fid|
train_ids = fold_ids.eq(fid).where.to_a
test_ids = sample_ids - train_ids
[train_ids, test_ids]
end
end
end
end
end
54 changes: 54 additions & 0 deletions spec/rumale/model_selection/group_k_fold_spec.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# frozen_string_literal: true

require 'spec_helper'

RSpec.describe Rumale::ModelSelection::GroupKFold do
let(:x) { Numo::DFloat[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16], [17, 18]] }
let(:y) { nil }
let(:groups) { Numo::Int32[0, 0, 1, 3, 1, 0, 3, 1, 0] }
let(:splitter) { described_class.new(n_splits: n_splits) }
let(:validation_ids) { splitter.split(x, y, groups) }

context 'when the number of splits is equal to the number of groups' do
let(:n_splits) { 3 }

it 'splits the dataset with group labels', :aggregate_failures do
expect(splitter.n_splits).to eq(n_splits)
expect(validation_ids).to be_a(Array)
expect(validation_ids.size).to eq(n_splits)
expect(validation_ids[0].size).to eq(2)
expect(validation_ids[1].size).to eq(2)
expect(validation_ids[2].size).to eq(2)
expect(validation_ids[0][0]).to match_array(groups.eq(0).where.to_a)
expect(validation_ids[0][1]).to match_array(groups.ne(0).where.to_a)
expect(validation_ids[1][0]).to match_array(groups.eq(1).where.to_a)
expect(validation_ids[1][1]).to match_array(groups.ne(1).where.to_a)
expect(validation_ids[2][0]).to match_array(groups.eq(3).where.to_a)
expect(validation_ids[2][1]).to match_array(groups.ne(3).where.to_a)
end
end

context 'when given the number of splits is less than the number of groups' do
let(:n_splits) { 2 }

it 'splits the dataset with group labels', :aggregate_failures do
expect(splitter.n_splits).to eq(n_splits)
expect(validation_ids).to be_a(Array)
expect(validation_ids.size).to eq(n_splits)
expect(validation_ids[0].size).to eq(2)
expect(validation_ids[1].size).to eq(2)
expect(validation_ids[0][0]).to match_array(groups.eq(0).where.to_a)
expect(validation_ids[0][1]).to match_array(groups.ne(0).where.to_a)
expect(validation_ids[1][0]).to match_array(groups.ne(0).where.to_a)
expect(validation_ids[1][1]).to match_array(groups.eq(0).where.to_a)
end
end

context 'when given the number of splits is greater than the number of groups' do
let(:n_splits) { 4 }

it 'raises ArgumentError' do
expect { splitter.split(x, y, groups) }.to raise_error(ArgumentError)
end
end
end

0 comments on commit f460338

Please sign in to comment.