Skip to content

Commit

Permalink
Add serialization test for classification machine
Browse files Browse the repository at this point in the history
  • Loading branch information
vigsterkr committed May 10, 2013
1 parent db8e225 commit b1c6bb0
Showing 1 changed file with 63 additions and 0 deletions.
63 changes: 63 additions & 0 deletions tests/unit/base/Serialization_unittest.cc
Expand Up @@ -8,8 +8,12 @@
*/

#include <shogun/base/init.h>
#include <shogun/labels/BinaryLabels.h>
#include <shogun/labels/MulticlassLabels.h>
#include <shogun/io/SerializableAsciiFile.h>
#include <shogun/classifier/svm/LibLinear.h>
#include <shogun/features/DataGenerator.h>
#include <shogun/features/DenseFeatures.h>
#include <gtest/gtest.h>

using namespace shogun;
Expand Down Expand Up @@ -68,4 +72,63 @@ TEST(Serialization,multiclass_labels)
SG_UNREF(labels);
}

TEST(Serialization, liblinear)
{
index_t num_samples = 50;
CMath::init_random(13);
SGMatrix<float64_t> data =
CDataGenerator::generate_gaussians(num_samples, 2, 2);
CDenseFeatures<float64_t> features(data);

SGVector<index_t> train_idx(num_samples), test_idx(num_samples);
SGVector<float64_t> labels(num_samples);
for (index_t i = 0, j = 0; i < data.num_cols; ++i)
{
if (i % 2 == 0)
train_idx[j] = i;
else
test_idx[j++] = i;

labels[i/2] = (i < data.num_cols/2) ? 1.0 : -1.0;
}

CDenseFeatures<float64_t>* train_feats = (CDenseFeatures<float64_t>*)features.copy_subset(train_idx);
CDenseFeatures<float64_t>* test_feats = (CDenseFeatures<float64_t>*)features.copy_subset(test_idx);

CBinaryLabels* ground_truth = new CBinaryLabels(labels);

CLibLinear* liblin = new CLibLinear(1.0, train_feats, ground_truth);
liblin->set_epsilon(1e-5);
liblin->train();

CBinaryLabels* pred = CBinaryLabels::obtain_from_generic(liblin->apply(test_feats));
for (int i = 0; i < num_samples; ++i)
EXPECT_EQ(ground_truth->get_int_label(i), pred->get_int_label(i));

/* save liblin */
const char* filename="trained_liblin.txt";
CSerializableAsciiFile* file=new CSerializableAsciiFile(filename, 'w');
liblin->save_serializable(file);
file->close();
SG_UNREF(file);

/* load liblin */
file=new CSerializableAsciiFile(filename, 'r');
CLibLinear* liblin_loaded=new CLibLinear();
liblin_loaded->load_serializable(file);
file->close();
SG_UNREF(file);

/* classify with the deserialized model */
pred = CBinaryLabels::obtain_from_generic(liblin_loaded->apply(test_feats));
for (int i = 0; i < num_samples; ++i)
EXPECT_EQ(ground_truth->get_int_label(i), pred->get_int_label(i));

SG_UNREF(liblin_loaded);
SG_UNREF(liblin);
SG_UNREF(train_feats);
SG_UNREF(test_feats);
SG_UNREF(pred);
}


0 comments on commit b1c6bb0

Please sign in to comment.