Skip to content

Commit

Permalink
🚀 Change automalically selected solver from sgd to lbfgs in LinearReg…
Browse files Browse the repository at this point in the history
…ression and Ridge
  • Loading branch information
yoshoku committed Apr 4, 2021
1 parent 7368b22 commit b2a4ec2
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 10 deletions.
8 changes: 5 additions & 3 deletions lib/rumale/linear_model/linear_regression.rb
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# frozen_string_literal: true

require 'lbfgsb'

require 'rumale/linear_model/base_sgd'
require 'rumale/base/regressor'

Expand Down Expand Up @@ -58,7 +60,7 @@ class LinearRegression < BaseSGD
# @param tol [Float] The tolerance of loss for terminating optimization.
# If solver is 'svd', this parameter is ignored.
# @param solver [String] The algorithm to calculate weights. ('auto', 'sgd', 'svd' or 'lbfgs').
# 'auto' chooses the 'svd' solver if Numo::Linalg is loaded. Otherwise, it chooses the 'sgd' solver.
# 'auto' chooses the 'svd' solver if Numo::Linalg is loaded. Otherwise, it chooses the 'lbfgs' solver.
# 'sgd' uses the stochastic gradient descent optimization.
# 'svd' performs singular value decomposition of samples.
# 'lbfgs' uses the L-BFGS method for optimization.
Expand All @@ -82,9 +84,9 @@ def initialize(learning_rate: 0.01, decay: nil, momentum: 0.9,
super()
@params.merge!(method(:initialize).parameters.map { |_t, arg| [arg, binding.local_variable_get(arg)] }.to_h)
@params[:solver] = if solver == 'auto'
enable_linalg?(warning: false) ? 'svd' : 'sgd'
enable_linalg?(warning: false) ? 'svd' : 'lbfgs'
else
solver.match?(/^svd$|^sgd$|^lbfgs$/) ? solver : 'sgd'
solver.match?(/^svd$|^sgd$|^lbfgs$/) ? solver : 'lbfgs'
end
@params[:decay] ||= @params[:learning_rate]
@params[:random_seed] ||= srand
Expand Down
6 changes: 3 additions & 3 deletions lib/rumale/linear_model/ridge.rb
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class Ridge < BaseSGD
# @param tol [Float] The tolerance of loss for terminating optimization.
# If solver is 'svd', this parameter is ignored.
# @param solver [String] The algorithm to calculate weights. ('auto', 'sgd', 'svd', or 'lbfgs').
# 'auto' chooses the 'svd' solver if Numo::Linalg is loaded. Otherwise, it chooses the 'sgd' solver.
# 'auto' chooses the 'svd' solver if Numo::Linalg is loaded. Otherwise, it chooses the 'lbfgs' solver.
# 'sgd' uses the stochastic gradient descent optimization.
# 'svd' performs singular value decomposition of samples.
# 'lbfgs' uses the L-BFGS method for optimization.
Expand All @@ -87,9 +87,9 @@ def initialize(learning_rate: 0.01, decay: nil, momentum: 0.9,
super()
@params.merge!(method(:initialize).parameters.map { |_t, arg| [arg, binding.local_variable_get(arg)] }.to_h)
@params[:solver] = if solver == 'auto'
enable_linalg?(warning: false) ? 'svd' : 'sgd'
enable_linalg?(warning: false) ? 'svd' : 'lbfgs'
else
solver.match?(/^svd$|^sgd$|^lbfgs$/) ? solver : 'sgd'
solver.match?(/^svd$|^sgd$|^lbfgs$/) ? solver : 'lbfgs'
end
@params[:decay] ||= @params[:reg_param] * @params[:learning_rate]
@params[:random_seed] ||= srand
Expand Down
4 changes: 2 additions & 2 deletions spec/rumale/linear_model/linear_regression_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@

after { Numo::Linalg = @backup }

it 'chooses "sgd" solver' do
expect(estimator.params[:solver]).to eq('sgd')
it 'chooses "lbfgs" solver' do
expect(estimator.params[:solver]).to eq('lbfgs')
end
end
end
Expand Down
4 changes: 2 additions & 2 deletions spec/rumale/linear_model/ridge_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@

after { Numo::Linalg = @backup }

it 'chooses "sgd" solver' do
expect(estimator.params[:solver]).to eq('sgd')
it 'chooses "lbfgs" solver' do
expect(estimator.params[:solver]).to eq('lbfgs')
end
end
end
Expand Down

0 comments on commit b2a4ec2

Please sign in to comment.