Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3301 from Saurabh7/rfcookbook
rf cookbook
- Loading branch information
Showing
3 changed files
with
87 additions
and
0 deletions.
There are no files selected for viewing
41 changes: 41 additions & 0 deletions
41
doc/cookbook/source/examples/multiclass_classifier/random_forest.rst
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
============= | ||
Random Forest | ||
============= | ||
|
||
A Random Forest is an ensemble learning method which implements multiple decision trees during training. It predicts by using a combination rule on the outputs of individual decision trees. | ||
|
||
See :cite:`Breiman2001` for a detailed introduction. | ||
|
||
------- | ||
Example | ||
------- | ||
|
||
CDenseFeatures (here 64 bit floats aka RealFeatures) and :sgclass:`CMulticlassLabels` are created from training and test data file | ||
|
||
.. sgexample:: random_forest.sg:create_features | ||
|
||
Combination rules to be used for prediction are derived form the :sgclass:`CCombinationRule` class. Here we create a :sgclass:`CMajorityVote` class to be used as a combination rule. | ||
|
||
.. sgexample:: random_forest.sg:create_combination_rule | ||
|
||
Next an instance of :sgclass:`CRandomForest` is created. The parameters provided are the number of attributes to be chosen randomly to select from and the number of trees. | ||
|
||
.. sgexample:: random_forest.sg:create_instance | ||
|
||
Then we run the train random forest and apply it to test data, which here gives :sgclass:`CMulticlassLabels`. | ||
|
||
.. sgexample:: random_forest.sg:train_and_apply | ||
|
||
We can evaluate test performance via e.g. :sgclass:`CMulticlassAccuracy` as well as get the "out of bag error". | ||
|
||
.. sgexample:: random_forest.sg:evaluate_accuracy | ||
|
||
---------- | ||
References | ||
---------- | ||
:wiki:`Random_forest` | ||
|
||
:wiki:`Out-of-bag_error` | ||
|
||
.. bibliography:: ../../references.bib | ||
:filter: docname in docnames |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
CSVFile f_feats_train("../../data/classifier_4class_2d_linear_features_train.dat") | ||
CSVFile f_feats_test("../../data/classifier_4class_2d_linear_features_test.dat") | ||
CSVFile f_labels_train("../../data/classifier_4class_2d_linear_labels_train.dat") | ||
CSVFile f_labels_test("../../data/classifier_4class_2d_linear_labels_test.dat") | ||
Math:init_random(1) | ||
|
||
#![create_features] | ||
RealFeatures features_train(f_feats_train) | ||
RealFeatures features_test(f_feats_test) | ||
MulticlassLabels labels_train(f_labels_train) | ||
MulticlassLabels labels_test(f_labels_test) | ||
#![create_features] | ||
|
||
#![create_combination_rule] | ||
MajorityVote m_vote() | ||
#![create_combination_rule] | ||
|
||
#![create_instance] | ||
RandomForest rand_forest(1,10) | ||
rand_forest.set_combination_rule(m_vote) | ||
rand_forest.set_labels(labels_train) | ||
#![create_instance] | ||
|
||
#![train_and_apply] | ||
rand_forest.train(features_train) | ||
MulticlassLabels labels_predict = rand_forest.apply_multiclass(features_test) | ||
#![train_and_apply] | ||
|
||
#![evaluate_accuracy] | ||
MulticlassAccuracy acc() | ||
real oob = rand_forest.get_oob_error(acc) | ||
real accuracy = acc.evaluate(labels_predict, labels_test) | ||
#![evaluate_accuracy] | ||
|
||
# additional integration testing variables | ||
RealVector output = labels_predict.get_labels() | ||
|