Skip to content

Commit

Permalink
🎨 Add n_features parameter to load_libsvm_file
Browse files Browse the repository at this point in the history
  • Loading branch information
yoshoku committed Mar 11, 2021
1 parent bad82e7 commit f065345
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
10 changes: 7 additions & 3 deletions lib/rumale/dataset.rb
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,26 @@ class << self
# Load a dataset with the libsvm file format into Numo::NArray.
#
# @param filename [String] A path to a dataset file.
# @param n_features [Integer/Nil] The number of features of data to load.
# If nil is given, it will be detected automatically from given file.
# @param zero_based [Boolean] Whether the column index starts from 0 (true) or 1 (false).
# @param dtype [Numo::NArray] Data type of Numo::NArray for features to be loaded.
#
# @return [Array<Numo::NArray>]
# Returns array containing the (n_samples x n_features) matrix for feature vectors
# and (n_samples) vector for labels or target values.
def load_libsvm_file(filename, zero_based: false, dtype: Numo::DFloat)
def load_libsvm_file(filename, n_features: nil, zero_based: false, dtype: Numo::DFloat)
ftvecs = []
labels = []
n_features = 0
n_features_detected = 0
CSV.foreach(filename, col_sep: "\s", headers: false) do |line|
label, ftvec, max_idx = parse_libsvm_line(line, zero_based)
labels.push(label)
ftvecs.push(ftvec)
n_features = max_idx if n_features < max_idx
n_features_detected = max_idx if n_features_detected < max_idx
end
n_features ||= n_features_detected
n_features = [n_features, n_features_detected].max
[convert_to_matrix(ftvecs, n_features, dtype), Numo::NArray.asarray(labels)]
end

Expand Down
13 changes: 13 additions & 0 deletions spec/rumale/datasets_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,19 @@
m, = described_class.load_libsvm_file(__dir__ + '/../test_zb.t', zero_based: true)
expect(m).to eq(matrix_dbl)
end

it 'lodas libsvm .t file with the number of features', :aggregate_failures do
m, = described_class.load_libsvm_file(__dir__ + '/../test_dbl.t', n_features: 6)
expect(m.shape[1]).to eq(6)
expect(m).to eq(matrix_dbl.concatenate(Numo::DFloat.zeros(6, 2), axis: 1))
m, = described_class.load_libsvm_file(__dir__ + '/../test_dbl.t', n_features: 2)
expect(m.shape[1]).to eq(matrix_dbl.shape[1])
m, = described_class.load_libsvm_file(__dir__ + '/../test_zb.t', zero_based: true, n_features: 6)
expect(m.shape[1]).to eq(6)
expect(m).to eq(matrix_dbl.concatenate(Numo::DFloat.zeros(6, 2), axis: 1))
m, = described_class.load_libsvm_file(__dir__ + '/../test_zb.t', zero_based: true, n_features: 2)
expect(m.shape[1]).to eq(matrix_dbl.shape[1])
end
end

describe '#dump_libsvm_file' do
Expand Down

0 comments on commit f065345

Please sign in to comment.