From 87ca78d072e4bb71fed8913b3c77de4e94365c81 Mon Sep 17 00:00:00 2001 From: Parijat Mazumdar Date: Tue, 11 Mar 2014 17:31:03 +0530 Subject: [PATCH] entropy method updated --- .../multiclass/tree/ID3ClassifierTree.cpp | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/shogun/multiclass/tree/ID3ClassifierTree.cpp b/src/shogun/multiclass/tree/ID3ClassifierTree.cpp index 6b3742896f1..02f4d63964c 100644 --- a/src/shogun/multiclass/tree/ID3ClassifierTree.cpp +++ b/src/shogun/multiclass/tree/ID3ClassifierTree.cpp @@ -32,8 +32,7 @@ #include #include #include -#include -using namespace std; +#include using namespace shogun; @@ -99,7 +98,7 @@ CMulticlassLabels* CID3ClassifierTree::apply_multiclass(CFeatures* data) bool CID3ClassifierTree::train_machine(CFeatures* data) { - REQUIRE(data,"data required for training") + REQUIRE(data,"Data required for training") REQUIRE(data->get_feature_class()==C_DENSE, "Dense data required for training") int32_t num_features = ((CDenseFeatures*) data)->get_num_features(); @@ -298,22 +297,24 @@ float64_t CID3ClassifierTree::informational_gain_attribute(int32_t attr_no, CFea float64_t CID3ClassifierTree::entropy(CMulticlassLabels* labels) { - float64_t entr = 0; + SGVector log_ratios = SGVector + (labels->get_unique_labels().size()); for(int32_t i=0;iget_unique_labels().size();i++) { int32_t count = 0; + for(int32_t j=0;jget_num_labels();j++) { if(labels->get_unique_labels()[i] == labels->get_label(j)) count++; } - float64_t ratio = (count-0.f)/(labels->get_num_labels()-0.f); + log_ratios[i] = (count-0.f)/(labels->get_num_labels()-0.f); - if(ratio != 0) - entr -= ratio*(CMath::log2(ratio)); + if(log_ratios[i] != 0) + log_ratios[i] = CMath::log(log_ratios[i]); } - return entr; + return CStatistics::entropy(log_ratios.vector, log_ratios.vlen); }