Skip to content

Commit

Permalink
add check if oob data is present
Browse files Browse the repository at this point in the history
  • Loading branch information
schalkdaniel committed Mar 7, 2019
1 parent 74b07d1 commit 47f134f
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions src/logger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ std::string LoggerInbagRisk::printLoggerStatus () const
*/

LoggerOobRisk::LoggerOobRisk (const std::string& logger_id0, const bool& is_a_stopper0, std::shared_ptr<loss::Loss> sh_ptr_loss,
const double& eps_for_break, const unsigned int& patience, std::map<std::string, std::shared_ptr<data::Data>> oob_data,
const double& eps_for_break, const unsigned int& patience, std::map<std::string, std::shared_ptr<data::Data>> oob_data,
std::shared_ptr<response::Response> oob_response)
: sh_ptr_loss ( sh_ptr_loss ),
eps_for_break ( eps_for_break ),
Expand Down Expand Up @@ -396,15 +396,21 @@ void LoggerOobRisk::logStep (const unsigned int& current_iteration, std::shared_
sh_ptr_oob_response->constantInitialization(sh_ptr_response->getInitialization());
sh_ptr_oob_response->initializePrediction();
}

std::string blearner_id = sh_ptr_blearner->getDataIdentifier();

// Get data of corresponding selected baselearner. E.g. iteration 100 linear
// baselearner of feature x_7, then get the data of feature x_7:
std::shared_ptr<data::Data> oob_blearner_data = oob_data.find(blearner_id)->second;

// Predict this data using the selected baselearner:
arma::mat temp_oob_prediction = sh_ptr_blearner->predict(oob_blearner_data);
sh_ptr_oob_response->updatePrediction(learning_rate, step_size, temp_oob_prediction);
// Check, whether the data object is present or not:
std::map<std::string, std::shared_ptr<data::Data>>::iterator it_oob_data = oob_data.find(blearner_id);
if (it_oob_data != oob_data.end()) {
std::shared_ptr<data::Data> oob_blearner_data = it_oob_data->second;

// Predict this data using the selected baselearner:
arma::mat temp_oob_prediction = sh_ptr_blearner->predict(oob_blearner_data);
sh_ptr_oob_response->updatePrediction(learning_rate, step_size, temp_oob_prediction);
}


/* *****************************************************************************************************************************
Expand Down

0 comments on commit 47f134f

Please sign in to comment.