-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
✨ Add transformer class with Fisher Discriminant Analysis
- Loading branch information
Showing
3 changed files
with
202 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
111 changes: 111 additions & 0 deletions
111
lib/rumale/metric_learning/fisher_discriminant_analysis.rb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# frozen_string_literal: true | ||
|
||
require 'rumale/base/base_estimator' | ||
require 'rumale/base/transformer' | ||
|
||
module Rumale | ||
# Module for metric learning algorithms. | ||
module MetricLearning | ||
# FisherDiscriminantAnalysis is a class that implements Fisher Discriminant Analysis. | ||
# | ||
# @example | ||
# transformer = Rumale::MetricLearning::FisherDiscriminantAnalysis.new | ||
# transformer.fit(training_samples, traininig_labels) | ||
# low_samples = transformer.transform(testing_samples) | ||
# | ||
# *Reference* | ||
# - Fisher, R. A., "The use of multiple measurements in taxonomic problems," Annals of Eugenics, vol. 7, pp. 179--188, 1936. | ||
# - Sugiyama, M., "Local Fisher Discriminant Analysis for Supervised Dimensionality Reduction," Proc. ICML'06, pp. 905--912, 2006. | ||
class FisherDiscriminantAnalysis | ||
include Base::BaseEstimator | ||
include Base::Transformer | ||
|
||
# Returns the principal components. | ||
# @return [Numo::DFloat] (shape: [n_components, n_features]) | ||
attr_reader :components | ||
|
||
# Returns the mean vector. | ||
# @return [Numo::DFloat] (shape: [n_features]) | ||
attr_reader :mean | ||
|
||
# Returns the class mean vectors. | ||
# @return [Numo::DFloat] (shape: [n_classes, n_features]) | ||
attr_reader :class_means | ||
|
||
# Return the class labels. | ||
# @return [Numo::Int32] (shape: [n_classes]) | ||
attr_reader :classes | ||
|
||
# Create a new transformer with FisherDiscriminantAnalysis. | ||
# | ||
# @param n_components [Integer] The number of components. | ||
# If nil is given, the number of components will be set to [n_features, n_classes - 1].min | ||
def initialize(n_components: nil) | ||
check_params_numeric_or_nil(n_components: n_components) | ||
@params = {} | ||
@params[:n_components] = n_components | ||
end | ||
|
||
# Fit the model with given training data. | ||
# | ||
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model. | ||
# @param y [Numo::Int32] (shape: [n_samples]) The labels to be used for fitting the model. | ||
# @return [FisherDiscriminantAnalysis] The learned classifier itself. | ||
def fit(x, y) | ||
x = check_convert_sample_array(x) | ||
y = check_convert_label_array(y) | ||
check_sample_label_size(x, y) | ||
raise 'FisherDiscriminatAnalysis#fit requires Numo::Linalg but that is not loaded.' unless enable_linalg? | ||
|
||
# initialize some variables. | ||
n_features = x.shape[1] | ||
@classes = Numo::Int32[*y.to_a.uniq.sort] | ||
n_classes = @classes.size | ||
n_components = if @params[:n_components].nil? | ||
[n_features, n_classes - 1].min | ||
else | ||
[n_features, @params[:n_components]].min | ||
end | ||
|
||
# calculate within and between scatter matricies. | ||
within_mat = Numo::DFloat.zeros(n_features, n_features) | ||
between_mat = Numo::DFloat.zeros(n_features, n_features) | ||
@class_means = Numo::DFloat.zeros(n_classes, n_features) | ||
@mean = x.mean(0) | ||
@classes.each_with_index do |label, i| | ||
mask_vec = y.eq(label) | ||
sz_class = mask_vec.count | ||
class_samples = x[mask_vec, true] | ||
class_mean = class_samples.mean(0) | ||
within_mat += (class_samples - class_mean).transpose.dot(class_samples - class_mean) | ||
between_mat += sz_class * (class_mean - @mean).expand_dims(1) * (class_mean - @mean) | ||
@class_means[i, true] = class_mean | ||
end | ||
|
||
# calculate components. | ||
_, evecs = Numo::Linalg.eigh(between_mat, within_mat, vals_range: (n_features - n_components)...n_features) | ||
comps = evecs.reverse(1).transpose.dup | ||
@components = n_components == 1 ? comps[0, true].dup : comps.dup | ||
self | ||
end | ||
|
||
# Fit the model with training data, and then transform them with the learned model. | ||
# | ||
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The training data to be used for fitting the model. | ||
# @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data | ||
def fit_transform(x, y) | ||
x = check_convert_sample_array(x) | ||
fit(x, y).transform(x) | ||
end | ||
|
||
# Transform the given data with the learned model. | ||
# | ||
# @param x [Numo::DFloat] (shape: [n_samples, n_features]) The data to be transformed with the learned model. | ||
# @return [Numo::DFloat] (shape: [n_samples, n_components]) The transformed data. | ||
def transform(x) | ||
x = check_convert_sample_array(x) | ||
x.dot(@components.transpose) | ||
end | ||
end | ||
end | ||
end |
90 changes: 90 additions & 0 deletions
90
spec/rumale/metric_learning/fisher_discriminant_analysis_spec.rb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# frozen_string_literal: true | ||
|
||
require 'spec_helper' | ||
|
||
RSpec.describe Rumale::MetricLearning::FisherDiscriminantAnalysis do | ||
let(:dataset) { three_clusters_dataset } | ||
let(:x) do | ||
# This data occur sample overlap between classes by dimensionality reduction with PCA. | ||
Numo::DFloat.hstack([dataset[0], 10 * Rumale::Utils.rand_normal([dataset[0].shape[0], 1], Random.new(1))]) | ||
end | ||
let(:y) { dataset[1] } | ||
let(:classes) { y.to_a.uniq.sort } | ||
let(:n_samples) { x.shape[0] } | ||
let(:n_features) { x.shape[1] } | ||
let(:n_classes) { classes.size } | ||
let(:n_components) { nil } | ||
|
||
let(:transformer) { described_class.new(n_components: n_components) } | ||
let(:z) { transformer.fit_transform(x, y) } | ||
|
||
let(:splitter) { Rumale::ModelSelection::ShuffleSplit.new(n_splits: 1, test_size: 0.1, train_size: 0.9, random_seed: 1) } | ||
let(:validation_ids) { splitter.split(x, y).first } | ||
let(:train_ids) { validation_ids[0] } | ||
let(:test_ids) { validation_ids[1] } | ||
let(:x_train) { x[train_ids, true].dup } | ||
let(:x_test) { x[test_ids, true].dup } | ||
let(:y_train) { y[train_ids].dup } | ||
let(:y_test) { y[test_ids].dup } | ||
let(:classifier) { Rumale::NearestNeighbors::KNeighborsClassifier.new(n_neighbors: 1) } | ||
|
||
context 'when n_components is not given' do | ||
it 'projects data into subspace', :aggregate_failures do | ||
expect(z.class).to eq(Numo::DFloat) | ||
expect(z.ndim).to eq(2) | ||
expect(z.shape[0]).to eq(n_samples) | ||
expect(z.shape[1]).to eq(n_classes - 1) | ||
expect(transformer.components.class).to eq(Numo::DFloat) | ||
expect(transformer.components.ndim).to eq(2) | ||
expect(transformer.components.shape[0]).to eq(n_classes - 1) | ||
expect(transformer.components.shape[1]).to eq(n_features) | ||
expect(transformer.mean.class).to eq(Numo::DFloat) | ||
expect(transformer.mean.ndim).to eq(1) | ||
expect(transformer.mean.shape[0]).to eq(n_features) | ||
expect(transformer.class_means.class).to eq(Numo::DFloat) | ||
expect(transformer.class_means.ndim).to eq(2) | ||
expect(transformer.class_means.shape[0]).to eq(n_classes) | ||
expect(transformer.class_means.shape[1]).to eq(n_features) | ||
expect(transformer.classes.class).to eq(Numo::Int32) | ||
expect(transformer.classes.ndim).to eq(1) | ||
expect(transformer.classes.shape[0]).to eq(n_classes) | ||
end | ||
|
||
it 'projects data into a higly discriminating subspace', :aggregate_failures do | ||
z_train = transformer.fit_transform(x_train, y_train) | ||
z_test = transformer.transform(x_test) | ||
classifier.fit(z_train, y_train) | ||
expect(classifier.score(z_test, y_test)).to be_within(0.05).of(1.0) | ||
end | ||
end | ||
|
||
context 'when subspace dimensionality is one' do | ||
let(:n_components) { 1 } | ||
|
||
it 'projects data into one-dimensional subspace.', :aggregate_failures do | ||
expect(z.class).to eq(Numo::DFloat) | ||
expect(z.ndim).to eq(1) | ||
expect(z.shape[0]).to eq(n_samples) | ||
expect(transformer.components.class).to eq(Numo::DFloat) | ||
expect(transformer.components.ndim).to eq(1) | ||
expect(transformer.components.shape[0]).to eq(n_features) | ||
expect(transformer.mean.class).to eq(Numo::DFloat) | ||
expect(transformer.mean.ndim).to eq(1) | ||
expect(transformer.mean.shape[0]).to eq(n_features) | ||
expect(transformer.class_means.class).to eq(Numo::DFloat) | ||
expect(transformer.class_means.ndim).to eq(2) | ||
expect(transformer.class_means.shape[0]).to eq(n_classes) | ||
expect(transformer.class_means.shape[1]).to eq(n_features) | ||
end | ||
end | ||
|
||
it 'dumps and restores itself using Marshal module.', :aggregate_failures do | ||
copied = Marshal.load(Marshal.dump(transformer.fit(x, y))) | ||
expect(copied.class).to eq(transformer.class) | ||
expect(copied.params).to eq(transformer.params) | ||
expect(copied.components).to eq(transformer.components) | ||
expect(copied.mean).to eq(transformer.mean) | ||
expect(copied.class_means).to eq(transformer.class_means) | ||
expect(copied.classes).to eq(copied.classes) | ||
end | ||
end |