diff --git a/tests/unit/base/Serialization_unittest.cc b/tests/unit/base/Serialization_unittest.cc index ee428c95fdd..b4fd84e0e47 100644 --- a/tests/unit/base/Serialization_unittest.cc +++ b/tests/unit/base/Serialization_unittest.cc @@ -8,8 +8,12 @@ */ #include +#include #include #include +#include +#include +#include #include using namespace shogun; @@ -68,4 +72,63 @@ TEST(Serialization,multiclass_labels) SG_UNREF(labels); } +TEST(Serialization, liblinear) +{ + index_t num_samples = 50; + CMath::init_random(13); + SGMatrix data = + CDataGenerator::generate_gaussians(num_samples, 2, 2); + CDenseFeatures features(data); + + SGVector train_idx(num_samples), test_idx(num_samples); + SGVector labels(num_samples); + for (index_t i = 0, j = 0; i < data.num_cols; ++i) + { + if (i % 2 == 0) + train_idx[j] = i; + else + test_idx[j++] = i; + + labels[i/2] = (i < data.num_cols/2) ? 1.0 : -1.0; + } + + CDenseFeatures* train_feats = (CDenseFeatures*)features.copy_subset(train_idx); + CDenseFeatures* test_feats = (CDenseFeatures*)features.copy_subset(test_idx); + + CBinaryLabels* ground_truth = new CBinaryLabels(labels); + + CLibLinear* liblin = new CLibLinear(1.0, train_feats, ground_truth); + liblin->set_epsilon(1e-5); + liblin->train(); + + CBinaryLabels* pred = CBinaryLabels::obtain_from_generic(liblin->apply(test_feats)); + for (int i = 0; i < num_samples; ++i) + EXPECT_EQ(ground_truth->get_int_label(i), pred->get_int_label(i)); + + /* save liblin */ + const char* filename="trained_liblin.txt"; + CSerializableAsciiFile* file=new CSerializableAsciiFile(filename, 'w'); + liblin->save_serializable(file); + file->close(); + SG_UNREF(file); + + /* load liblin */ + file=new CSerializableAsciiFile(filename, 'r'); + CLibLinear* liblin_loaded=new CLibLinear(); + liblin_loaded->load_serializable(file); + file->close(); + SG_UNREF(file); + + /* classify with the deserialized model */ + pred = CBinaryLabels::obtain_from_generic(liblin_loaded->apply(test_feats)); + for (int i = 0; i < num_samples; ++i) + EXPECT_EQ(ground_truth->get_int_label(i), pred->get_int_label(i)); + + SG_UNREF(liblin_loaded); + SG_UNREF(liblin); + SG_UNREF(train_feats); + SG_UNREF(test_feats); + SG_UNREF(pred); +} +