Skip to content

Commit

Permalink
✨ Add transformer class with Fisher Discriminant Analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
yoshoku committed Feb 27, 2020
1 parent bf1ef25 commit 6ecd1b8
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 0 deletions.
1 change: 1 addition & 0 deletions lib/rumale.rb
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
require 'rumale/decomposition/fast_ica'
require 'rumale/manifold/tsne'
require 'rumale/manifold/mds'
require 'rumale/metric_learning/fisher_discriminant_analysis'
require 'rumale/neural_network/adam'
require 'rumale/neural_network/base_mlp'
require 'rumale/neural_network/mlp_regressor'
Expand Down
111 changes: 111 additions & 0 deletions lib/rumale/metric_learning/fisher_discriminant_analysis.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# frozen_string_literal: true

require 'rumale/base/base_estimator'
require 'rumale/base/transformer'

module Rumale
# Module for metric learning algorithms.
module MetricLearning
# FisherDiscriminantAnalysis is a class that implements Fisher Discriminant Analysis.
#
# @example
# transformer = Rumale::MetricLearning::FisherDiscriminantAnalysis.new
# transformer.fit(training_samples, traininig_labels)
# low_samples = transformer.transform(testing_samples)
#
# *Reference*
# - Fisher, R. A., "The use of multiple measurements in taxonomic problems," Annals of Eugenics, vol. 7, pp. 179--188, 1936.
# - Sugiyama, M., "Local Fisher Discriminant Analysis for Supervised Dimensionality Reduction," Proc. ICML'06, pp. 905--912, 2006.
class FisherDiscriminantAnalysis
include Base::BaseEstimator
include Base::Transformer

# Returns the principal components.
# @return [Numo::DFloat] (shape: [n_components, n_features])
attr_reader :components

# Returns the mean vector.
# @return [Numo::DFloat] (shape: [n_features])
attr_reader :mean

# Returns the class mean vectors.
# @return [Numo::DFloat] (shape: [n_classes, n_features])
attr_reader :class_means

# Return the class labels.
# @return [Numo::Int32] (shape: [n_classes])
attr_reader :classes

# Create a new transformer with FisherDiscriminantAnalysis.
#
# @param n_components [Integer] The number of components.
# If nil is given, the number of components will be set to [n_features, n_classes - 1].min
def initialize(n_components: nil)
check_params_numeric_or_nil(n_components: n_components)
@params = {}
@params[:n_components] = n_components
end

# Fit the model with given training data.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model.
# @return [FisherDiscriminantAnalysis] The learned classifier itself.
def fit(x, y)
x = check_convert_sample_array(x)
y = check_convert_label_array(y)
check_sample_label_size(x, y)
raise 'FisherDiscriminatAnalysis#fit requires Numo::Linalg but that is not loaded.' unless enable_linalg?

# initialize some variables.
n_features = x.shape[1]
@classes = Numo::Int32[*y.to_a.uniq.sort]
n_classes = @classes.size
n_components = if @params[:n_components].nil?
[n_features, n_classes - 1].min
else
[n_features, @params[:n_components]].min
end

# calculate within and between scatter matricies.
within_mat = Numo::DFloat.zeros(n_features, n_features)
between_mat = Numo::DFloat.zeros(n_features, n_features)
@class_means = Numo::DFloat.zeros(n_classes, n_features)
@mean = x.mean(0)
@classes.each_with_index do |label, i|
mask_vec = y.eq(label)
sz_class = mask_vec.count
class_samples = x[mask_vec, true]
class_mean = class_samples.mean(0)
within_mat += (class_samples - class_mean).transpose.dot(class_samples - class_mean)
between_mat += sz_class * (class_mean - @mean).expand_dims(1) * (class_mean - @mean)
@class_means[i, true] = class_mean
end

# calculate components.
_, evecs = Numo::Linalg.eigh(between_mat, within_mat, vals_range: (n_features - n_components)...n_features)
comps = evecs.reverse(1).transpose.dup
@components = n_components == 1 ? comps[0, true].dup : comps.dup
self
end

# Fit the model with training data, and then transform them with the learned model.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model.
# @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data
def fit_transform(x, y)
x = check_convert_sample_array(x)
fit(x, y).transform(x)
end

# Transform the given data with the learned model.
#
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The data to be transformed with the learned model.
# @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data.
def transform(x)
x = check_convert_sample_array(x)
x.dot(@components.transpose)
end
end
end
end
90 changes: 90 additions & 0 deletions spec/rumale/metric_learning/fisher_discriminant_analysis_spec.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# frozen_string_literal: true

require 'spec_helper'

RSpec.describe Rumale::MetricLearning::FisherDiscriminantAnalysis do
let(:dataset) { three_clusters_dataset }
let(:x) do
# This data occur sample overlap between classes by dimensionality reduction with PCA.
Numo::DFloat.hstack([dataset[0], 10 * Rumale::Utils.rand_normal([dataset[0].shape[0], 1], Random.new(1))])
end
let(:y) { dataset[1] }
let(:classes) { y.to_a.uniq.sort }
let(:n_samples) { x.shape[0] }
let(:n_features) { x.shape[1] }
let(:n_classes) { classes.size }
let(:n_components) { nil }

let(:transformer) { described_class.new(n_components: n_components) }
let(:z) { transformer.fit_transform(x, y) }

let(:splitter) { Rumale::ModelSelection::ShuffleSplit.new(n_splits: 1, test_size: 0.1, train_size: 0.9, random_seed: 1) }
let(:validation_ids) { splitter.split(x, y).first }
let(:train_ids) { validation_ids[0] }
let(:test_ids) { validation_ids[1] }
let(:x_train) { x[train_ids, true].dup }
let(:x_test) { x[test_ids, true].dup }
let(:y_train) { y[train_ids].dup }
let(:y_test) { y[test_ids].dup }
let(:classifier) { Rumale::NearestNeighbors::KNeighborsClassifier.new(n_neighbors: 1) }

context 'when n_components is not given' do
it 'projects data into subspace', :aggregate_failures do
expect(z.class).to eq(Numo::DFloat)
expect(z.ndim).to eq(2)
expect(z.shape[0]).to eq(n_samples)
expect(z.shape[1]).to eq(n_classes - 1)
expect(transformer.components.class).to eq(Numo::DFloat)
expect(transformer.components.ndim).to eq(2)
expect(transformer.components.shape[0]).to eq(n_classes - 1)
expect(transformer.components.shape[1]).to eq(n_features)
expect(transformer.mean.class).to eq(Numo::DFloat)
expect(transformer.mean.ndim).to eq(1)
expect(transformer.mean.shape[0]).to eq(n_features)
expect(transformer.class_means.class).to eq(Numo::DFloat)
expect(transformer.class_means.ndim).to eq(2)
expect(transformer.class_means.shape[0]).to eq(n_classes)
expect(transformer.class_means.shape[1]).to eq(n_features)
expect(transformer.classes.class).to eq(Numo::Int32)
expect(transformer.classes.ndim).to eq(1)
expect(transformer.classes.shape[0]).to eq(n_classes)
end

it 'projects data into a higly discriminating subspace', :aggregate_failures do
z_train = transformer.fit_transform(x_train, y_train)
z_test = transformer.transform(x_test)
classifier.fit(z_train, y_train)
expect(classifier.score(z_test, y_test)).to be_within(0.05).of(1.0)
end
end

context 'when subspace dimensionality is one' do
let(:n_components) { 1 }

it 'projects data into one-dimensional subspace.', :aggregate_failures do
expect(z.class).to eq(Numo::DFloat)
expect(z.ndim).to eq(1)
expect(z.shape[0]).to eq(n_samples)
expect(transformer.components.class).to eq(Numo::DFloat)
expect(transformer.components.ndim).to eq(1)
expect(transformer.components.shape[0]).to eq(n_features)
expect(transformer.mean.class).to eq(Numo::DFloat)
expect(transformer.mean.ndim).to eq(1)
expect(transformer.mean.shape[0]).to eq(n_features)
expect(transformer.class_means.class).to eq(Numo::DFloat)
expect(transformer.class_means.ndim).to eq(2)
expect(transformer.class_means.shape[0]).to eq(n_classes)
expect(transformer.class_means.shape[1]).to eq(n_features)
end
end

it 'dumps and restores itself using Marshal module.', :aggregate_failures do
copied = Marshal.load(Marshal.dump(transformer.fit(x, y)))
expect(copied.class).to eq(transformer.class)
expect(copied.params).to eq(transformer.params)
expect(copied.components).to eq(transformer.components)
expect(copied.mean).to eq(transformer.mean)
expect(copied.class_means).to eq(transformer.class_means)
expect(copied.classes).to eq(copied.classes)
end
end

0 comments on commit 6ecd1b8

Please sign in to comment.