Skip to content

Commit

Permalink
Drop jinja2 in trained model serialization tests and use typed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
micmn committed Jun 4, 2018
1 parent ee3da84 commit 1961292
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 238 deletions.
43 changes: 21 additions & 22 deletions tests/unit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,32 +59,31 @@ if(JINJA2_IMPORT_SUCCESS)
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
COMMENT "Generating DynamicObjectArray_unittest_generated.cc")
LIST(APPEND TEMPLATE_GENERATED_UNITTEST DynamicObjectArray_unittest_generated.cc)

IF(NOT CTAGS_FOUND)
MESSAGE("Please install Ctags for trained models serialization tests.")
ELSEIF(NOT HAVE_HDF5)
MESSAGE("Please install HDF5 for trained models serialization tests.")
ELSE()
ADD_CUSTOM_COMMAND(OUTPUT trained_model_serialization_unittest.cc
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/base/trained_model_serialization_unittest.cc.py
${CMAKE_CURRENT_SOURCE_DIR}/base/trained_model_serialization_unittest.cc.jinja2
${CTAGS_FILE}
trained_model_serialization_unittest.cc
${CMAKE_BINARY_DIR}/src/shogun/lib/config.h
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/base/trained_model_serialization_unittest.cc.py
${CMAKE_CURRENT_SOURCE_DIR}/base/trained_model_serialization_unittest.cc.jinja2
ctags
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
COMMENT "Generating trained_model_serialization_unittest.cc")
LIST(APPEND SERIALIZATION_UNITTEST trained_model_serialization_unittest.cc)
ENDIF()

LIST(APPEND SERIALIZATION_UNITTEST base/main_unittest.cc)
add_unit_test_executable(shogun-serialization-unit-test serialization-unit-tests "${SERIALIZATION_UNITTEST}")
ELSE()
MESSAGE(WARNING "Please install jinja2 for automatic generated tests.")
ENDIF()

IF(NOT CTAGS_FOUND)
MESSAGE("Please install Ctags for trained models serialization tests.")
ELSEIF(NOT HAVE_HDF5)
MESSAGE("Please install HDF5 for trained models serialization tests.")
ELSE()
ADD_CUSTOM_COMMAND(OUTPUT trained_model_serialization_unittest.h
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/base/trained_model_serialization_unittest.cc.py
${CTAGS_FILE}
${CMAKE_CURRENT_BINARY_DIR}/trained_model_serialization_unittest.h
${CMAKE_BINARY_DIR}/src/shogun/lib/config.h
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/base/trained_model_serialization_unittest.cc.py
ctags
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}
COMMENT "Generating trained_model_serialization_unittest.h")
LIST(APPEND SERIALIZATION_UNITTEST ${CMAKE_CURRENT_SOURCE_DIR}/base/trained_model_serialization_unittest.cc)
LIST(APPEND SERIALIZATION_UNITTEST ${CMAKE_CURRENT_BINARY_DIR}/trained_model_serialization_unittest.h)
ENDIF()

LIST(APPEND SERIALIZATION_UNITTEST base/main_unittest.cc)
add_unit_test_executable(shogun-serialization-unit-test serialization-unit-tests "${SERIALIZATION_UNITTEST}")

add_executable (discover_gtest_tests ${CMAKE_CURRENT_SOURCE_DIR}/discover_gtest_tests.cpp)
set_target_properties (discover_gtest_tests PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
set_target_properties (discover_gtest_tests PROPERTIES RUNTIME_OUTPUT_DIRECTORY_DEBUG ${CMAKE_BINARY_DIR}/bin)
Expand Down
197 changes: 197 additions & 0 deletions tests/unit/base/trained_model_serialization_unittest.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
#include "environments/LinearTestEnvironment.h"
#include "environments/MultiLabelTestEnvironment.h"
#include "environments/RegressionTestEnvironment.h"
#include "utils/Utils.h"
#include <gtest/gtest.h>
#include <shogun/base/some.h>
#include <shogun/features/DenseFeatures.h>
#include <shogun/io/CSVFile.h>
#include <shogun/io/SGIO.h>
#include <shogun/io/SerializableAsciiFile.h>
#include <shogun/io/SerializableHdf5File.h>
#include <shogun/kernel/GaussianKernel.h>
#include <shogun/labels/BinaryLabels.h>
#include <shogun/labels/Labels.h>
#include <shogun/labels/RegressionLabels.h>
#include <shogun/machine/Machine.h>

using namespace shogun;

extern LinearTestEnvironment* linear_test_env;
extern MultiLabelTestEnvironment* multilabel_test_env;
extern RegressionTestEnvironment* regression_test_env;

template <class T>
class TrainedModelSerializationFixture : public ::testing::Test
{
protected:
void SetUp()
{
machine = new T();
SG_REF(machine)

deserialized_machine = new T();
SG_REF(deserialized_machine)

this->load_data(this->machine->get_machine_problem_type());
}

void TearDown()
{
SG_UNREF(train_feats)
SG_UNREF(test_feats)
SG_UNREF(train_labels)
SG_UNREF(machine)
SG_UNREF(deserialized_machine)
}

void load_data(EProblemType pt)
{
switch (pt)
{
case PT_BINARY:
case PT_CLASS:
{
std::shared_ptr<GaussianCheckerboard> mock_data =
linear_test_env->getBinaryLabelData();
train_feats = mock_data->get_features_train();
test_feats = mock_data->get_features_test();
train_labels = mock_data->get_labels_train();
break;
}

case PT_MULTICLASS:
{
std::shared_ptr<GaussianCheckerboard> mock_data =
multilabel_test_env->getMulticlassFixture();
train_feats = mock_data->get_features_train();
test_feats = mock_data->get_features_test();
train_labels = mock_data->get_labels_train();
break;
}

case PT_REGRESSION:
train_feats = regression_test_env->get_features_train();
test_feats = regression_test_env->get_features_test();
train_labels = regression_test_env->get_labels_train();
break;

default:
SG_SERROR("Unsupported problem type: %d\n", pt);
FAIL();
}

SG_REF(train_feats)
SG_REF(test_feats)
SG_REF(train_labels)
}

bool serialize_machine(
CMachine* cmachine, std::string& filename,
bool store_model_features = false)
{
std::string class_name = cmachine->get_name();
filename = "shogun-unittest-trained-model-serialization-" + class_name +
".XXXXXX";
generate_temp_filename(const_cast<char*>(filename.c_str()));

CSerializableHdf5File* file =
new CSerializableHdf5File(filename.c_str(), 'w');
cmachine->set_store_model_features(store_model_features);
bool save_success = cmachine->save_serializable(file);
file->close();
SG_FREE(file);

return save_success;
}

bool deserialize_machine(CMachine* cmachine, std::string filename)
{
CSerializableHdf5File* file =
new CSerializableHdf5File(filename.c_str(), 'r');
bool load_success = cmachine->load_serializable(file);

file->close();
SG_FREE(file);
int delete_success = unlink(filename.c_str());

return load_success && (delete_success == 0);
}

CDenseFeatures<float64_t> *train_feats, *test_feats;
CLabels* train_labels;
T* machine;
T* deserialized_machine;
};

#include "trained_model_serialization_unittest.h"

template <class T>
class TrainedMachineSerialization : public TrainedModelSerializationFixture<T>
{
};

TYPED_TEST_CASE(TrainedMachineSerialization, MachineTypes);

TYPED_TEST(TrainedMachineSerialization, Test)
{
this->machine->set_labels(this->train_labels);
this->machine->train(this->train_feats);

/* to avoid serialization of the data */
// machine->set_features(NULL);
// machine->set_labels(NULL);

auto predictions = wrap<CLabels>(this->machine->apply(this->test_feats));

std::string filename;
ASSERT_TRUE(this->serialize_machine(this->machine, filename));

ASSERT_TRUE(
this->deserialize_machine(this->deserialized_machine, filename));

auto deserialized_predictions =
wrap<CLabels>(this->deserialized_machine->apply(this->test_feats));

set_global_fequals_epsilon(1e-7);
ASSERT(predictions->equals(deserialized_predictions))
set_global_fequals_epsilon(0);
}

template <class T>
class TrainedKernelMachineSerialization
: public TrainedModelSerializationFixture<T>
{
};

TYPED_TEST_CASE(TrainedKernelMachineSerialization, KernelMachineTypes);

TYPED_TEST(TrainedKernelMachineSerialization, Test)
{
CGaussianKernel* kernel = new CGaussianKernel(2.0);
this->machine->set_kernel(kernel);
this->machine->set_labels(this->train_labels);

this->machine->train(this->train_feats);

auto predictions = wrap<CLabels>(this->machine->apply(this->test_feats));

for (auto store_model_features : {false, true})
{
std::string filename;
ASSERT_TRUE(
this->serialize_machine(
this->machine, filename, store_model_features));

ASSERT_TRUE(
this->deserialize_machine(this->deserialized_machine, filename));

auto deserialized_predictions =
wrap<CLabels>(this->deserialized_machine->apply(this->test_feats));

// allow for lossy serialization format
set_global_fequals_epsilon(1e-6);
ASSERT(predictions->equals(deserialized_predictions))
set_global_fequals_epsilon(0);
}
}

0 comments on commit 1961292

Please sign in to comment.