From 22194405e445121633ee82b1d9d4ab19e4bbc1ec Mon Sep 17 00:00:00 2001 From: Saurabh7 Date: Thu, 16 Jun 2016 21:26:51 +0530 Subject: [PATCH] update CART and random Forest --- src/shogun/machine/RandomForest.cpp | 18 +- src/shogun/machine/RandomForest.h | 7 + src/shogun/multiclass/tree/CARTree.cpp | 238 +++++++++++++---- src/shogun/multiclass/tree/CARTree.h | 27 +- src/shogun/multiclass/tree/RandomCARTree.cpp | 243 ++---------------- src/shogun/multiclass/tree/RandomCARTree.h | 6 +- .../multiclass/tree/RandomForest_unittest.cc | 93 ++++++- 7 files changed, 333 insertions(+), 299 deletions(-) diff --git a/src/shogun/machine/RandomForest.cpp b/src/shogun/machine/RandomForest.cpp index 6fc4798bf74..3cbe9cae426 100644 --- a/src/shogun/machine/RandomForest.cpp +++ b/src/shogun/machine/RandomForest.cpp @@ -159,11 +159,27 @@ void CRandomForest::set_machine_parameters(CMachine* m, SGVector idx) } tree->set_weights(weights); - + tree->set_sorted_features(m_sorted_transposed_feats, m_sorted_indices); // equate the machine problem types - cloning does not do this tree->set_machine_problem_type(dynamic_cast(m_machine)->get_machine_problem_type()); } +bool CRandomForest::train_machine(CFeatures* data) +{ + if (data) + { + SG_REF(data); + SG_UNREF(m_features); + m_features = data; + } + + REQUIRE(m_features, "Training features not set!\n"); + + dynamic_cast(m_machine)->pre_sort_features(m_features, m_sorted_transposed_feats, m_sorted_indices); + + return CBaggingMachine::train_machine(); +} + void CRandomForest::init() { m_machine=new CRandomCARTree(); diff --git a/src/shogun/machine/RandomForest.h b/src/shogun/machine/RandomForest.h index 2c5bb2f6549..405d317c462 100644 --- a/src/shogun/machine/RandomForest.h +++ b/src/shogun/machine/RandomForest.h @@ -139,6 +139,8 @@ class CRandomForest : public CBaggingMachine int32_t get_num_random_features() const; protected: + + virtual bool train_machine(CFeatures* data=NULL); /** sets parameters of CARTree - sets machine labels and weights here * * @param m machine @@ -154,6 +156,11 @@ class CRandomForest : public CBaggingMachine /** weights */ SGVector m_weights; + /** Pre-sorted features */ + SGMatrix m_sorted_transposed_feats; + + /** Indices of pre-sorted features */ + SGMatrix m_sorted_indices; }; } /* namespace shogun */ #endif /* _RANDOMFOREST_H__ */ diff --git a/src/shogun/multiclass/tree/CARTree.cpp b/src/shogun/multiclass/tree/CARTree.cpp index deeea3fed0d..1b0f62361fb 100644 --- a/src/shogun/multiclass/tree/CARTree.cpp +++ b/src/shogun/multiclass/tree/CARTree.cpp @@ -30,7 +30,10 @@ #include #include +#include +#include +using namespace Eigen; using namespace shogun; const float64_t CCARTree::MISSING=CMath::MAX_REAL_NUMBER; @@ -102,6 +105,8 @@ CMulticlassLabels* CCARTree::apply_multiclass(CFeatures* data) // apply multiclass starting from root bnode_t* current=dynamic_cast(get_root()); + + REQUIRE(current, "Tree machine not yet trained.\n"); CLabels* ret=apply_from_current_node(dynamic_cast*>(data), current); SG_UNREF(current); @@ -282,6 +287,33 @@ bool CCARTree::train_machine(CFeatures* data) return true; } +void CCARTree::set_sorted_features(SGMatrix& sorted_feats, SGMatrix& sorted_indices) +{ + m_pre_sort=true; + m_sorted_features=sorted_feats; + m_sorted_indices=sorted_indices; +} + +void CCARTree::pre_sort_features(CFeatures* data, SGMatrix& sorted_feats, SGMatrix& sorted_indices) +{ + SGMatrix mat=(dynamic_cast*>(data))->get_feature_matrix(); + sorted_feats = SGMatrix(mat.num_cols, mat.num_rows); + sorted_indices = SGMatrix(mat.num_cols, mat.num_rows); + for(int32_t i=0; i map_sorted_feats(sorted_feats.matrix, mat.num_cols, mat.num_rows); + Map map_data(mat.matrix, mat.num_rows, mat.num_cols); + + map_sorted_feats=map_data.transpose(); + + #pragma omp parallel for + for(int32_t i=0; i* CCARTree::CARTtrain(CFeatures* data, SGVector weights, CLabels* labels, int32_t level) { REQUIRE(labels,"labels have to be supplied\n"); @@ -381,8 +413,21 @@ CBinaryTreeMachineNode* CCARTree::CARTtrain(CFeatures* data, SG int32_t num_missing_final=0; int32_t c_left=-1; int32_t c_right=-1; - - int32_t best_attribute=compute_best_attribute(mat,weights,labels_vec,left,right,left_final,num_missing_final,c_left,c_right); + int32_t best_attribute; + + SGVector indices(num_vecs); + if (m_pre_sort) + { + CSubsetStack* subset_stack = data->get_subset_stack(); + if (subset_stack->has_subsets()) + indices=(subset_stack->get_last_subset())->get_subset_idx(); + else + indices.range_fill(); + SG_UNREF(subset_stack); + best_attribute=compute_best_attribute(m_sorted_features,weights,labels,left,right,left_final,num_missing_final,c_left,c_right,0,indices); + } + else + best_attribute=compute_best_attribute(mat,weights,labels,left,right,left_final,num_missing_final,c_left,c_right); if (best_attribute==-1) { @@ -483,12 +528,17 @@ SGVector CCARTree::get_unique_labels(SGVector labels_vec, return ulabels; } -int32_t CCARTree::compute_best_attribute(SGMatrix mat, SGVector weights, SGVector labels_vec, - SGVector left, SGVector right, SGVector is_left_final, int32_t &num_missing_final, int32_t &count_left, - int32_t &count_right) +int32_t CCARTree::compute_best_attribute(const SGMatrix& mat, const SGVector& weights, CLabels* labels, + SGVector& left, SGVector& right, SGVector& is_left_final, int32_t &num_missing_final, int32_t &count_left, + int32_t &count_right, int32_t subset_size, const SGVector& active_indices) { - int32_t num_vecs=mat.num_cols; - int32_t num_feats=mat.num_rows; + SGVector labels_vec=(dynamic_cast(labels))->get_labels(); + int32_t num_vecs=labels->get_num_labels(); + int32_t num_feats; + if (m_pre_sort) + num_feats=mat.num_cols; + else + num_feats=mat.num_rows; int32_t n_ulabels; SGVector ulabels=get_unique_labels(labels_vec,n_ulabels); @@ -517,56 +567,112 @@ int32_t CCARTree::compute_best_attribute(SGMatrix mat, SGVector idx(num_feats); + idx.range_fill(); + if (subset_size) + { + num_feats=subset_size; + CMath::permute(idx); + } float64_t max_gain=MIN_SPLIT_GAIN; int32_t best_attribute=-1; float64_t best_threshold=0; + + SGVector indices_mask; + SGVector count_indices(mat.num_rows); + count_indices.zero(); + SGVector dupes(num_vecs); + dupes.range_fill(); + if (m_pre_sort) + { + indices_mask = SGVector(mat.num_rows); + indices_mask.set_const(-1); + for(int32_t j=0;j=0) + dupes[indices_mask[active_indices[j]]]=j; + + indices_mask[active_indices[j]]=j; + count_indices[active_indices[j]]++; + } + } + for (int32_t i=0;i feats(num_vecs); - for (int32_t j=0;j sorted_args(num_vecs); + SGVector temp_count_indices(count_indices.size()); + memcpy(temp_count_indices.vector, count_indices.vector, sizeof(int32_t)*count_indices.size()); - // O(N*logN) - SGVector sorted_args=CMath::argsort(feats); + if (m_pre_sort) + { + SGVector temp_col(mat.get_column_vector(idx[i]), mat.num_rows, false); + SGVector sorted_indices(m_sorted_indices.get_column_vector(idx[i]), mat.num_rows, false); + int32_t count=0; + for(int32_t j=0;j=0) + { + while(temp_count_indices[sorted_indices[j]]>0) + { + feats[count]=temp_col[j]; + sorted_args[count]=indices_mask[sorted_indices[j]]; + ++count; + --temp_count_indices[sorted_indices[j]]; + } + if (count==num_vecs) + break; + } + } + } + else + { + for (int32_t j=0;j simple_feats(num_vecs); simple_feats.fill_vector(simple_feats.vector,simple_feats.vlen,-1); // convert to simple values - simple_feats[sorted_args[0]]=0; + simple_feats[0]=0; int32_t c=0; for (int32_t j=1;j ufeats(c+1); - ufeats[0]=feats[sorted_args[0]]; + ufeats[0]=feats[0]; int32_t u=0; for (int32_t j=1;j mat, SGVector=0;j--) + { + if(dupes[j]!=j) + is_left[j]=is_left[dupes[j]]; + } + float64_t g=0; if (m_mode==PT_MULTICLASS) g=gain(wleft,wright,total_wclasses); @@ -609,7 +720,7 @@ int32_t CCARTree::compute_best_attribute(SGMatrix mat, SGVectormax_gain) { - best_attribute=i; + best_attribute=idx[i]; max_gain=g; memcpy(is_left_final.vector,is_left.vector,is_left.vlen*sizeof(bool)); num_missing_final=num_vecs-n_nm_vecs; @@ -641,18 +752,17 @@ int32_t CCARTree::compute_best_attribute(SGMatrix mat, SGVector mat, SGVectormax_gain) { max_gain=g; - best_attribute=i; + best_attribute=idx[i]; best_threshold=z; num_missing_final=num_vecs-n_nm_vecs; } - z=feats[sorted_args[j]]; - if (feats[sorted_args[n_nm_vecs-1]]<=z+EQ_DELTA) + z=feats[j]; + if (feats[n_nm_vecs-1]<=z+EQ_DELTA) break; - right_wclasses[simple_labels[sorted_args[j]]]-=weights[sorted_args[j]]; left_wclasses[simple_labels[sorted_args[j]]]+=weights[sorted_args[j]]; } @@ -696,8 +805,33 @@ int32_t CCARTree::compute_best_attribute(SGMatrix mat, SGVector temp_vec(mat.get_column_vector(best_attribute), mat.num_rows, false); + SGVector sorted_indices(m_sorted_indices.get_column_vector(best_attribute), mat.num_rows, false); + int32_t count=0; + for(int32_t i=0;i=0) + { + is_left_final[indices_mask[sorted_indices[i]]]=(temp_vec[i]<=best_threshold); + ++count; + if (count==num_vecs) + break; + } + } + for (int32_t i=num_vecs-1;i>=0;i--) + { + if(dupes[i]!=i) + is_left_final[i]=is_left_final[dupes[i]]; + } + + } + else + { + for (int32_t i=0;i wleft, SGVector wright, return lsd_n-(lsd_l*(total_lweight/total_weight))-(lsd_r*(total_rweight/total_weight)); } -float64_t CCARTree::gain(SGVector wleft, SGVector wright, SGVector wtotal) +float64_t CCARTree::gain(const SGVector& wleft, const SGVector& wright, const SGVector& wtotal) { float64_t total_lweight=0; float64_t total_rweight=0; @@ -941,29 +1075,23 @@ float64_t CCARTree::gain(SGVector wleft, SGVector wright, return gini_n-(gini_l*(total_lweight/total_weight))-(gini_r*(total_rweight/total_weight)); } -float64_t CCARTree::gini_impurity_index(SGVector weighted_lab_classes, float64_t &total_weight) +float64_t CCARTree::gini_impurity_index(const SGVector& weighted_lab_classes, float64_t &total_weight) { - total_weight=0; - float64_t gini=0; - for (int32_t i=0;i map_weighted_lab_classes(weighted_lab_classes.vector, weighted_lab_classes.size()); + total_weight=map_weighted_lab_classes.sum(); + float64_t gini=map_weighted_lab_classes.dot(map_weighted_lab_classes); gini=1.0-(gini/(total_weight*total_weight)); return gini; } -float64_t CCARTree::least_squares_deviation(SGVector feats, SGVector weights, float64_t &total_weight) +float64_t CCARTree::least_squares_deviation(const SGVector& feats, const SGVector& weights, float64_t &total_weight) { - float64_t mean=0; - total_weight=0; - for (int32_t i=0;i map_weights(weights.vector, weights.size()); + Map map_feats(feats.vector, weights.size()); + float64_t mean=map_weights.dot(map_feats); + total_weight=map_weights.sum(); mean/=total_weight; float64_t dev=0; @@ -976,6 +1104,8 @@ float64_t CCARTree::least_squares_deviation(SGVector feats, SGVector< CLabels* CCARTree::apply_from_current_node(CDenseFeatures* feats, bnode_t* current) { int32_t num_vecs=feats->get_num_vectors(); + REQUIRE(num_vecs>0, "No data provided in apply\n"); + SGVector labels(num_vecs); for (int32_t i=0;i* tree) { + REQUIRE(tree, "Tree not provided for pruning.\n"); + CDynamicObjectArray* trees=new CDynamicObjectArray(); SG_UNREF(m_alphas); m_alphas=new CDynamicArray(); @@ -1363,6 +1495,7 @@ void CCARTree::init() m_nominal=SGVector(); m_weights=SGVector(); m_mode=PT_MULTICLASS; + m_pre_sort=false; m_types_set=false; m_weights_set=false; m_apply_cv_pruning=false; @@ -1373,6 +1506,9 @@ void CCARTree::init() m_min_node_size=0; m_label_epsilon=1e-7; + SG_ADD(&m_pre_sort,"m_pre_sort","presort", MS_NOT_AVAILABLE); + SG_ADD(&m_sorted_features,"m_sorted_features", "sorted feats", MS_NOT_AVAILABLE); + SG_ADD(&m_sorted_indices,"m_sorted_indices", "sorted indices", MS_NOT_AVAILABLE); SG_ADD(&m_nominal,"m_nominal", "feature types", MS_NOT_AVAILABLE); SG_ADD(&m_weights,"m_weights", "weights", MS_NOT_AVAILABLE); SG_ADD(&m_weights_set,"m_weights_set", "weights set", MS_NOT_AVAILABLE); diff --git a/src/shogun/multiclass/tree/CARTree.h b/src/shogun/multiclass/tree/CARTree.h index 7426c34ef17..20336393392 100644 --- a/src/shogun/multiclass/tree/CARTree.h +++ b/src/shogun/multiclass/tree/CARTree.h @@ -225,8 +225,11 @@ class CCARTree : public CTreeMachine */ void set_label_epsilon(float64_t epsilon); -protected: + void pre_sort_features(CFeatures* data, SGMatrix& sorted_feats, SGMatrix& sorted_indices); + + void set_sorted_features(SGMatrix& sorted_feats, SGMatrix& sorted_indices); +protected: /** train machine - build CART from training data * @param data training data * @return true @@ -264,9 +267,9 @@ class CCARTree : public CTreeMachine * @param count_right stores number of feature values for right transition * @return index to the best attribute */ - virtual int32_t compute_best_attribute(SGMatrix mat, SGVector weights, SGVector labels_vec, - SGVector left, SGVector right, SGVector is_left_final, int32_t &num_missing, - int32_t &count_left, int32_t &count_right); + virtual int32_t compute_best_attribute(const SGMatrix& mat, const SGVector& weights, CLabels* labels, + SGVector& left, SGVector& right, SGVector& is_left_final, int32_t &num_missing, + int32_t &count_left, int32_t &count_right, int32_t subset_size=0, const SGVector& active_indices=SGVector()); /** handles missing values through surrogate splits @@ -329,7 +332,7 @@ class CCARTree : public CTreeMachine * @param wtotal label distribution in current node * @return Gini gain achieved after spliting the node */ - float64_t gain(SGVector wleft, SGVector wright, SGVector wtotal); + float64_t gain(const SGVector& wleft, const SGVector& wright, const SGVector& wtotal); /** returns Gini impurity of a node * @@ -337,7 +340,7 @@ class CCARTree : public CTreeMachine * @param total_weight stores the total weight of all classes * @return Gini index of the node */ - float64_t gini_impurity_index(SGVector weighted_lab_classes, float64_t &total_weight); + float64_t gini_impurity_index(const SGVector& weighted_lab_classes, float64_t &total_weight); /** returns least squares deviation * @@ -346,7 +349,7 @@ class CCARTree : public CTreeMachine * @param total_weight stores sum of weights in weights vector * @return least squares deviation of the data */ - float64_t least_squares_deviation(SGVector labels, SGVector weights, float64_t &total_weight); + float64_t least_squares_deviation(const SGVector& labels, const SGVector& weights, float64_t &total_weight); /** uses current subtree to classify/regress data * @@ -404,6 +407,7 @@ class CCARTree : public CTreeMachine /** initializes members of class */ void init(); + public: /** denotes that a feature in a vector is missing MISSING = NOT_A_NUMBER */ static const float64_t MISSING; @@ -424,6 +428,15 @@ class CCARTree : public CTreeMachine /** weights of samples in training set **/ SGVector m_weights; + /** sorted transposed features */ + SGMatrix m_sorted_features; + + /** sorted indices */ + SGMatrix m_sorted_indices; + + /** If pre sorted features are used in train */ + bool m_pre_sort; + /** flag storing whether the type of various feature dimensions are specified using is_nominal_feature **/ bool m_types_set; diff --git a/src/shogun/multiclass/tree/RandomCARTree.cpp b/src/shogun/multiclass/tree/RandomCARTree.cpp index 84c4654fc0c..f762e52ef11 100644 --- a/src/shogun/multiclass/tree/RandomCARTree.cpp +++ b/src/shogun/multiclass/tree/RandomCARTree.cpp @@ -49,238 +49,27 @@ void CRandomCARTree::set_feature_subset_size(int32_t size) m_randsubset_size=size; } -int32_t CRandomCARTree::compute_best_attribute(SGMatrix mat, SGVector weights, SGVector labels_vec, - SGVector left, SGVector right, SGVector is_left_final, int32_t &num_missing_final, int32_t &count_left, - int32_t &count_right) -{ - int32_t num_vecs=mat.num_cols; - int32_t num_feats=mat.num_rows; - - int32_t n_ulabels; - SGVector ulabels=get_unique_labels(labels_vec,n_ulabels); - - // if all labels same early stop - if (n_ulabels==1) - return -1; - - SGVector total_wclasses(n_ulabels); - total_wclasses.zero(); - - SGVector simple_labels(num_vecs); - float64_t delta=0; - if (m_mode==PT_REGRESSION) - delta=m_label_epsilon; - - for (int32_t i=0;i& mat, const SGVector& weights, CLabels* labels, + SGVector& left, SGVector& right, SGVector& is_left_final, int32_t &num_missing_final, int32_t &count_left, + int32_t &count_right, int32_t subset_size, const SGVector& active_indices) +{ + int32_t num_feats; + if(m_pre_sort) + num_feats=mat.num_cols; + else + num_feats=mat.num_rows; + // if subset size is not set choose sqrt(num_feats) by default if (m_randsubset_size==0) - m_randsubset_size=CMath::sqrt(num_feats-0.f); - - // randomly choose w/o replacement the attributes from which best will be chosen - // randomly permute and choose 1st randsubset_size elements - SGVector idx(num_feats); - idx.range_fill(); - CMath::permute(idx); - - float64_t max_gain=MIN_SPLIT_GAIN; - int32_t best_attribute=-1; - float64_t best_threshold=0; - for (int32_t i=0;i feats(num_vecs); - for (int32_t j=0;j sorted_args=CMath::argsort(feats); - - // number of non-missing vecs - int32_t n_nm_vecs=feats.vlen; - while (feats[sorted_args[n_nm_vecs-1]]==MISSING) - { - total_wclasses[simple_labels[sorted_args[n_nm_vecs-1]]]-=weights[sorted_args[n_nm_vecs-1]]; - n_nm_vecs--; - } - - // if only one unique value - it cannot be used to split - if (feats[sorted_args[n_nm_vecs-1]]<=feats[sorted_args[0]]+EQ_DELTA) - continue; - - if (m_nominal[idx[i]]) - { - SGVector simple_feats(num_vecs); - simple_feats.fill_vector(simple_feats.vector,simple_feats.vlen,-1); - - // convert to simple values - simple_feats[sorted_args[0]]=0; - int32_t c=0; - for (int32_t j=1;j ufeats(c+1); - ufeats[0]=feats[sorted_args[0]]; - int32_t u=0; - for (int32_t j=1;j wleft(n_ulabels); - SGVector wright(n_ulabels); - wleft.zero(); - wright.zero(); - - // stores which vectors are assigned to left child - SGVector is_left(num_vecs); - is_left.fill_vector(is_left.vector,is_left.vlen,false); - - // stores which among the categorical values of chosen attribute are assigned left child - SGVector feats_left(c+1); - - // fill feats_left in a unique way corresponding to the case - for (int32_t p=0;pmax_gain) - { - best_attribute=idx[i]; - max_gain=g; - memcpy(is_left_final.vector,is_left.vector,is_left.vlen*sizeof(bool)); - num_missing_final=num_vecs-n_nm_vecs; - - count_left=0; - for (int32_t l=0;l right_wclasses=total_wclasses.clone(); - SGVector left_wclasses(n_ulabels); - left_wclasses.zero(); - - // O(N) - // find best split for non-nominal attribute - choose threshold (z) - float64_t z=feats[sorted_args[0]]; - right_wclasses[simple_labels[sorted_args[0]]]-=weights[sorted_args[0]]; - left_wclasses[simple_labels[sorted_args[0]]]+=weights[sorted_args[0]]; - for (int32_t j=1;jmax_gain) - { - max_gain=g; - best_attribute=idx[i]; - best_threshold=z; - num_missing_final=num_vecs-n_nm_vecs; - } - - z=feats[sorted_args[j]]; - if (feats[sorted_args[n_nm_vecs-1]]<=z+EQ_DELTA) - break; - - right_wclasses[simple_labels[sorted_args[j]]]-=weights[sorted_args[j]]; - left_wclasses[simple_labels[sorted_args[j]]]+=weights[sorted_args[j]]; - } - } - - // restore total_wclasses - while (n_nm_vecs mat, SGVector weights, SGVector labels_vec, - SGVector left, SGVector right, SGVector is_left_final, int32_t &num_missing, int32_t &count_left, - int32_t &count_right); + virtual int32_t compute_best_attribute(const SGMatrix& mat, const SGVector& weights, CLabels* labels, + SGVector& left, SGVector& right, SGVector& is_left_final, int32_t &num_missing, + int32_t &count_left, int32_t &count_right, int32_t subset_size=0, const SGVector& active_indices=SGVector()); private: /** initialize parameters */ diff --git a/tests/unit/multiclass/tree/RandomForest_unittest.cc b/tests/unit/multiclass/tree/RandomForest_unittest.cc index 92cda14988e..61fa729c3cd 100644 --- a/tests/unit/multiclass/tree/RandomForest_unittest.cc +++ b/tests/unit/multiclass/tree/RandomForest_unittest.cc @@ -51,12 +51,8 @@ using namespace shogun; #define weak 1. #define strong 2. -TEST(RandomForest,classify_test) +void generate_nm_data(SGMatrix& data, SGVector& lab) { - sg_rand->set_seed(1); - - SGMatrix data(4,14); - //vector = [Outlook Temperature Humidity Wind] data(0,0)=sunny; data(1,0)=hot; @@ -127,11 +123,7 @@ TEST(RandomForest,classify_test) data(1,13)=mild; data(2,13)=high; data(3,13)=strong; - - CDenseFeatures* feats=new CDenseFeatures(data); - - // yes 1. no 0. - SGVector lab(14); + lab[0]=0.0; lab[1]=0.0; lab[2]=1.0; @@ -146,6 +138,18 @@ TEST(RandomForest,classify_test) lab[11]=1.0; lab[12]=1.0; lab[13]=0.0; +} + +TEST(RandomForest,classify_nominal_test) +{ + sg_rand->set_seed(1); + + SGMatrix data(4,14); + SGVector lab(14); + + generate_nm_data(data, lab); + + CDenseFeatures* feats=new CDenseFeatures(data); SGVector ft=SGVector(4); ft[0]=true; @@ -204,3 +208,72 @@ TEST(RandomForest,classify_test) SG_UNREF(c); SG_UNREF(eval); } + +TEST(RandomForest,classify_non_nominal_test) +{ + sg_rand->set_seed(1); + + SGMatrix data(4,14); + SGVector lab(14); + + generate_nm_data(data, lab); + + CDenseFeatures* feats=new CDenseFeatures(data); + + SGVector ft=SGVector(4); + ft[0]=false; + ft[1]=false; + ft[2]=false; + ft[3]=false; + + CMulticlassLabels* labels=new CMulticlassLabels(lab); + + CRandomForest* c=new CRandomForest(feats, labels, 100,2); + c->set_feature_types(ft); + CMajorityVote* mv = new CMajorityVote(); + c->set_combination_rule(mv); + c->train(feats); + + SGMatrix test(4,5); + test(0,0)=overcast; + test(0,1)=rain; + test(0,2)=sunny; + test(0,3)=rain; + test(0,4)=sunny; + + test(1,0)=hot; + test(1,1)=cool; + test(1,2)=mild; + test(1,3)=mild; + test(1,4)=hot; + + test(2,0)=normal; + test(2,1)=high; + test(2,2)=high; + test(2,3)=normal; + test(2,4)=normal; + + test(3,0)=strong; + test(3,1)=strong; + test(3,2)=weak; + test(3,3)=weak; + test(3,4)=strong; + + CDenseFeatures* test_feats=new CDenseFeatures(test); + CMulticlassLabels* result=(CMulticlassLabels*) c->apply(test_feats); + SGVector res_vector=result->get_labels(); + + EXPECT_EQ(1.0,res_vector[0]); + EXPECT_EQ(0.0,res_vector[1]); + EXPECT_EQ(0.0,res_vector[2]); + EXPECT_EQ(1.0,res_vector[3]); + EXPECT_EQ(1.0,res_vector[4]); + + CMulticlassAccuracy* eval=new CMulticlassAccuracy(); + EXPECT_NEAR(0.571428,c->get_oob_error(eval),1e-6); + + SG_UNREF(test_feats); + SG_UNREF(result); + SG_UNREF(c); + SG_UNREF(eval); +}