-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
✨ Add classifier class with variable-random tree
- Loading branch information
Showing
3 changed files
with
275 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# frozen_string_literal: true | ||
|
||
require 'rumale/tree/decision_tree_classifier' | ||
|
||
module Rumale | ||
module Tree | ||
# VariableRandomTreeClassifier is a class that implements variable-random tree for classification. | ||
# | ||
# @example | ||
# estimator = | ||
# Rumale::Tree::VariableRandomTreeClassifier.new( | ||
# criterion: 'gini', max_depth: 3, max_leaf_nodes: 10, min_samples_leaf: 5, | ||
# alpha: 0.3, random_seed: 1 | ||
# ) | ||
# estimator.fit(training_samples, traininig_labels) | ||
# results = estimator.predict(testing_samples) | ||
# | ||
# *Reference* | ||
# - F. T. Liu, K. M. Ting, Y. Yu, and Z-H. Zhou, "Spectrum of Variable-Random Trees," Journal of Artificial Intelligence Research, vol. 32, pp. 355--384, 2008. | ||
class VariableRandomTreeClassifier < DecisionTreeClassifier | ||
# Return the class labels. | ||
# @return [Numo::Int32] (size: n_classes) | ||
attr_reader :classes | ||
|
||
# Return the importance for each feature. | ||
# @return [Numo::DFloat] (size: n_features) | ||
attr_reader :feature_importances | ||
|
||
# Return the learned tree. | ||
# @return [Node] | ||
attr_reader :tree | ||
|
||
# Return the random generator for random selection of feature index. | ||
# @return [Random] | ||
attr_reader :rng | ||
|
||
# Return the labels assigned each leaf. | ||
# @return [Numo::Int32] (size: n_leafs) | ||
attr_reader :leaf_labels | ||
|
||
# Create a new classifier with decision tree algorithm. | ||
# | ||
# @param criterion [String] The function to evaluate spliting point. Supported criteria are 'gini' and 'entropy'. | ||
# @param max_depth [Integer] The maximum depth of the tree. | ||
# If nil is given, decision tree grows without concern for depth. | ||
# @param max_leaf_nodes [Integer] The maximum number of leaves on decision tree. | ||
# If nil is given, number of leaves is not limited. | ||
# @param min_samples_leaf [Integer] The minimum number of samples at a leaf node. | ||
# @param max_features [Integer] The number of features to consider when searching optimal split point. | ||
# If nil is given, split process considers all features. | ||
# @param alpha [Float] The probability of choosing deterministic test-selection. | ||
# @param random_seed [Integer] The seed value using to initialize the random generator. | ||
# It is used to randomly determine the order of features when deciding spliting point. | ||
def initialize(criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, | ||
alpha: 0.3, random_seed: nil) | ||
check_params_numeric_or_nil(max_depth: max_depth, max_leaf_nodes: max_leaf_nodes, | ||
max_features: max_features, random_seed: random_seed) | ||
check_params_numeric(min_samples_leaf: min_samples_leaf, alpha: alpha) | ||
check_params_string(criterion: criterion) | ||
check_params_positive(max_depth: max_depth, max_leaf_nodes: max_leaf_nodes, | ||
min_samples_leaf: min_samples_leaf, max_features: max_features) | ||
keywd_args = method(:initialize).parameters.map { |_t, arg| [arg, binding.local_variable_get(arg)] }.to_h | ||
keywd_args.delete(:alpha) | ||
super(keywd_args) | ||
@params[:alpha] = alpha | ||
end | ||
|
||
# Fit the model with given training data. | ||
# | ||
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model. | ||
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model. | ||
# @return [VariableRandomTreeClassifier] The learned classifier itself. | ||
def fit(x, y) | ||
x = check_convert_sample_array(x) | ||
y = check_convert_label_array(y) | ||
check_sample_label_size(x, y) | ||
super | ||
end | ||
|
||
# Predict class labels for samples. | ||
# | ||
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the labels. | ||
# @return [Numo::Int32] (shape: [n_samples]) Predicted class label per sample. | ||
def predict(x) | ||
x = check_convert_sample_array(x) | ||
super | ||
end | ||
|
||
# Predict probability for samples. | ||
# | ||
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities. | ||
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample. | ||
def predict_proba(x) | ||
x = check_convert_sample_array(x) | ||
super | ||
end | ||
|
||
private | ||
|
||
def find_best_split(x, y, impurity) | ||
if @sub_rng.rand <= @params[:alpha] | ||
rand_ids.map { |n| [n, *best_split(x[true, n], y, impurity)] }.max_by(&:last) | ||
else | ||
n = rand_ids.first | ||
[n, *best_split_rand(x[true, n], y, impurity)] | ||
end | ||
end | ||
|
||
def best_split_rand(features, y, whole_impurity) | ||
threshold = @sub_rng.rand(features.min..features.max) | ||
l_ids = features.le(threshold).where | ||
r_ids = features.gt(threshold).where | ||
l_impurity = l_ids.empty? ? 0.0 : impurity(y[l_ids, true]) | ||
r_impurity = r_ids.empty? ? 0.0 : impurity(y[r_ids, true]) | ||
gain = whole_impurity - | ||
l_impurity * l_ids.size.fdiv(y.shape[0]) - | ||
r_impurity * r_ids.size.fdiv(y.shape[0]) | ||
[l_impurity, r_impurity, threshold, gain] | ||
end | ||
end | ||
end | ||
end |
152 changes: 152 additions & 0 deletions
152
spec/rumale/tree/variable_random_tree_classifier_spec.rb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
# frozen_string_literal: true | ||
|
||
require 'spec_helper' | ||
|
||
RSpec.describe Rumale::Tree::VariableRandomTreeClassifier do | ||
let(:x) { dataset[0] } | ||
let(:y) { dataset[1] } | ||
let(:classes) { y.to_a.uniq.sort } | ||
let(:n_samples) { x.shape[0] } | ||
let(:n_features) { x.shape[1] } | ||
let(:n_classes) { classes.size } | ||
let(:criterion) { 'gini' } | ||
let(:max_depth) { nil } | ||
let(:max_leaf_nodes) { nil } | ||
let(:min_samples_leaf) { 1 } | ||
let(:max_features) { nil } | ||
let(:alpha) { 0.3 } | ||
let(:estimator) do | ||
described_class.new(criterion: criterion, max_depth: max_depth, max_leaf_nodes: max_leaf_nodes, | ||
min_samples_leaf: min_samples_leaf, max_features: max_features, | ||
alpha: alpha, random_seed: 1).fit(x, y) | ||
end | ||
let(:probs) { estimator.predict_proba(x) } | ||
let(:predicted_by_probs) { Numo::Int32[*(Array.new(n_samples) { |n| classes[probs[n, true].max_index] })] } | ||
let(:score) { estimator.score(x, y) } | ||
let(:copied) { Marshal.load(Marshal.dump(estimator)) } | ||
|
||
context 'when binary classification problem' do | ||
let(:dataset) { two_clusters_dataset } | ||
|
||
it 'classifies two clusters data.', :aggregate_failures do | ||
expect(estimator.tree.class).to eq(Rumale::Tree::Node) | ||
expect(estimator.classes.class).to eq(Numo::Int32) | ||
expect(estimator.classes.ndim).to eq(1) | ||
expect(estimator.classes.shape[0]).to eq(n_classes) | ||
expect(estimator.feature_importances.class).to eq(Numo::DFloat) | ||
expect(estimator.feature_importances.ndim).to eq(1) | ||
expect(estimator.feature_importances.shape[0]).to eq(n_features) | ||
expect(score).to eq(1.0) | ||
end | ||
end | ||
|
||
context 'when multiclass classification problem' do | ||
let(:dataset) { three_clusters_dataset } | ||
|
||
it 'classifies three clusters data.', :aggregate_failures do | ||
expect(estimator.tree.class).to eq(Rumale::Tree::Node) | ||
expect(estimator.classes.class).to eq(Numo::Int32) | ||
expect(estimator.classes.ndim).to eq(1) | ||
expect(estimator.classes.shape[0]).to eq(n_classes) | ||
expect(estimator.feature_importances.class).to eq(Numo::DFloat) | ||
expect(estimator.feature_importances.ndim).to eq(1) | ||
expect(estimator.feature_importances.shape[0]).to eq(n_features) | ||
expect(score).to eq(1.0) | ||
end | ||
|
||
it 'estimates class probabilities with three clusters dataset.', :aggregate_failures do | ||
expect(probs.class).to eq(Numo::DFloat) | ||
expect(probs.ndim).to eq(2) | ||
expect(probs.shape[0]).to eq(n_samples) | ||
expect(probs.shape[1]).to eq(n_classes) | ||
expect(predicted_by_probs).to eq(y) | ||
end | ||
|
||
it 'dumps and restores itself using Marshal module.', :aggregated_failures do | ||
expect(estimator.class).to eq(copied.class) | ||
expect(estimator.classes).to eq(copied.classes) | ||
expect(estimator.feature_importances).to eq(copied.feature_importances) | ||
expect(estimator.rng).to eq(copied.rng) | ||
# FIXME: A slight error on the value of the threhold parameter occurs. | ||
# It seems to be caused by rounding error of Float. | ||
# expect(estimator.tree).to eq(copied.tree) | ||
expect(score).to eq(copied.score(x, y)) | ||
end | ||
|
||
context 'when max_depth parameter is given' do | ||
let(:max_depth) { 1 } | ||
|
||
it 'learns model with given parameters.', :aggregate_failures do | ||
expect(estimator.params[:max_depth]).to eq(max_depth) | ||
expect(estimator.tree.left.left).to be_nil | ||
expect(estimator.tree.left.right).to be_nil | ||
expect(estimator.tree.right.left).to be_nil | ||
expect(estimator.tree.right.right).to be_nil | ||
end | ||
end | ||
|
||
context 'when max_leaf_nodes parameter is given' do | ||
let(:max_leaf_nodes) { 2 } | ||
|
||
it 'learns model with given parameters.', :aggregate_failures do | ||
expect(estimator.params[:max_leaf_nodes]).to eq(max_leaf_nodes) | ||
expect(estimator.leaf_labels.size).to eq(max_leaf_nodes) | ||
end | ||
end | ||
|
||
context 'when min_samples_leaf parameter is given' do | ||
let(:min_samples_leaf) { 200 } | ||
|
||
it 'learns model with given parameters.', :aggregate_failures do | ||
expect(estimator.params[:min_samples_leaf]).to eq(min_samples_leaf) | ||
expect(estimator.tree.left.leaf).to be_truthy | ||
expect(estimator.tree.left.n_samples).to be >= min_samples_leaf | ||
expect(estimator.tree.right).to be_nil | ||
end | ||
end | ||
|
||
context 'when alpha parameter is given' do | ||
context 'with zero' do | ||
let(:alpha) { 0 } | ||
|
||
it 'behaves like a random tree' do | ||
expect(estimator.leaf_labels.size).to be > n_classes | ||
end | ||
end | ||
|
||
context 'with one' do | ||
let(:alpha) { 1 } | ||
|
||
it 'behaves like a decision tree' do | ||
expect(estimator.leaf_labels.size).to eq(n_classes) | ||
end | ||
end | ||
end | ||
|
||
context 'when max_features parameter is given' do | ||
context 'with negative value' do | ||
let(:max_features) { -10 } | ||
|
||
it 'raises ArgumentError by validation' do | ||
expect { estimator }.to raise_error(ArgumentError) | ||
end | ||
end | ||
|
||
context 'with value larger than number of features' do | ||
let(:max_features) { 10 } | ||
|
||
it 'value of max_features is equal to the number of features' do | ||
expect(estimator.params[:max_features]).to eq(x.shape[1]) | ||
end | ||
end | ||
|
||
context 'with valid value' do | ||
let(:max_features) { 2 } | ||
|
||
it 'learns model with given parameters.' do | ||
expect(estimator.params[:max_features]).to eq(2) | ||
end | ||
end | ||
end | ||
end | ||
end |