Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3326 from OXPHOS/cookbook_mclda
cookbook - multiclass lda
- Loading branch information
Showing
3 changed files
with
76 additions
and
1 deletion.
There are no files selected for viewing
Submodule data
updated
1 files
+16 −0 | testsuite/meta/multiclass_classifier/linear_discriminant_analysis.dat |
41 changes: 41 additions & 0 deletions
41
...cookbook/source/examples/multiclass_classifier/linear_discriminant_analysis.rst
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
============================ | ||
Linear Discriminant Analysis | ||
============================ | ||
|
||
This cookbook page introduces the application of | ||
`linear discriminant analysis <http://shogun.ml/cookbook/latest/examples/binary_classifier/lda.html>`_ | ||
to multi-class classifications. | ||
|
||
------- | ||
Example | ||
------- | ||
|
||
Imagine we have files with training and test data. We create CDenseFeatures (here 64 bit floats aka RealFeatures) and :sgclass:`CMulticlassLabels` as | ||
|
||
.. sgexample:: linear_discriminant_analysis.sg:create_features | ||
|
||
We create an instance of the :sgclass:`CMCLDA` classifier with feature matrix and label list. | ||
:sgclass:`CMCLDA` also has two default parameters, to set tolerance used in training and mark whether to store the within class covariances. | ||
|
||
.. sgexample:: linear_discriminant_analysis.sg:create_instance | ||
|
||
Then we train and apply it to the test data, which here gives :sgclass:`CMulticlassLabels`. | ||
|
||
.. sgexample:: linear_discriminant_analysis.sg:train_and_apply | ||
|
||
We can extract the mean vector of one class. | ||
If we enabled storing covariance when creating instances, we can also extract the covariance matrix: | ||
|
||
.. sgexample:: linear_discriminant_analysis.sg:extract_mean_and_cov | ||
|
||
We can evaluate test performance via e.g. :sgclass:`CMulticlassAccuracy`. | ||
|
||
.. sgexample:: linear_discriminant_analysis.sg:evaluate_accuracy | ||
|
||
---------- | ||
References | ||
---------- | ||
|
||
:wiki:`Linear_discriminant_analysis` | ||
|
||
:wiki:`Linear_discriminant_analysis#Multiclass_LDA` |
34 changes: 34 additions & 0 deletions
34
examples/meta/src/multiclass_classifier/linear_discriminant_analysis.sg
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
CSVFile f_feats_train("../../data/classifier_4class_2d_linear_features_train.dat") | ||
CSVFile f_feats_test("../../data/classifier_4class_2d_linear_features_test.dat") | ||
CSVFile f_labels_train("../../data/classifier_4class_2d_linear_labels_train.dat") | ||
CSVFile f_labels_test("../../data/classifier_4class_2d_linear_labels_test.dat") | ||
|
||
#![create_features] | ||
RealFeatures features_train(f_feats_train) | ||
RealFeatures features_test(f_feats_test) | ||
MulticlassLabels labels_train(f_labels_train) | ||
MulticlassLabels labels_test(f_labels_test) | ||
#![create_features] | ||
|
||
#![create_instance] | ||
MCLDA mc_lda(features_train, labels_train, 0.0001, True) | ||
#![create_instance] | ||
|
||
#![train_and_apply] | ||
mc_lda.train() | ||
MulticlassLabels labels_predict = mc_lda.apply_multiclass(features_test) | ||
#![train_and_apply] | ||
|
||
#![extract_mean_and_cov] | ||
int classlabel = 1 | ||
RealVector m = mc_lda.get_mean(classlabel) | ||
RealMatrix c = mc_lda.get_cov() | ||
#![extract_mean_and_cov] | ||
|
||
#![evaluate_accuracy] | ||
MulticlassAccuracy evals() | ||
real accuracy = evals.evaluate(labels_predict, labels_test) | ||
#![evaluate_accuracy] | ||
|
||
# additional integration testing variables | ||
RealVector output = labels_predict.get_labels() |