From 187c7fa9f578c803e88248218247c460247b91f1 Mon Sep 17 00:00:00 2001 From: yoshoku Date: Mon, 2 Mar 2020 21:19:59 +0900 Subject: [PATCH] :art: Add max_iter parameter to NCA --- .../neighbourhood_component_analysis.rb | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/lib/rumale/metric_learning/neighbourhood_component_analysis.rb b/lib/rumale/metric_learning/neighbourhood_component_analysis.rb index a52b99d4..83cc9405 100644 --- a/lib/rumale/metric_learning/neighbourhood_component_analysis.rb +++ b/lib/rumale/metric_learning/neighbourhood_component_analysis.rb @@ -35,17 +35,19 @@ class NeighbourhoodComponentAnalysis # # @param n_components [Integer] The number of components. # @param init [String] The initialization method for components ('random' or 'pca'). + # @param max_iter [Integer] The maximum number of iterations. # @param tol [Float] The tolerance of termination criterion. # @param verbose [Boolean] The flag indicating whether to output loss during iteration. # @param random_seed [Integer] The seed value using to initialize the random generator. - def initialize(n_components: nil, init: 'random', tol: 1e-6, verbose: false, random_seed: nil) + def initialize(n_components: nil, init: 'random', max_iter: 100, tol: 1e-6, verbose: false, random_seed: nil) check_params_numeric_or_nil(n_components: n_components, random_seed: random_seed) - check_params_numeric(tol: tol) + check_params_numeric(max_iter: max_iter, tol: tol) check_params_string(init: init) check_params_boolean(verbose: verbose) @params = {} @params[:n_components] = n_components @params[:init] = init + @params[:max_iter] = max_iter @params[:tol] = tol @params[:verbose] = verbose @params[:random_seed] = random_seed @@ -114,7 +116,9 @@ def optimize_components(x, y, n_features, n_components) res[:n_iter] = 0 # perform optimization. optimizer = Mopti::ScaledConjugateGradient.new( - fnc: method(:nca_loss), jcb: method(:nca_dloss), x_init: comp_init, args: [x, y], ftol: @params[:tol] + fnc: method(:nca_loss), jcb: method(:nca_dloss), + x_init: comp_init, args: [x, y], + max_iter: @params[:max_iter], ftol: @params[:tol] ) fold = 0.0 dold = 0.0