Skip to content

Commit

Permalink
feat: support for probabilistic outputs of one-class SVM
Browse files Browse the repository at this point in the history
  • Loading branch information
yoshoku committed Oct 10, 2022
1 parent 3ea4629 commit 38e933c
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 3 deletions.
23 changes: 21 additions & 2 deletions lib/rumale/svm/one_class_svm.rb
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,16 @@ class OneClassSVM
# @param gamma [Float] The gamma parameter in rbf/poly/sigmoid kernel function.
# @param coef0 [Float] The coefficient in poly/sigmoid kernel function.
# @param shrinking [Boolean] The flag indicating whether to use the shrinking heuristics.
# @param probability [Boolean] The flag indicating whether to train the parameter for probability estimation.
# @param cache_size [Float] The cache memory size in MB.
# @param tol [Float] The tolerance of termination criterion.
# @param verbose [Boolean] The flag indicating whether to output learning process message
# @param random_seed [Integer/Nil] The seed value using to initialize the random generator.
def initialize(nu: 1.0, kernel: 'rbf', degree: 3, gamma: 1.0, coef0: 0.0,
shrinking: true, cache_size: 200.0, tol: 1e-3, verbose: false, random_seed: nil)
shrinking: true, probability: true, cache_size: 200.0, tol: 1e-3, verbose: false, random_seed: nil)
check_params_numeric(nu: nu, degree: degree, gamma: gamma, coef0: coef0, cache_size: cache_size, tol: tol)
check_params_string(kernel: kernel)
check_params_boolean(shrinking: shrinking, verbose: verbose)
check_params_boolean(shrinking: shrinking, probability: probability, verbose: verbose)
check_params_numeric_or_nil(random_seed: random_seed)
@params = {}
@params[:nu] = nu.to_f
Expand All @@ -41,6 +42,7 @@ def initialize(nu: 1.0, kernel: 'rbf', degree: 3, gamma: 1.0, coef0: 0.0,
@params[:gamma] = gamma.to_f
@params[:coef0] = coef0.to_f
@params[:shrinking] = shrinking
@params[:probability] = probability
@params[:cache_size] = cache_size.to_f
@params[:tol] = tol.to_f
@params[:verbose] = verbose
Expand Down Expand Up @@ -82,6 +84,19 @@ def predict(x)
Numo::Int32.cast(Numo::Libsvm.predict(x, libsvm_params, @model))
end

# Predict class probability for samples.
# This method works correctly only if the probability parameter is true.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The samples to predict the probailities.
# If the kernel is 'precomputed', the shape of x must be [n_samples, n_training_samples].
# @return [Numo::DFloat] (shape: [n_samples, n_classes]) Predicted probability of each class per sample.
def predict_proba(x)
raise "#{self.class.name}\##{__method__} expects to be called after training the model with the fit method." unless trained?
raise "#{self.class.name}\##{__method__} expects to be called after training the probablity parameters." unless trained_probs?
x = check_convert_sample_array(x)
Numo::Libsvm.predict_proba(x, libsvm_params, @model)
end

# Dump marshal data.
# @return [Hash] The marshal data about SVC.
def marshal_dump
Expand Down Expand Up @@ -150,6 +165,10 @@ def libsvm_params
def trained?
!@model.nil?
end

def trained_probs?
@model[:prob_density_marks].is_a?(Numo::NArray)
end
end
end
end
2 changes: 1 addition & 1 deletion rumale-svm.gemspec
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@ Gem::Specification.new do |spec|
spec.require_paths = ['lib']

spec.add_dependency 'numo-liblinear', '~> 2.0'
spec.add_dependency 'numo-libsvm', '~> 2.0'
spec.add_dependency 'numo-libsvm', '~> 2.1'
spec.add_dependency 'rumale', '~> 0.14', '< 0.24'
end
8 changes: 8 additions & 0 deletions spec/rumale/svm/one_class_svm_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
shared_examples 'distribution estimation task' do
let(:dfs) { ocsvm.decision_function(x) }
let(:predicted) { ocsvm.predict(x) }
let(:probs) { ocsvm.predict_proba(x) }
let(:n_sv) { ocsvm.n_support }

before { ocsvm.fit(x_pos) }
Expand Down Expand Up @@ -58,6 +59,13 @@
expect(dfs.ge(0).count).to eq(predicted.eq(1).count)
end

it 'estimates probabilities', :aggregate_failures do
expect(probs.class).to eq(Numo::DFloat)
expect(probs.ndim).to eq(2)
expect(probs.shape[0]).to eq(n_samples)
expect(probs.shape[1]).to eq(2)
end

it 'dumps and restores itself using Marshal module', :aggregate_failures do
expect(copied.instance_variable_get(:@model)).to eq(ocsvm.instance_variable_get(:@model))
expect(copied.params).to eq(ocsvm.params)
Expand Down

0 comments on commit 38e933c

Please sign in to comment.