Skip to content

Commit

Permalink
add n_class arg to binary mat creation
Browse files Browse the repository at this point in the history
  • Loading branch information
schalkdaniel committed Jan 2, 2020
1 parent eb68913 commit 11101f9
Show file tree
Hide file tree
Showing 10 changed files with 71 additions and 15 deletions.
Binary file added src/.baselearner.h.swp
Binary file not shown.
Binary file added src/.baselearner_factory.cpp.swp
Binary file not shown.
Binary file added src/.baselearner_factory.h.swp
Binary file not shown.
Binary file added src/.compboost_modules.cpp.swp
Binary file not shown.
9 changes: 5 additions & 4 deletions src/baselearner.cpp
Expand Up @@ -357,10 +357,11 @@ BaselearnerPSpline::~BaselearnerPSpline () {}
// -----------------------

BaselearnerCategorical::BaselearnerCategorical (std::shared_ptr<data::Data> data, const std::string& identifier,
const double& learning_rate, const double& penalty, const unsigned int& iters)
const double& learning_rate, const double& penalty, const unsigned int& iters, const unsigned int& n_classes)
: learning_rate ( learning_rate ),
penalty ( penalty ),
iters ( iters )
iters ( iters ),
n_classes ( n_classes )
{
// Called from parent class 'Baselearner':
Baselearner::setData(data);
Expand All @@ -379,7 +380,7 @@ Baselearner* BaselearnerCategorical::clone ()
arma::mat BaselearnerCategorical::instantiateData (const arma::mat& newdata) const
{
arma::Row<unsigned int> temp = arma::conv_to<arma::Row<unsigned int>>::from(newdata);
arma::mat out (helper::binaryMat(temp));
arma::mat out (helper::binaryMat(temp, n_classes));
return out;
}

Expand All @@ -397,7 +398,7 @@ arma::mat BaselearnerCategorical::predict () const
arma::mat BaselearnerCategorical::predict (std::shared_ptr<data::Data> newdata) const
{
arma::Row<unsigned int> temp = arma::conv_to<arma::Row<unsigned int>>::from(newdata->getData());
return helper::binaryMat(temp) * parameter;
return helper::binaryMat(temp, n_classes) * parameter;
}

// Destructor:
Expand Down
5 changes: 4 additions & 1 deletion src/baselearner.h
Expand Up @@ -207,14 +207,17 @@ class BaselearnerCategorical : public Baselearner
/// Iteration used for optimization:
const unsigned int iters;

/// Number of classes, required to always create data matrix with correct dimensions:
const unsigned int n_classes;

// /// Hashmap of levels:
// const std::map<unsigned int, std::string>;

public:

/// Default constructor:
BaselearnerCategorical (std::shared_ptr<data::Data>, const std::string&, const double&,
const double&, const unsigned int&);
const double&, const unsigned int&, const unsigned int&);

// /// Constructor when passing a list of levels:
// BaselearnerCategorical (std::shared_ptr<data::Data>, const std::string&, const double&,
Expand Down
52 changes: 48 additions & 4 deletions src/baselearner_factory.cpp
Expand Up @@ -319,7 +319,8 @@ arma::mat BaselearnerPSplineFactory::instantiateData (const arma::mat& newdata)
// -----------------------

/**
* \brief Default constructor of class `PSplineCategoricalFactory`
* \brief
constructor of class `PSplineCategoricalFactory`
*
* The constructor creates the binary design matrix as sparse matrix. The factory also stores
* the learning rate and penalty term since the base-learner uses an coordinate descent
Expand Down Expand Up @@ -349,9 +350,52 @@ BaselearnerCategoricalFactory::BaselearnerCategoricalFactory (const std::string&
// Make sure that the data identifier is setted correctly:
data_target->setDataIdentifier(data_source->getDataIdentifier());
arma::Row<unsigned int> temp = arma::conv_to<arma::Row<unsigned int>>::from(data_source->getData());
data_target->sparse_data_mat = helper::binaryMat(temp);

n_classes = arma::max(temp);

data_target->sparse_data_mat = helper::binaryMat(temp, n_classes);
}

// BaselearnerCategorical:
// -----------------------

/**
* \brief Constructor of class `PSplineCategoricalFactory`
*
* The constructor creates the binary design matrix as sparse matrix. The factory also stores
* the learning rate and penalty term since the base-learner uses an coordinate descent
* search strategy with L1 penalty.
*
* \param blearner_type0 `std::string` Name of the baselearner type (setted by
* the Rcpp Wrapper classes in `compboost_modules.cpp`)
* \param data_source `std::shared_ptr<data::Data>` Source of the data
* \param data_target `std::shared_ptr<data::Data>` Object to store the transformed data source
* \param learning_rate `double` Coordinate descent learning rate.
* \param penalty `double` L1 penalty term.
* \param iters `unsigned int` Maximal numbers of iterations.
* \param n_classes `unsigned int` Number of classes (safety check if one class has no entities).
*/

BaselearnerCategoricalFactory::BaselearnerCategoricalFactory (const std::string& blearner_type0,
std::shared_ptr<data::Data> data_source0, std::shared_ptr<data::Data> data_target0,
const double& learning_rate, const double& penalty, const unsigned int& iters, const unsigned int& n_classes)
: learning_rate ( learning_rate ),
penalty ( penalty ),
iters ( iters ),
n_classes ( n_classes )
{
blearner_type = blearner_type0;
// Set data, data identifier and the data_mat (dense at this stage)
data_source = data_source0;
data_target = data_target0;

// Make sure that the data identifier is setted correctly:
data_target->setDataIdentifier(data_source->getDataIdentifier());
arma::Row<unsigned int> temp = arma::conv_to<arma::Row<unsigned int>>::from(data_source->getData());
data_target->sparse_data_mat = helper::binaryMat(temp, n_classes);
}


/**
* \brief Create new `BaselearnerCategorical` object
*
Expand All @@ -362,7 +406,7 @@ std::shared_ptr<blearner::Baselearner> BaselearnerCategoricalFactory::createBase
// Create new categorical baselearner. This one will be returned by the
// factory:
std::shared_ptr<blearner::Baselearner> sh_ptr_blearner = std::make_shared<blearner::BaselearnerCategorical>(data_target, identifier,
learning_rate, penalty, iters);
learning_rate, penalty, iters, n_classes);
sh_ptr_blearner->setBaselearnerType(blearner_type);

return sh_ptr_blearner;
Expand Down Expand Up @@ -403,7 +447,7 @@ arma::mat BaselearnerCategoricalFactory::getData () const
arma::mat BaselearnerCategoricalFactory::instantiateData (const arma::mat& newdata) const
{
arma::Row<unsigned int> temp = arma::conv_to<arma::Row<unsigned int>>::from(newdata);
arma::mat out (helper::binaryMat(temp));
arma::mat out (helper::binaryMat(temp, n_classes));
return out;
}

Expand Down
8 changes: 8 additions & 0 deletions src/baselearner_factory.h
Expand Up @@ -176,14 +176,22 @@ class BaselearnerCategoricalFactory : public BaselearnerFactory
/// Iteration used for optimization:
const unsigned int iters;

/// Number of classes, required to always create data matrix with correct dimensions:
unsigned int n_classes;

// /// Hashmap of levels:
// const std::map<unsigned int, std::string>;

public:
/// Default constructor of class `PSplineBleanrerFactory`
BaselearnerCategoricalFactory (const std::string&, std::shared_ptr<data::Data>, std::shared_ptr<data::Data>,
const double&, const double&, const unsigned int&, const unsigned int&);

/// Default constructor of class `PSplineBleanrerFactory`
BaselearnerCategoricalFactory (const std::string&, std::shared_ptr<data::Data>, std::shared_ptr<data::Data>,
const double&, const double&, const unsigned int&);


/// Create new `BaselearnerPSpline` object
std::shared_ptr<blearner::Baselearner> createBaselearner (const std::string&);

Expand Down
10 changes: 5 additions & 5 deletions src/helper.cpp
Expand Up @@ -237,15 +237,15 @@ double matrixQuantile (const arma::mat& X, const double& quantile)
}

// Classes must start with 1!
arma::sp_mat binaryMat (const arma::Row<unsigned int>& classes)
arma::sp_mat binaryMat (const arma::Row<unsigned int>& classes, const unsigned int& n_classes)
{
const unsigned int n = classes.size();
const unsigned int n_rows = classes.size();

const arma::Row<double> ones (n, arma::fill::ones);
const arma::Row<unsigned int> ones_int (n, arma::fill::ones);
const arma::Row<double> ones (n_rows, arma::fill::ones);
const arma::Row<unsigned int> ones_int (n_rows, arma::fill::ones);

arma::umat indices = arma::join_cols(arma::cumsum(ones_int) - 1, classes - 1);
arma::sp_mat sp_out(indices, ones);
arma::sp_mat sp_out(indices, ones, n_rows, n_classes);

return sp_out;
}
Expand Down
2 changes: 1 addition & 1 deletion src/helper.h
Expand Up @@ -40,7 +40,7 @@ void checkForBinaryClassif (const std::vector<std::string>&);
void checkMatrixDim (const arma::mat&, const arma::mat&);
bool checkTracePrinter (const unsigned int&, const unsigned int&);
double matrixQuantile (const arma::mat&, const double&);
arma::sp_mat binaryMat (const arma::Row<unsigned int>&);
arma::sp_mat binaryMat (const arma::Row<unsigned int>&, const unsigned int&);
} // namespace helper

# endif // HELPER_H_

0 comments on commit 11101f9

Please sign in to comment.