Skip to content

Commit

Permalink
🚀 Change mopti gem to non-rutime dependent library
Browse files Browse the repository at this point in the history
  • Loading branch information
yoshoku committed May 16, 2020
1 parent 7975316 commit 62a9782
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 2 deletions.
1 change: 1 addition & 0 deletions Gemfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ gemspec

gem 'coveralls', '~> 0.8'
gem 'mmh3', '>= 1.0'
gem 'mopti', '>= 0.1.0'
gem 'numo-linalg', '>= 0.1.4'
gem 'parallel', '>= 1.17.0'
gem 'rake', '~> 12.0'
Expand Down
14 changes: 13 additions & 1 deletion lib/rumale/metric_learning/neighbourhood_component_analysis.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

require 'rumale/base/base_estimator'
require 'rumale/base/transformer'
require 'mopti/scaled_conjugate_gradient'

module Rumale
module MetricLearning
# NeighbourhoodComponentAnalysis is a class that implements Neighbourhood Component Analysis.
#
# @example
# require 'mopti'
# require 'rumale'
#
# transformer = Rumale::MetricLearning::NeighbourhoodComponentAnalysis.new
# transformer.fit(training_samples, traininig_labels)
# low_samples = transformer.transform(testing_samples)
Expand Down Expand Up @@ -63,6 +65,8 @@ def initialize(n_components: nil, init: 'random', max_iter: 100, tol: 1e-6, verb
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
# @return [NeighbourhoodComponentAnalysis] The learned classifier itself.
def fit(x, y)
raise 'NeighbourhoodComponentAnalysis#fit requires Mopti but that is not loaded.' unless enable_mopti?

x = check_convert_sample_array(x)
y = check_convert_label_array(y)
check_sample_label_size(x, y)
Expand Down Expand Up @@ -98,6 +102,14 @@ def transform(x)

private

def enable_mopti?
if defined?(Mopti).nil?
warn('NeighbourhoodComponentAnalysis#fit requires Mopti but that is not loaded. You should intall and load mopti gem in advance.')
return false
end
true
end

def init_components(x, n_features, n_components)
if @params[:init] == 'pca'
pca = Rumale::Decomposition::PCA.new(n_components: n_components, solver: 'evd')
Expand Down
1 change: 0 additions & 1 deletion rumale.gemspec
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,4 @@ Gem::Specification.new do |spec|
}

spec.add_runtime_dependency 'numo-narray', '>= 0.9.1'
spec.add_runtime_dependency 'mopti', '>= 0.1.0'
end
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,19 @@
end
end

context 'when Mopti is not loaded' do
before do
@backup = Mopti
Object.send(:remove_const, :Mopti)
end

it 'raises Runtime error' do
expect { transformer.fit(x, y) }.to raise_error(RuntimeError)
end

after { Mopti = @backup }
end

it 'dumps and restores itself using Marshal module.', :aggregate_failures do
copied = Marshal.load(Marshal.dump(transformer.fit(x, y)))
expect(copied.class).to eq(transformer.class)
Expand Down
1 change: 1 addition & 0 deletions spec/spec_helper.rb
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
require 'rumale'
require 'parallel'
require 'mmh3'
require 'mopti'

def two_clusters_dataset
rng = Random.new(8)
Expand Down

0 comments on commit 62a9782

Please sign in to comment.