Skip to content

Commit

Permalink
Merge pull request #98 from schalkdaniel/general_updates
Browse files Browse the repository at this point in the history
add getter for parameter and update tests
  • Loading branch information
Daniel Schalk committed Jan 20, 2018
2 parents 5b8b8c2 + 5b1b62f commit 924d767
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 16 deletions.
8 changes: 5 additions & 3 deletions src/baselearner_track.cpp
Expand Up @@ -53,9 +53,11 @@ void BaselearnerTrack::InsertBaselearner (blearner::Baselearner* blearner,
// Insert new baselearner:
blearner_vector.push_back(blearner);

std::string insert_id = blearner->GetDataIdentifier() + ": " + blearner->GetBaselearnerType();

// Check if the baselearner is the first one. If so, the parameter
// has to be instantiated with a zero matrix:
std::map<std::string, arma::mat>::iterator it = my_parameter_map.find(blearner->GetBaselearnerType());
std::map<std::string, arma::mat>::iterator it = my_parameter_map.find(insert_id);

// Prune parameter by multiplying it with the learning rate:
arma::mat parameter_temp = learning_rate * blearner->GetParameter();
Expand All @@ -65,15 +67,15 @@ void BaselearnerTrack::InsertBaselearner (blearner::Baselearner* blearner,

// If this is the first entry, initialize it with zeros:
arma::mat init_parameter(parameter_temp.n_rows, parameter_temp.n_cols, arma::fill::zeros);
my_parameter_map.insert(std::pair<std::string, arma::mat>(blearner->GetBaselearnerType(), init_parameter));
my_parameter_map.insert(std::pair<std::string, arma::mat>(insert_id, init_parameter));

}

// Accumulating parameter. If there is a nan, then this will be ignored and
// the non nan entries are added up:
// arma::mat parameter_insert = parameter_temp + my_parameter_map.find(blearner->GetBaselearnerType())->second;
// my_parameter_map.insert(std::pair<std::string, arma::mat>(blearner->GetBaselearnerType(), parameter_insert));
my_parameter_map[ blearner->GetBaselearnerType() ] = parameter_temp + my_parameter_map.find(blearner->GetBaselearnerType())->second;
my_parameter_map[ insert_id ] = parameter_temp + my_parameter_map.find(insert_id)->second;

}

Expand Down
33 changes: 23 additions & 10 deletions src/compboost_modules.cpp
Expand Up @@ -572,16 +572,6 @@ class CompboostWrapper
// std::cout << "<<CompboostWrapper>> Create Compboost" << std::endl;
}

// Destructor:
~CompboostWrapper ()
{
// std::cout << "Call CompboostWrapper Destructor" << std::endl;
delete used_logger;
delete used_optimizer;
delete eval_data;
delete obj;
}

// Member functions
void train (bool trace)
{
Expand All @@ -606,6 +596,28 @@ class CompboostWrapper
);
}

Rcpp::List getEstimatedParameter ()
{
std::map<std::string, arma::mat> parameter = obj->GetParameter();

Rcpp::List out;

for (auto &it : parameter) {
out[it.first] = it.second;
}
return out;
}

// Destructor:
~CompboostWrapper ()
{
// std::cout << "Call CompboostWrapper Destructor" << std::endl;
// delete used_logger;
// delete used_optimizer;
// delete eval_data;
// delete obj;
}

private:

blearnerlist::BaselearnerList* blearner_list_ptr;
Expand All @@ -630,6 +642,7 @@ RCPP_MODULE (compboost_module)
.method("getPrediction", &CompboostWrapper::getPrediction, "Get prediction")
.method("getSelectedBaselearner", &CompboostWrapper::getSelectedBaselearner, "Get vector of selected baselearner")
.method("getLoggerData", &CompboostWrapper::getLoggerData, "Get data of the used logger")
.method("getEstimatedParameter", &CompboostWrapper::getEstimatedParameter, "Get the estimated paraemter")
;
}

16 changes: 16 additions & 0 deletions tests/testthat/test_compboost.R
Expand Up @@ -105,6 +105,22 @@ test_that("compboost does the same as mboost", {
# ------
expect_equal(predict(mod), cboost$getPrediction())
expect_equal(mod$xselect(), cboost.xselect)
expect_equal(
unname(
unlist(
mod$coef()[
order(
unlist(
lapply(names(unlist(mod$coef()[1:3])), function (x) {
strsplit(x, "[.]")[[1]][2]
})
)
)
]
)
),
unname(unlist(cboost$getEstimatedParameter()))
)

expect_equal(dim(cboost$getLoggerData()$logger_data), c(500, 2))
expect_equal(cboost$getLoggerData()$logger_data[, 1], 1:500)
Expand Down
6 changes: 5 additions & 1 deletion tutorials/compboost_class.R
Expand Up @@ -89,5 +89,9 @@ cboost$train(trace = TRUE)

# Get vector selected baselearner:
cboost$getSelectedBaselearner()
# cboost$getModelFrame()

# Get estimated parameter:
cboost$getEstimatedParameter()

# Get logger data:
cboost$getLoggerData()
11 changes: 9 additions & 2 deletions tutorials/compboost_vs_mboost.R
Expand Up @@ -78,9 +78,11 @@ cboost$train(FALSE)

# Get vector selected baselearner:
cboost$getSelectedBaselearner()
# cboost$GetModelFrame()

# Get logger data:
cboost$getLoggerData()

# Get parameter estimator:

# Do the same with mboost:
# ------------------------
Expand Down Expand Up @@ -113,7 +115,12 @@ all.equal(mod$xselect(), cboost.xselect)
# ------------------------------------

all.equal(predict(mod), cboost$getPrediction())
# cboost$GetParameter()

# Check if parameter are the same:
# --------------------------------

mod$coef()
cboost$getEstimatedParameter()

# Benchmark:
# ----------
Expand Down

0 comments on commit 924d767

Please sign in to comment.