Skip to content

Commit

Permalink
update splittingstrategy tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Saurabh7 committed Dec 27, 2013
1 parent 88e2f02 commit b40ff33
Showing 1 changed file with 104 additions and 2 deletions.
Expand Up @@ -12,11 +12,13 @@
#include <shogun/labels/RegressionLabels.h>
#include <shogun/evaluation/StratifiedCrossValidationSplitting.h>
#include <shogun/labels/MulticlassLabels.h>
#include <shogun/evaluation/LOOCrossValidationSplitting.h>
#include <gtest/gtest.h>


using namespace shogun;

TEST(Crossvalidation,standard)
TEST(SplittingStrategy,standard)
{
index_t fold_sizes;
index_t num_labels;
Expand Down Expand Up @@ -84,7 +86,7 @@ TEST(Crossvalidation,standard)

}

TEST(Crossvalidation,stratified)
TEST(SplittingStrategy,stratified_subsets_disjoint_cover)
{
index_t num_labels, num_classes, num_subsets, fold_sizes;
index_t runs=50;
Expand Down Expand Up @@ -158,6 +160,48 @@ TEST(Crossvalidation,stratified)

EXPECT_EQ(flag,0);

/* clean up */
SG_UNREF(splitting);
}
}

TEST(SplittingStrategy,stratified_subset_label_ratio)
{
index_t num_labels, num_classes, num_subsets, fold_sizes;
index_t runs=50;

while (runs-->0)
{
num_labels=CMath::random(11, 100);
num_classes=CMath::random(2, 10);
num_subsets=CMath::random(1, 10);

/* build labels */
CMulticlassLabels* labels=new CMulticlassLabels(num_labels);
for (index_t i=0; i<num_labels; ++i)
labels->set_label(i, CMath::random()%num_classes);

/*No. of labels belonging to one class*/
SGVector<index_t> class_labels(num_classes);
SGVector<index_t>::fill_vector(class_labels.vector, class_labels.vlen, 0);

/*check total no. of class labels*/
for (index_t i=0; i<num_classes; ++i)
{
for(index_t j=0; j<num_labels; ++j)
{
if ((int32_t)labels->get_label(j)==i)
++class_labels.vector[i];
}
}


/* build splitting strategy */
CStratifiedCrossValidationSplitting* splitting=
new CStratifiedCrossValidationSplitting(labels, num_subsets);

splitting->build_subsets();

/* check whether number of labels in every subset is nearly equal */
for (index_t i=0; i<num_classes; ++i)
{
Expand Down Expand Up @@ -188,6 +232,64 @@ TEST(Crossvalidation,stratified)
}
EXPECT_EQ(total_count,class_labels.vector[i]);
}
}
}


TEST(SplittingStrategy,LOO)
{
index_t num_labels, fold_sizes;
index_t runs=10;

while (runs-->0)
{
fold_sizes=0;
num_labels=CMath::random(10, 50);

/* build labels */
CRegressionLabels* labels=new CRegressionLabels(num_labels);
for (index_t i=0; i<num_labels; ++i)
labels->set_label(i, CMath::random(-10.0, 10.0));

/* build Leave one out splitting strategy */
CLOOCrossValidationSplitting* splitting=
new CLOOCrossValidationSplitting(labels);

splitting->build_subsets();

SGVector<index_t> total(num_labels);
SGVector<index_t>::fill_vector(total.vector, total.vlen,(index_t)-1);

for (index_t i=0; i<num_labels; ++i)
{
SGVector<index_t> subset=splitting->generate_subset_indices(i);
SGVector<index_t> inverse=splitting->generate_subset_inverse(i);

for(index_t j=0;j<subset.vlen;++j)
{
/*check if fold indices are disjoint*/
SGVector<index_t> temp=total.find((index_t)subset.vector[j]);
EXPECT_EQ(temp.vlen,0);

total.vector[j+fold_sizes]=subset.vector[j];
}

EXPECT_EQ(subset.vlen+inverse.vlen, num_labels);
fold_sizes+=subset.vlen;
}

EXPECT_EQ(fold_sizes, num_labels);

index_t flag=0;
/*check if indices in all folds cover available indices*/
for (index_t i=0;i<num_labels;++i)
{
SGVector<index_t> temp=total.find((index_t)i);
if(temp.vlen == 0)
flag = 1;
}

EXPECT_EQ(flag,0);

/* clean up */
SG_UNREF(splitting);
Expand Down

0 comments on commit b40ff33

Please sign in to comment.