Skip to content

Commit

Permalink
Added id3 tree pruning
Browse files Browse the repository at this point in the history
  • Loading branch information
mazumdarparijat committed Apr 11, 2014
1 parent 59e9f13 commit 390f6a5
Show file tree
Hide file tree
Showing 3 changed files with 301 additions and 75 deletions.
229 changes: 155 additions & 74 deletions src/shogun/multiclass/tree/ID3ClassifierTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <shogun/labels/MulticlassLabels.h>
#include <shogun/multiclass/tree/ID3ClassifierTree.h>
#include <shogun/mathematics/Statistics.h>
#include <shogun/evaluation/MulticlassAccuracy.h>

using namespace shogun;

Expand All @@ -49,51 +50,20 @@ CMulticlassLabels* CID3ClassifierTree::apply_multiclass(CFeatures* data)
{
REQUIRE(data, "Data required for classification in apply_multiclass\n")

CDenseFeatures<float64_t>* feats = dynamic_cast<CDenseFeatures<float64_t>*>(data);
int32_t num_vecs = feats->get_num_vectors();
SGVector<float64_t> labels = SGVector<float64_t>(num_vecs);

for (int32_t i=0; i<num_vecs; i++)
{
SGVector<float64_t> sample = feats->get_feature_vector(i);
node_t* node = get_root();
CDynamicObjectArray* children = node->get_children();

while (children->get_num_elements())
{
int32_t flag = 0;
for (int32_t j=0; j<children->get_num_elements(); j++)
{
node_t* child = dynamic_cast<node_t*>(children->get_element(j));
if (child->data.transit_if_feature_value
== sample[node->data.attribute_id])
{
flag = 1;
node_t* current = get_root();
CMulticlassLabels* ret = apply_multiclass_from_current_node((CDenseFeatures<float64_t>*) data, current);

SG_UNREF(node);
node = child;

SG_UNREF(children);
children = node->get_children();

break;
}

SG_UNREF(child);
}
SG_UNREF(current);
return ret;
}

if (!flag)
break;
}

labels[i] = node->data.class_label;
bool CID3ClassifierTree::prune_tree(CDenseFeatures<float64_t>* validation_data, CMulticlassLabels* validation_labels)
{
node_t* current = get_root();
prune_tree_machine(validation_data, validation_labels, current);

SG_UNREF(node);
SG_UNREF(children);
}

CMulticlassLabels* ret = new CMulticlassLabels(labels);
return ret;
SG_UNREF(current);
return true;
}

bool CID3ClassifierTree::train_machine(CFeatures* data)
Expand All @@ -117,44 +87,40 @@ CTreeMachineNode<id3TreeNodeData>* CID3ClassifierTree::id3train(CFeatures* data,
CDenseFeatures<float64_t>* feats = dynamic_cast<CDenseFeatures<float64_t>*>(data);
int32_t num_vecs = feats->get_num_vectors();

// if all samples belong to the same class
if (class_labels->get_unique_labels().size() == 1)
{
node->data.class_label = class_labels->get_unique_labels()[0];
return node;
}

// if no feature is left
if (feature_id_vector.vlen == 0)
// set class_label for the node as the mode of occurring multiclass labels
SGVector<float64_t> labels = class_labels->get_labels_copy();
labels.qsort();

int32_t most_label = labels[0];
int32_t most_num = 1;
int32_t count = 1;

for (int32_t i=1; i<labels.vlen; i++)
{
// decide label - label occuring max times
SGVector<float64_t> labels = class_labels->get_labels();
labels.qsort();

int32_t most_label = labels[0];
int32_t most_num = 1;
int32_t count = 1;

for (int32_t i=1; i<labels.vlen; i++)
while ((labels[i] == labels[i-1]) && (i<labels.vlen))
{
while ((labels[i] == labels[i-1]) && (i<labels.vlen))
{
count++;
i++;
}
count++;
i++;
}

if (count>most_num)
{
most_num = count;
most_label = labels[i-1];
}

count = 1;
}

if (count>most_num)
{
most_num = count;
most_label = labels[i-1];
}
node->data.class_label = most_label;

count = 1;
}
// if all samples belong to the same class
if (most_num == labels.vlen)
return node;

node->data.class_label = most_label;
// if no feature is left
if (feature_id_vector.vlen == 0)
return node;
}

// else get the feature with the highest informational gain
float64_t max = 0;
Expand Down Expand Up @@ -316,3 +282,118 @@ float64_t CID3ClassifierTree::entropy(CMulticlassLabels* labels)

return CStatistics::entropy(log_ratios.vector, log_ratios.vlen);
}

void CID3ClassifierTree::prune_tree_machine(CDenseFeatures<float64_t>* feats, CMulticlassLabels* gnd_truth, node_t* current)
{
SGMatrix<float64_t> feature_matrix = feats->get_feature_matrix();
CDynamicObjectArray* children = current->get_children();

for (int32_t i=0; i<children->get_num_elements(); i++)
{
// count number of feature vectors which transit into the child
int32_t count = 0;
node_t* child = dynamic_cast<node_t*>(children->get_element(i));

for (int32_t j=0; j<feature_matrix.num_cols; j++)
{
if (child->data.transit_if_feature_value == feature_matrix(current->data.attribute_id,j))
count++;
}

// form new set of features and labels
SGMatrix<float64_t> new_fmatrix = SGMatrix<float64_t>(feature_matrix.num_rows, count);
SGVector<float64_t> new_lab = SGVector<float64_t>(count);
int32_t k = 0;

for (int32_t j=0; j<feature_matrix.num_cols;j++)
{
if (child->data.transit_if_feature_value == feature_matrix(current->data.attribute_id,j))
{
memcpy(new_fmatrix.matrix+k*feature_matrix.num_rows,
feature_matrix.matrix+j*feature_matrix.num_rows, feature_matrix.num_rows*sizeof(float64_t));

new_lab[k] = gnd_truth->get_label(j);
k++;
}
}

CDenseFeatures<float64_t>* new_feats = new CDenseFeatures<float64_t>(new_fmatrix);
CMulticlassLabels* new_labels = new CMulticlassLabels(new_lab);

// prune the child subtree
prune_tree_machine(new_feats, new_labels, child);

SG_UNREF(new_feats);
SG_UNREF(new_labels);
SG_UNREF(child);
}

SG_UNREF(children);

CMulticlassLabels* predicted_unpruned = apply_multiclass_from_current_node(feats, current);
SGVector<float64_t> pruned_labels = SGVector<float64_t>(feats->get_num_vectors());
for (int32_t i=0; i<feats->get_num_vectors(); i++)
pruned_labels[i] = current->data.class_label;

CMulticlassLabels* predicted_pruned = new CMulticlassLabels(pruned_labels);

CMulticlassAccuracy* accuracy = new CMulticlassAccuracy();
float64_t unpruned_accuracy = accuracy->evaluate(predicted_unpruned, gnd_truth);
float64_t pruned_accuracy = accuracy->evaluate(predicted_pruned, gnd_truth);

if (unpruned_accuracy<pruned_accuracy)
current->set_children(new CDynamicObjectArray());

SG_UNREF(accuracy);
SG_UNREF(predicted_pruned);
SG_UNREF(predicted_unpruned);
}

CMulticlassLabels* CID3ClassifierTree::apply_multiclass_from_current_node(CDenseFeatures<float64_t>* feats, node_t* current)
{
int32_t num_vecs = feats->get_num_vectors();
SGVector<float64_t> labels = SGVector<float64_t>(num_vecs);

for (int32_t i=0; i<num_vecs; i++)
{
SGVector<float64_t> sample = feats->get_feature_vector(i);
node_t* node = current;
SG_REF(node);
CDynamicObjectArray* children = node->get_children();

while (children->get_num_elements())
{
int32_t flag = 0;
for (int32_t j=0; j<children->get_num_elements(); j++)
{
node_t* child = dynamic_cast<node_t*>(children->get_element(j));
if (child->data.transit_if_feature_value
== sample[node->data.attribute_id])
{
flag = 1;

SG_UNREF(node);
node = child;

SG_UNREF(children);
children = node->get_children();

break;
}

SG_UNREF(child);
}

if (!flag)
break;
}

labels[i] = node->data.class_label;

SG_UNREF(node);
SG_UNREF(children);
}

CMulticlassLabels* ret = new CMulticlassLabels(labels);
return ret;
}
25 changes: 24 additions & 1 deletion src/shogun/multiclass/tree/ID3ClassifierTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,14 @@ class CID3ClassifierTree : public CTreeMachine<id3TreeNodeData>
*/
virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL);

/** prune id3 decision tree
* @param validation_data feature vectors from validation dataset
* @param validation_labels multiclass labels from validation dataset
*
* @return true if pruning successful
*/
virtual bool prune_tree(CDenseFeatures<float64_t>* validation_data, CMulticlassLabels* validation_labels);

protected:

/** train machine - build ID3 Tree from training data
Expand Down Expand Up @@ -121,7 +129,22 @@ class CID3ClassifierTree : public CTreeMachine<id3TreeNodeData>
* @return entropy
*/
float64_t entropy(CMulticlassLabels* labels);


/** recursive tree pruning method - called within prune_tree method
*
* @param feats feature set to use for pruning
* @param gnd_truth ground truth labels
* @param current root of current subtree
*/
void prune_tree_machine(CDenseFeatures<float64_t>* feats, CMulticlassLabels* gnd_truth, node_t* current);

/** uses current subtree to classify data
*
* @param feats data to be classified
* @param current root of current subtree
* @return classification labels of input data
*/
CMulticlassLabels* apply_multiclass_from_current_node(CDenseFeatures<float64_t>* feats, node_t* current);
};
} /* namespace shogun */

Expand Down
Loading

0 comments on commit 390f6a5

Please sign in to comment.