diff --git a/src/shogun/machine/StochasticGBMachine.cpp b/src/shogun/machine/StochasticGBMachine.cpp new file mode 100644 index 00000000000..3cfe1cc26c5 --- /dev/null +++ b/src/shogun/machine/StochasticGBMachine.cpp @@ -0,0 +1,400 @@ +/* + * Copyright (c) The Shogun Machine Learning Toolbox + * Written (w) 2014 Parijat Mazumdar + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * The views and conclusions contained in the software and documentation are those + * of the authors and should not be interpreted as representing official policies, + * either expressed or implied, of the Shogun Development Team. + */ + +#include +#include + +using namespace shogun; + +CStochasticGBMachine::CStochasticGBMachine(CMachine* machine, CLossFunction* loss, int32_t num_iterations, + float64_t learning_rate, float64_t subset_fraction) +: CMachine() +{ + init(); + + if (machine!=NULL) + { + SG_REF(machine); + m_machine=machine; + } + + if (loss!=NULL) + { + SG_REF(loss); + m_loss=loss; + } + + m_num_iter=num_iterations; + m_subset_frac=subset_fraction; + m_learning_rate=learning_rate; +} + +CStochasticGBMachine::~CStochasticGBMachine() +{ + SG_UNREF(m_machine); + SG_UNREF(m_loss); + SG_UNREF(m_weak_learners); + SG_UNREF(m_gamma); +} + +void CStochasticGBMachine::set_machine(CMachine* machine) +{ + REQUIRE(machine,"Supplied machine is NULL\n") + + if (m_machine!=NULL) + SG_UNREF(m_machine); + + SG_REF(machine); + m_machine=machine; +} + +CMachine* CStochasticGBMachine::get_machine() const +{ + if (m_machine==NULL) + SG_ERROR("machine not set yet!\n") + + SG_REF(m_machine); + return m_machine; +} + +void CStochasticGBMachine::set_loss_function(CLossFunction* f) +{ + REQUIRE(f,"Supplied loss function is NULL\n") + if (m_loss!=NULL) + SG_UNREF(m_loss); + + SG_REF(f); + m_loss=f; +} + +CLossFunction* CStochasticGBMachine::get_loss_function() const +{ + if (m_loss==NULL) + SG_ERROR("Loss function not set yet!\n") + + SG_REF(m_loss) + return m_loss; +} + +void CStochasticGBMachine::set_num_iterations(int32_t iter) +{ + REQUIRE(iter,"Number of iterations\n") + m_num_iter=iter; +} + +int32_t CStochasticGBMachine::get_num_iterations() const +{ + return m_num_iter; +} + +void CStochasticGBMachine::set_subset_fraction(float64_t frac) +{ + REQUIRE((frac>0)&&(frac<=1),"subset fraction should lie between 0 and 1. Supplied value is %f\n",frac) + + m_subset_frac=frac; +} + +float64_t CStochasticGBMachine::get_subset_fraction() const +{ + return m_subset_frac; +} + +void CStochasticGBMachine::set_learning_rate(float64_t lr) +{ + REQUIRE((lr>0)&&(lr<=1),"learning rate should lie between 0 and 1. Supplied value is %f\n",lr) + + m_learning_rate=lr; +} + +float64_t CStochasticGBMachine::get_learning_rate() const +{ + return m_learning_rate; +} + +CRegressionLabels* CStochasticGBMachine::apply_regression(CFeatures* data) +{ + REQUIRE(data,"test data supplied is NULL\n") + CDenseFeatures* feats=CDenseFeatures::obtain_from_generic(data); + + SGVector retlabs(feats->get_num_vectors()); + retlabs.fill_vector(retlabs.vector,retlabs.vlen,0); + for (int32_t i=0;iget_element(i); + + CSGObject* element=m_weak_learners->get_element(i); + REQUIRE(element,"%d element of the array of weak learners is NULL. This is not expected\n",i) + CMachine* machine=dynamic_cast(element); + + CRegressionLabels* dlabels=machine->apply_regression(feats); + SGVector delta=dlabels->get_labels(); + + for (int32_t j=0;j* feats=CDenseFeatures::obtain_from_generic(data); + + // initialize weak learners array and gamma array + initialize_learners(); + + // cache predicted labels for intermediate models + CRegressionLabels* interf=new CRegressionLabels(feats->get_num_vectors()); + SG_REF(interf); + for (int32_t i=0;iget_num_labels();i++) + interf->set_label(i,0); + + for (int32_t i=0;ipush_back(wlearner); + + // compute multiplier + CRegressionLabels* hm=wlearner->apply_regression(feats); + SG_REF(hm); + float64_t gamma=compute_multiplier(interf,hm); + m_gamma->push_back(gamma); + + // remove subset + if (m_subset_frac!=1.0) + { + feats->remove_subset(); + m_labels->remove_subset(); + interf->remove_subset(); + } + + // update intermediate function value + CRegressionLabels* dlabels=wlearner->apply_regression(feats); + SGVector delta=dlabels->get_labels(); + for (int32_t j=0;jget_num_labels();j++) + interf->set_label(j,interf->get_label(j)+delta[j]*gamma*m_learning_rate); + + SG_UNREF(dlabels); + SG_UNREF(hm); + SG_UNREF(wlearner); + } + + SG_UNREF(interf); + return true; +} + +float64_t CStochasticGBMachine::compute_multiplier(CRegressionLabels* f, CRegressionLabels* hm) +{ + REQUIRE(f->get_num_labels()==hm->get_num_labels(),"The number of labels in both input parameters should be equal\n") + + CDynamicObjectArray* instance=new CDynamicObjectArray(); + instance->push_back(m_labels); + instance->push_back(f); + instance->push_back(hm); + instance->push_back(m_loss); + + float64_t ret=get_gamma(instance); + + SG_UNREF(instance); + return ret; +} + +CMachine* CStochasticGBMachine::fit_model(CDenseFeatures* feats, CRegressionLabels* labels) +{ + // clone base machine + CMachine* c=dynamic_cast(m_machine->clone()); + + // train cloned machine + c->set_labels(labels); + c->train(feats); + + return c; +} + +CRegressionLabels* CStochasticGBMachine::compute_pseudo_residuals(CRegressionLabels* inter_f) +{ + REQUIRE(m_labels,"training labels not set!\n") + SGVector labels=(dynamic_cast(m_labels))->get_labels(); + SGVector f=inter_f->get_labels(); + + SGVector residuals(f.vlen); + for (int32_t i=0;ifirst_derivative(f[i],labels[i]); + + return new CRegressionLabels(residuals); +} + +void CStochasticGBMachine::apply_subset(CDenseFeatures* f, CLabels* interf) +{ + int32_t subset_size=m_subset_frac*(f->get_num_vectors()); + SGVector idx(f->get_num_vectors()); + idx.range_fill(0); + idx.randperm(); + + SGVector subset(subset_size); + memcpy(subset.vector,idx.vector,subset.vlen*sizeof(index_t)); + + f->add_subset(subset); + interf->add_subset(subset); + m_labels->add_subset(subset); +} + +void CStochasticGBMachine::initialize_learners() +{ + SG_UNREF(m_weak_learners); + m_weak_learners=new CDynamicObjectArray(); + SG_REF(m_weak_learners); + + SG_UNREF(m_gamma); + m_gamma=new CDynamicArray(); + SG_REF(m_gamma); +} + +float64_t CStochasticGBMachine::get_gamma(void* instance) +{ + lbfgs_parameter_t lbfgs_param; + lbfgs_parameter_init(&lbfgs_param); + lbfgs_param.linesearch=2; + + float64_t gamma=0; + lbfgs(1,&gamma,NULL,CStochasticGBMachine::lbfgs_evaluate,NULL,instance,&lbfgs_param); + + return gamma; +} + +float64_t CStochasticGBMachine::lbfgs_evaluate(void *obj, const float64_t *parameters, float64_t *gradient, const int dim, + const float64_t step) +{ + REQUIRE(obj,"object cannot be NULL\n") + CDynamicObjectArray* objects=static_cast(obj); + REQUIRE((objects->get_num_elements()==2) || (objects->get_num_elements()==4),"Number of elements in obj array" + " (%d) does not match expectations(2 or 4)\n",objects->get_num_elements()) + + if (objects->get_num_elements()==2) + { + // extract labels + CSGObject* element=objects->get_element(0); + REQUIRE(element,"0 index element of objects is NULL\n") + CDenseLabels* lab=dynamic_cast(element); + SGVector labels=lab->get_labels(); + + // extract loss function + element=objects->get_element(1); + REQUIRE(element,"1 index element of objects is NULL\n") + CLossFunction* lossf=dynamic_cast(element); + + *gradient=0; + float64_t ret=0; + for (int32_t i=0;ifirst_derivative((*parameters),labels[i]); + ret+=lossf->loss((*parameters),labels[i]); + } + + SG_UNREF(lab); + SG_UNREF(lossf); + return ret; + } + + // extract labels + CSGObject* element=objects->get_element(0); + REQUIRE(element,"0 index element of objects is NULL\n") + CDenseLabels* lab=dynamic_cast(element); + SGVector labels=lab->get_labels(); + + // extract f + element=objects->get_element(1); + REQUIRE(element,"1 index element of objects is NULL\n") + CDenseLabels* func=dynamic_cast(element); + SGVector f=func->get_labels(); + + // extract hm + element=objects->get_element(2); + REQUIRE(element,"2 index element of objects is NULL\n") + CDenseLabels* delta=dynamic_cast(element); + SGVector hm=delta->get_labels(); + + // extract loss function + element=objects->get_element(3); + REQUIRE(element,"3 index element of objects is NULL\n") + CLossFunction* lossf=dynamic_cast(element); + + *gradient=0; + float64_t ret=0; + for (int32_t i=0;ifirst_derivative((*parameters)*hm[i]+f[i],labels[i]); + ret+=lossf->loss((*parameters)*hm[i]+f[i],labels[i]); + } + + SG_UNREF(lab); + SG_UNREF(delta); + SG_UNREF(func); + SG_UNREF(lossf) + return ret; +} + +void CStochasticGBMachine::init() +{ + m_machine=NULL; + m_loss=NULL; + m_num_iter=0; + m_subset_frac=0; + m_learning_rate=0; + + m_weak_learners=new CDynamicObjectArray(); + SG_REF(m_weak_learners); + + m_gamma=new CDynamicArray(); + SG_REF(m_gamma); + + SG_ADD((CSGObject**)&m_machine,"m_machine","machine",MS_NOT_AVAILABLE); + SG_ADD((CSGObject**)&m_loss,"m_loss","loss function",MS_NOT_AVAILABLE); + SG_ADD(&m_num_iter,"m_num_iter","number of iterations",MS_NOT_AVAILABLE); + SG_ADD(&m_subset_frac,"m_subset_frac","subset fraction",MS_NOT_AVAILABLE); + SG_ADD(&m_learning_rate,"m_learning_rate","learning rate",MS_NOT_AVAILABLE); + SG_ADD((CSGObject**)&m_weak_learners,"m_weak_learners","array of weak learners",MS_NOT_AVAILABLE); + SG_ADD((CSGObject**)&m_gamma,"m_gamma","array of learner weights",MS_NOT_AVAILABLE); +} diff --git a/src/shogun/machine/StochasticGBMachine.h b/src/shogun/machine/StochasticGBMachine.h new file mode 100644 index 00000000000..2fb667539ed --- /dev/null +++ b/src/shogun/machine/StochasticGBMachine.h @@ -0,0 +1,227 @@ +/* + * Copyright (c) The Shogun Machine Learning Toolbox + * Written (w) 2014 Parijat Mazumdar + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * The views and conclusions contained in the software and documentation are those + * of the authors and should not be interpreted as representing official policies, + * either expressed or implied, of the Shogun Development Team. + */ + +#ifndef _StochasticGBMachine_H__ +#define _StochasticGBMachine_H__ + + +#include + +#include +#include +#include + +namespace shogun +{ + +/** @brief This class implements the stochastic gradient boosting algorithm for ensemble learning invented by Jerome H. Friedman. This class + * works with a variety of loss functions like squared loss, exponential loss, Huber loss etc which can be accessed through Shogun's + * CLossFunction interface (cf. http://www.shogun-toolbox.org/doc/en/latest/classshogun_1_1CLossFunction.html). Additionally, it can create + * an ensemble of any regressor class derived from the CMachine class (cf. http://www.shogun-toolbox.org/doc/en/latest/classshogun_1_1CMachine.html). + * For one dimensional optimization, this class uses the backtracking linesearch accessed via Shogun's L-BFGS class. + * A concise description of the algorithm implemented can be found in the following link : + * http://en.wikipedia.org/wiki/Gradient_boosting#Algorithm + */ +class CStochasticGBMachine : public CMachine +{ +public: + /** Constructor + * + * @param machine The class of machine which will constitute the ensemble + * @param loss loss function + * @param num_iterations number of iterations of boosting + * @param subset_fraction fraction of trainining vectors to be chosen randomly w/o replacement + * @param learning_rate shrinkage factor + */ + CStochasticGBMachine(CMachine* machine=NULL, CLossFunction* loss=NULL, int32_t num_iterations=100, + float64_t learning_rate=1.0, float64_t subset_fraction=0.6); + + /** Destructor */ + virtual ~CStochasticGBMachine(); + + /** get name + * + * @return StochasticGBMachine + */ + virtual const char* get_name() const { return "StochasticGBMachine"; } + + /** set machine + * + * @param machine machine + */ + void set_machine(CMachine* machine); + + /** get machine + * + * @return machine + */ + CMachine* get_machine() const; + + /** set loss function + * + * @param f loss function + */ + virtual void set_loss_function(CLossFunction* f); + + /** get loss function + * + * @return loss function + */ + virtual CLossFunction* get_loss_function() const; + + /** set number of iterations + * + * @param iter number of iterations + */ + void set_num_iterations(int32_t iter); + + /** get number of iterations + * + * @return number of iterations + */ + int32_t get_num_iterations() const; + + /** set subset fraction + * + * @param frac subset fraction (should lie between 0 and 1) + */ + void set_subset_fraction(float64_t frac); + + /** get subset fraction + * + * @return subset fraction + */ + float64_t get_subset_fraction() const; + + /** set learning rate + * + * @param lr learning rate + */ + void set_learning_rate(float64_t lr); + + /** get learning rate + * + * @return learning rate + */ + float64_t get_learning_rate() const; + + /** apply_regression + * + * @param test data + * @param Regression labels + */ + virtual CRegressionLabels* apply_regression(CFeatures* data=NULL); + +protected: + /** train machine + * + * @param data training data + * @return true + */ + virtual bool train_machine(CFeatures* data=NULL); + + /** compute gamma values + * + * @param f labels from the intermediate model + * @param hm labels from the newly trained base model + * @return gamma - the scalar weights given to individual weak learners in the ensemble model + */ + float64_t compute_multiplier(CRegressionLabels* f, CRegressionLabels* hm); + + /** train base model + * + * @param feats training data + * @param labels training labels + * @return trained base model + */ + CMachine* fit_model(CDenseFeatures* feats, CRegressionLabels* labels); + + /** compute pseudo_residuals + * + * @param inter_f intermediate boosted model labels for training data + * @return pseudo_residuals + */ + CRegressionLabels* compute_pseudo_residuals(CRegressionLabels* inter_f); + + /** add randomized subset to relevant parameters + * + * @param f training data + * @param interf intermediate boosted model labels for training data + */ + void apply_subset(CDenseFeatures* f, CLabels* interf); + + /** reset arrays of weak learners and gamma values */ + void initialize_learners(); + + /** apply lbfgs to get gamma + * + * @param instance stores parameters to be passed to lbfgs_evaluate + * @return gamma + */ + float64_t get_gamma(void* instance); + + /** call-back evaluate method for lbfgs + * + * @param obj object parameters required for loss calculation + * @param paramaters current state of variables of target function + * @param gradient stores gradient computed by this method + * @param dim dimensions + * @param step step in linesearch + */ + static float64_t lbfgs_evaluate(void *obj, const float64_t *parameters, float64_t *gradient, const int dim, const float64_t step); + + /** initialize */ + void init(); + +protected: + /** machine to be used for GBoosting */ + CMachine* m_machine; + + /** loss function */ + CLossFunction* m_loss; + + /** num of iterations */ + int32_t m_num_iter; + + /** subset fraction */ + float64_t m_subset_frac; + + /** learning_rate */ + float64_t m_learning_rate; + + /** array of weak learners */ + CDynamicObjectArray* m_weak_learners; + + /** gamma - weak learner weights */ + CDynamicArray* m_gamma; +}; +}/* shogun */ + +#endif /* _StochasticGBMachine_H__ */ diff --git a/src/shogun/multiclass/tree/CARTree.cpp b/src/shogun/multiclass/tree/CARTree.cpp index 0eea1e7323b..bdcf8a5c524 100644 --- a/src/shogun/multiclass/tree/CARTree.cpp +++ b/src/shogun/multiclass/tree/CARTree.cpp @@ -65,6 +65,20 @@ CCARTree::~CCARTree() SG_UNREF(m_alphas); } +void CCARTree::set_labels(CLabels* lab) +{ + if (lab->get_label_type()==LT_MULTICLASS) + set_machine_problem_type(PT_MULTICLASS); + else if (lab->get_label_type()==LT_REGRESSION) + set_machine_problem_type(PT_REGRESSION); + else + SG_ERROR("label type supplied is not supported\n") + + SG_REF(lab); + SG_UNREF(m_labels); + m_labels=lab; +} + void CCARTree::set_machine_problem_type(EProblemType mode) { m_mode=mode; diff --git a/src/shogun/multiclass/tree/CARTree.h b/src/shogun/multiclass/tree/CARTree.h index 5a4e93f4002..59c48841c90 100644 --- a/src/shogun/multiclass/tree/CARTree.h +++ b/src/shogun/multiclass/tree/CARTree.h @@ -99,6 +99,11 @@ class CCARTree : public CTreeMachine /** destructor */ virtual ~CCARTree(); + /** set labels - automagically switch machine problem type based on type of labels supplied + * @param lab labels + */ + virtual void set_labels(CLabels* lab); + /** get name * @return class name CARTree */ diff --git a/tests/unit/machine/StochasticGBMachine_unittest.cc b/tests/unit/machine/StochasticGBMachine_unittest.cc new file mode 100644 index 00000000000..12727c2c677 --- /dev/null +++ b/tests/unit/machine/StochasticGBMachine_unittest.cc @@ -0,0 +1,121 @@ +/* + * Copyright (c) The Shogun Machine Learning Toolbox + * Written (w) 2014 Parijat Mazumdar + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * The views and conclusions contained in the software and documentation are those + * of the authors and should not be interpreted as representing official policies, + * either expressed or implied, of the Shogun Development Team. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace shogun; + +SGMatrix get_sinusoid_samples(int32_t num_samples, SGVector labels) +{ + SGMatrix ret(1,num_samples); + SGVector::random_vector(ret.matrix,num_samples,0,15); + + for (int32_t i=0;iset_seed(10); + + int32_t num_train_samples=100; + SGVector lab(num_train_samples); + SGMatrix data=get_sinusoid_samples(num_train_samples,lab); + CDenseFeatures* train_feats=new CDenseFeatures(data); + CRegressionLabels* train_labels=new CRegressionLabels(lab); + + SGVector tlab(10); + SGMatrix tdata(1,10); + + tlab[0]=-0.999585752311506259; + tlab[1]=0.75965469336929492; + tlab[2]=-0.425832103506531334; + tlab[3]=0.298135616000050285; + tlab[4]=-0.48828775732556795; + tlab[5]=-0.031677813420380535; + tlab[6]=0.144672857935527394; + tlab[7]=-0.0810247683026898424; + tlab[8]=-0.767723534099077121; + tlab[9]=0.639868456911451666; + + tdata(0,0)=10.9667896982205075; + tdata(0,1)=0.862781976084872615; + tdata(0,2)=12.1264892751645501; + tdata(0,3)=9.12203911322216499; + tdata(0,4)=9.93490458930258313; + tdata(0,5)=6.25150219333625934; + tdata(0,6)=0.145182344164974608; + tdata(0,7)=3.22270633960671393; + tdata(0,8)=11.6910897047936668; + tdata(0,9)=2.44726557225158103; + + CDenseFeatures* test_feats=new CDenseFeatures(tdata); + CRegressionLabels* test_labels=new CRegressionLabels(tlab); + + SGVector ft(1); + ft[0]=false; + CCARTree* tree=new CCARTree(ft); + tree->set_max_depth(2); + CSquaredLoss* sq=new CSquaredLoss(); + + CStochasticGBMachine* sgbm=new CStochasticGBMachine(tree,sq,100,0.1,1.0); + sgbm->set_labels(train_labels); + sgbm->train(train_feats); + + CRegressionLabels* ret_labels=sgbm->apply_regression(test_feats); + SGVector ret=ret_labels->get_labels(); + + float64_t epsilon=1e-8; + EXPECT_NEAR(ret[0],-0.943157980,epsilon); + EXPECT_NEAR(ret[1],0.769725470,epsilon); + EXPECT_NEAR(ret[2],-0.065691733,epsilon); + EXPECT_NEAR(ret[3],0.251266829,epsilon); + EXPECT_NEAR(ret[4],-0.577155330,epsilon); + EXPECT_NEAR(ret[5],0.113875818,epsilon); + EXPECT_NEAR(ret[6],0.427405429,epsilon); + EXPECT_NEAR(ret[7],-0.098310066,epsilon); + EXPECT_NEAR(ret[8],-0.416565932,epsilon); + EXPECT_NEAR(ret[9],0.542023083,epsilon); + + SG_UNREF(train_feats); + SG_UNREF(test_feats); + SG_UNREF(test_labels); + SG_UNREF(ret_labels); + SG_UNREF(sgbm); +}