Skip to content

Commit

Permalink
Merge pull request #3293 from Saurabh7/cart
Browse files Browse the repository at this point in the history
CART update
  • Loading branch information
vigsterkr committed Jun 28, 2016
2 parents d6353fe + 2219440 commit bb07222
Show file tree
Hide file tree
Showing 7 changed files with 333 additions and 299 deletions.
18 changes: 17 additions & 1 deletion src/shogun/machine/RandomForest.cpp
Expand Up @@ -159,11 +159,27 @@ void CRandomForest::set_machine_parameters(CMachine* m, SGVector<index_t> idx)
}

tree->set_weights(weights);

tree->set_sorted_features(m_sorted_transposed_feats, m_sorted_indices);
// equate the machine problem types - cloning does not do this
tree->set_machine_problem_type(dynamic_cast<CRandomCARTree*>(m_machine)->get_machine_problem_type());
}

bool CRandomForest::train_machine(CFeatures* data)
{
if (data)
{
SG_REF(data);
SG_UNREF(m_features);
m_features = data;
}

REQUIRE(m_features, "Training features not set!\n");

dynamic_cast<CRandomCARTree*>(m_machine)->pre_sort_features(m_features, m_sorted_transposed_feats, m_sorted_indices);

return CBaggingMachine::train_machine();
}

void CRandomForest::init()
{
m_machine=new CRandomCARTree();
Expand Down
7 changes: 7 additions & 0 deletions src/shogun/machine/RandomForest.h
Expand Up @@ -139,6 +139,8 @@ class CRandomForest : public CBaggingMachine
int32_t get_num_random_features() const;

protected:

virtual bool train_machine(CFeatures* data=NULL);
/** sets parameters of CARTree - sets machine labels and weights here
*
* @param m machine
Expand All @@ -154,6 +156,11 @@ class CRandomForest : public CBaggingMachine
/** weights */
SGVector<float64_t> m_weights;

/** Pre-sorted features */
SGMatrix<float64_t> m_sorted_transposed_feats;

/** Indices of pre-sorted features */
SGMatrix<index_t> m_sorted_indices;
};
} /* namespace shogun */
#endif /* _RANDOMFOREST_H__ */

0 comments on commit bb07222

Please sign in to comment.