-
-
Notifications
You must be signed in to change notification settings - Fork 1k
/
RandomForest.h
176 lines (152 loc) · 6.15 KB
/
RandomForest.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
/*
* Copyright (c) The Shogun Machine Learning Toolbox
* Written (w) 2014 Parijat Mazumdar
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* The views and conclusions contained in the software and documentation are those
* of the authors and should not be interpreted as representing official policies,
* either expressed or implied, of the Shogun Development Team.
*/
#ifndef _RANDOMFOREST_H__
#define _RANDOMFOREST_H__
#include <shogun/lib/config.h>
#include <shogun/machine/BaggingMachine.h>
namespace shogun
{
/** @brief This class implements the Random Forests algorithm. In Random Forests algorithm, we train a number of randomized CART trees
* (see class CRandomCARTree) using the supplied training data. The number of trees to be trained is a parameter (called number of bags)
* controlled by the user. Test feature vectors are classified/regressed by combining the outputs of all these trained candidate trees using a
* combination rule (see class CombinationRule). The feature for calculating out-of-box error is also provided to help determine the
* appropriate number of bags. The evaluatin criteria for calculating this out-of-box error is specified by the user (see class CEvaluation).
*/
class RandomForest : public BaggingMachine
{
public:
/** constructor */
RandomForest();
/** constructor
*
* @param num_rand_feats number of attributes chosen randomly during node split in candidate trees
* @param num_bags number of trees in forest
*/
RandomForest(int32_t num_rand_feats, int32_t num_bags=10);
/** constructor
*
* @param features training features
* @param labels training labels
* @param num_bags number of trees in forest
* @param num_rand_feats number of attributes chosen randomly during node split in candidate trees
*/
RandomForest(std::shared_ptr<Features> features, std::shared_ptr<Labels> labels, int32_t num_bags=10, int32_t num_rand_feats=0);
/** constructor
*
* @param features training features
* @param labels training labels
* @param weights weights of training feature vectors
* @param num_bags number of trees in forest
* @param num_rand_feats number of attributes chosen randomly during node split in candidate trees
*/
RandomForest(std::shared_ptr<Features> features, std::shared_ptr<Labels> labels, SGVector<float64_t> weights, int32_t num_bags=10, int32_t num_rand_feats=0);
/** destructor */
virtual ~RandomForest();
/** get name
*
* @return RandomForest
*/
virtual const char* get_name() const { return "RandomForest"; }
/** machine is set to modified CART(RandomCART) and cannot be changed
*
* @param machine the machine to use for bagging
*/
virtual void set_machine(std::shared_ptr<Machine> machine);
/** set weights
*
* @param weights of training feature vectors
*/
void set_weights(SGVector<float64_t> weights);
/** get weights
*
* @return weights of training feature vectors
*/
SGVector<float64_t> get_weights() const;
/** set feature types of various features
*
* @param ft bool vector true for nominal feature false for continuous feature type
*/
void set_feature_types(SGVector<bool> ft);
/** get feature types of various features
*
* @return bool vector - true for nominal feature false for continuous feature type
*/
SGVector<bool> get_feature_types() const;
/** get problem type - multiclass classification or regression
*
* @return PT_MULTICLASS or PT_REGRESSION
*/
virtual EProblemType get_machine_problem_type() const;
/** set problem type - multiclass classification or regression
*
* @param mode EProblemType PT_MULTICLASS or PT_REGRESSION
*/
void set_machine_problem_type(EProblemType mode);
/** set number of random features to be chosen during node splits
*
* @param rand_featsize number of randomly chosen features during each node split
*/
void set_num_random_features(int32_t rand_featsize);
/** get number of random features to be chosen during node splits
*
* @return number of randomly chosen features during each node split
*/
int32_t get_num_random_features() const;
/** get feature importances of previous trained, use Mean Decrease
* Impurity(MDI)
*
* @return arrays of feature importance
*/
SGVector<float64_t> get_feature_importances() const;
protected:
virtual bool train_machine(std::shared_ptr<Features> data=NULL);
/** sets parameters of CARTree - sets machine labels and weights here
*
* @param m machine
* @param idx indices of training vectors chosen in current bag
*/
virtual void set_machine_parameters(std::shared_ptr<Machine> m, SGVector<index_t> idx);
private:
/** initialize parameters */
void init();
private:
/** weights */
SGVector<float64_t> m_weights;
/** Pre-sorted features */
SGMatrix<float64_t> m_sorted_transposed_feats;
/** Indices of pre-sorted features */
SGMatrix<index_t> m_sorted_indices;
#ifndef SWIG
public:
static constexpr std::string_view kWeights = "weights";
#endif
};
} /* namespace shogun */
#endif /* _RANDOMFOREST_H__ */