Skip to content

Commit

Permalink
Added class probabilities for BaggingMachine (and RandomForest) (#3954)
Browse files Browse the repository at this point in the history
  • Loading branch information
olinguyen authored and karlnapf committed Aug 16, 2017
1 parent 832e8fd commit b537cc8
Show file tree
Hide file tree
Showing 6 changed files with 573 additions and 386 deletions.
74 changes: 65 additions & 9 deletions src/shogun/machine/BaggingMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@
* Copyright (C) 2013 Viktor Gal
*/

#include <shogun/machine/BaggingMachine.h>
#include <shogun/ensemble/CombinationRule.h>
#include <shogun/ensemble/MeanRule.h>
#include <shogun/machine/BaggingMachine.h>
#include <shogun/mathematics/linalg/LinalgNamespace.h>

#include <shogun/evaluation/Evaluation.h>

using namespace shogun;
Expand Down Expand Up @@ -44,17 +47,63 @@ CBaggingMachine::~CBaggingMachine()

CBinaryLabels* CBaggingMachine::apply_binary(CFeatures* data)
{
SGVector<float64_t> combined_vector = apply_get_outputs(data);
SGMatrix<float64_t> output = apply_outputs_without_combination(data);

CMeanRule* mean_rule = new CMeanRule();

SGVector<float64_t> labels = m_combination_rule->combine(output);
SGVector<float64_t> probabilities = mean_rule->combine(output);

float64_t threshold = 0.5;
CBinaryLabels* pred = new CBinaryLabels(probabilities, threshold);

SG_UNREF(mean_rule);

CBinaryLabels* pred = new CBinaryLabels(combined_vector);
return pred;
}

CMulticlassLabels* CBaggingMachine::apply_multiclass(CFeatures* data)
{
SGVector<float64_t> combined_vector = apply_get_outputs(data);
SGMatrix<float64_t> bagged_outputs =
apply_outputs_without_combination(data);

REQUIRE(m_labels, "Labels not set.\n");

auto labels_multiclass = dynamic_cast<CMulticlassLabels*>(m_labels);
REQUIRE(
labels_multiclass, "Labels (%s) are not compatible with multiclass.\n",
m_labels->get_name());

auto num_samples = bagged_outputs.size() / m_num_bags;
auto num_classes = labels_multiclass->get_num_classes();

CMulticlassLabels* pred = new CMulticlassLabels(num_samples);
pred->allocate_confidences_for(num_classes);

SGMatrix<float64_t> class_probabilities(num_samples, num_classes);
class_probabilities.zero();

for (auto i = 0; i < num_samples; ++i)
{
for (auto j = 0; j < m_num_bags; ++j)
{
int32_t class_idx = bagged_outputs(i, j);
class_probabilities(i, class_idx) += 1;
}
}

float64_t alpha = 1.0 / m_num_bags;
class_probabilities = linalg::scale(class_probabilities, alpha);

for (auto i = 0; i < num_samples; ++i)
{
auto confidences = class_probabilities.get_row_vector(i);
auto y_pred = CMath::arg_max(confidences.vector, 1, confidences.vlen);

pred->set_label(i, y_pred);
pred->set_multiclass_confidences(i, confidences);
}

CMulticlassLabels* pred = new CMulticlassLabels(combined_vector);
return pred;
}

Expand All @@ -71,12 +120,21 @@ SGVector<float64_t> CBaggingMachine::apply_get_outputs(CFeatures* data)
{
ASSERT(data != NULL);
REQUIRE(m_combination_rule != NULL, "Combination rule is not set!");

SGMatrix<float64_t> output = apply_outputs_without_combination(data);
SGVector<float64_t> combined = m_combination_rule->combine(output);

return combined;
}

SGMatrix<float64_t>
CBaggingMachine::apply_outputs_without_combination(CFeatures* data)
{
ASSERT(m_num_bags == m_bags->get_num_elements());

SGMatrix<float64_t> output(data->get_num_vectors(), m_num_bags);
output.zero();


#pragma omp parallel for
for (int32_t i = 0; i < m_num_bags; ++i)
{
Expand All @@ -95,9 +153,7 @@ SGVector<float64_t> CBaggingMachine::apply_get_outputs(CFeatures* data)
SG_UNREF(m);
}

SGVector<float64_t> combined = m_combination_rule->combine(output);

return combined;
return output;
}

bool CBaggingMachine::train_machine(CFeatures* data)
Expand Down
38 changes: 24 additions & 14 deletions src/shogun/machine/BaggingMachine.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,20 +139,30 @@ namespace shogun
*/
SGVector<float64_t> apply_get_outputs(CFeatures* data);

/** Register paramaters */
void register_parameters();

/** Initialize the members with default values */
void init();

/**
* get the vector of indices for feature vectors that are out of bag
*
* @param in_bag vector of indices that are in bag.
* NOTE: in_bag is a randomly generated with replacement
* @return the vector of indices
*/
CDynamicArray<index_t>* get_oob_indices(const SGVector<index_t>& in_bag);
/** helper function for the apply_{binary,..} functions that
* computes the output probabilities without combination rules
*
* @param data the data to compute the output for
* @return predictions
*/
SGMatrix<float64_t>
apply_outputs_without_combination(CFeatures* data);

/** Register paramaters */
void register_parameters();

/** Initialize the members with default values */
void init();

/**
* get the vector of indices for feature vectors that are out of bag
*
* @param in_bag vector of indices that are in bag.
* NOTE: in_bag is a randomly generated with replacement
* @return the vector of indices
*/
CDynamicArray<index_t>*
get_oob_indices(const SGVector<index_t>& in_bag);

protected:
/** bags array */
Expand Down
Loading

0 comments on commit b537cc8

Please sign in to comment.