diff --git a/CMakeLists.txt b/CMakeLists.txt index c6f209fbe90..14e70d8b01e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -392,6 +392,14 @@ IF (GDB_FOUND) SET(GDB_DEFAULT_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/src/.gdb) ENDIF() +FIND_PACKAGE(Cereal) +IF(NOT CEREAL_FOUND) + include(external/Cereal) + LIST(APPEND INCLUDES ${CEREAL_INCLUDE_DIRS}) +ELSE() + LIST(APPEND INCLUDES ${CEREAL_INCLUDE_DIRS}) +ENDIF() + FIND_PACKAGE(Doxygen 1.8.6) IF(DOXYGEN_FOUND) SET(HAVE_DOXYGEN 1) diff --git a/cmake/FindCereal.cmake b/cmake/FindCereal.cmake new file mode 100644 index 00000000000..5ed51bd3e35 --- /dev/null +++ b/cmake/FindCereal.cmake @@ -0,0 +1,19 @@ +# - Try to find Cereal Serialization Library +# +# This sets the following variables: +# CEREAL_FOUND - True if Cereal was found. +# CEREAL_INCLUDE_DIRS - Directories containing the Cereal include files. + +find_path(CEREAL_INCLUDE_DIR cereal + HINTS "$ENV{CMAKE_SOURCE_DIR}/include" "/usr/include" "$ENV{CMAKE_BINARY_DIR}/cereal/include") + +set(CEREAL_INCLUDE_DIRS ${CEREAL_INCLUDE_DIR}) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(Cereal DEFAULT_MSG CEREAL_INCLUDE_DIR) + +mark_as_advanced(CEREAL_INCLUDE_DIR) + +if(CEREAL_FOUND) + MESSAGE(STATUS "Found Cereal: ${CEREAL_INCLUDE_DIRS}") +endif(CEREAL_FOUND) diff --git a/cmake/external/Cereal.cmake b/cmake/external/Cereal.cmake new file mode 100644 index 00000000000..d8e3603b1e9 --- /dev/null +++ b/cmake/external/Cereal.cmake @@ -0,0 +1,14 @@ +include(ExternalProject) +ExternalProject_Add( + Cereal + PREFIX ${CMAKE_BINARY_DIR}/Cereal + DOWNLOAD_DIR ${THIRD_PARTY_DIR}/Cereal + URL https://github.com/USCiLab/cereal/archive/v1.2.0.tar.gz + URL_MD5 e372c9814696481dbdb7d500e1410d2b + CMAKE_ARGS -DCMAKE_C_FLAGS:STRING=${CMAKE_C_FLAGS}${CMAKE_DEFINITIONS} + -DCMAKE_CXX_FLAGS:STRING=${CMAKE_CXX_FLAGS}${CMAKE_DEFINITIONS} + INSTALL_COMMAND "" + ) + +SET(CEREAL_INCLUDE_DIRS ${CMAKE_BINARY_DIR}/Cereal/src/Cereal/include) +LIST(APPEND SHOGUN_DEPENDS Cereal) diff --git a/src/shogun/base/AnyParameter.h b/src/shogun/base/AnyParameter.h index e363bd70240..b60a6125358 100644 --- a/src/shogun/base/AnyParameter.h +++ b/src/shogun/base/AnyParameter.h @@ -50,6 +50,16 @@ namespace shogun return m_gradient; } + /** serialize the object using cereal + * + * @param ar Archive type + */ + template + void serialize(Archive& ar) + { + ar(m_model_selection, m_gradient); + } + private: EModelSelectionAvailability m_model_selection; EGradientAvailability m_gradient; @@ -88,6 +98,16 @@ namespace shogun return m_properties; } + /** serialize the object using cereal + * + * @param ar Archive type + */ + template + void serialize(Archive& ar) + { + ar(m_value, m_properties); + } + private: Any m_value; AnyParameterProperties m_properties; diff --git a/src/shogun/base/SGObject.cpp b/src/shogun/base/SGObject.cpp index 8418c2f508b..b8a9db3e927 100644 --- a/src/shogun/base/SGObject.cpp +++ b/src/shogun/base/SGObject.cpp @@ -32,6 +32,14 @@ #include #include +#include +#include +#include +#include +#include +#include +#include + #include namespace shogun @@ -229,6 +237,74 @@ int32_t CSGObject::unref() } } +void CSGObject::save_binary(const char* filename) const +{ + std::ofstream os(filename); + cereal::BinaryOutputArchive archive(os); + archive(cereal::make_nvp(this->get_name(), *this)); +} + +void CSGObject::save_json(const char* filename) const +{ + std::ofstream os(filename); + cereal::JSONOutputArchive archive(os); + archive(cereal::make_nvp(this->get_name(), *this)); +} + +void CSGObject::save_xml(const char* filename) const +{ + std::ofstream os(filename); + cereal::XMLOutputArchive archive(os); + archive(cereal::make_nvp(this->get_name(), *this)); +} + +void CSGObject::load_binary(const char* filename) +{ + std::ifstream is(filename); + cereal::BinaryInputArchive archive(is); + archive(*this); +} + +void CSGObject::load_json(const char* filename) +{ + std::ifstream is(filename); + cereal::JSONInputArchive archive(is); + archive(*this); +} + +void CSGObject::load_xml(const char* filename) +{ + std::ifstream is(filename); + cereal::XMLInputArchive archive(is); + archive(*this); +} + +template +void CSGObject::cereal_save(Archive& ar) const +{ + for (const auto& it : self->map) + ar(cereal::make_nvp(it.first.name(), it.second)); +} +template void CSGObject::cereal_save( + cereal::BinaryOutputArchive& ar) const; +template void CSGObject::cereal_save( + cereal::JSONOutputArchive& ar) const; +template void CSGObject::cereal_save( + cereal::XMLOutputArchive& ar) const; + +template +void CSGObject::cereal_load(Archive& ar) +{ + for (auto& it : self->map) + ar(it.second); +} +template void CSGObject::cereal_load( + cereal::BinaryInputArchive& ar); +template void +CSGObject::cereal_load(cereal::JSONInputArchive& ar); +template void +CSGObject::cereal_load(cereal::XMLInputArchive& ar); + #ifdef TRACE_MEMORY_ALLOCS #include extern CMap* sg_mallocs; diff --git a/src/shogun/base/SGObject.h b/src/shogun/base/SGObject.h index b6bc550137d..c89367c5120 100644 --- a/src/shogun/base/SGObject.h +++ b/src/shogun/base/SGObject.h @@ -137,6 +137,58 @@ class CSGObject /** destructor */ virtual ~CSGObject(); +#ifndef SWIG // SWIG should skip this part + /** serializes the SGObject to a binary file + * + * @param filename Binary archive filename + */ + void save_binary(const char* filename) const; + + /** serializes the SGObject to a JSON file + * + * @param filename JSON archive filename + */ + void save_json(const char* filename) const; + + /** serializes the SGObject to a XML file + * + * @param filename XML archive filename + */ + void save_xml(const char* filename) const; + + /** loads SGObject from a Binary file + * + * @param filename Binary archive filename + */ + void load_binary(const char* filename); + + /** loads SGObject from a JSON file + * + * @param filename JSON archive filename + */ + void load_json(const char* filename); + + /** loads SGObject from a XML file + * + * @param filename XML archive filename + */ + void load_xml(const char* filename); + + /** serializes SGObject parameters to Archive with Cereal + * + * @param ar Archive + */ + template + void cereal_save(Archive& ar) const; + + /** loads SGObject parameters from Archive with Cereal + * + * @param ar Archive + */ + template + void cereal_load(Archive& ar); +#endif // #ifndef SWIG // SWIG should skip this part + /** increase reference counter * * @return reference count diff --git a/src/shogun/lib/SGMatrix.cpp b/src/shogun/lib/SGMatrix.cpp index 12a2df5bd84..b69829b33cd 100644 --- a/src/shogun/lib/SGMatrix.cpp +++ b/src/shogun/lib/SGMatrix.cpp @@ -24,139 +24,154 @@ #include #include -namespace shogun -{ +#include +#include +#include +#include +#include -template -SGMatrix::SGMatrix() : SGReferencedData() -{ - init_data(); -} - -template -SGMatrix::SGMatrix(bool ref_counting) : SGReferencedData(ref_counting) +namespace shogun { - init_data(); -} -template -SGMatrix::SGMatrix(T* m, index_t nrows, index_t ncols, bool ref_counting) - : SGReferencedData(ref_counting), matrix(m), - num_rows(nrows), num_cols(ncols), gpu_ptr(nullptr) -{ - m_on_gpu.store(false, std::memory_order_release); -} + template + SGMatrix::SGMatrix() : SGReferencedData() + { + init_data(); + } -template -SGMatrix::SGMatrix(T* m, index_t nrows, index_t ncols, index_t offset) - : SGReferencedData(false), matrix(m+offset), - num_rows(nrows), num_cols(ncols) -{ - m_on_gpu.store(false, std::memory_order_release); -} + template + SGMatrix::SGMatrix(bool ref_counting) : SGReferencedData(ref_counting) + { + init_data(); + } -template -SGMatrix::SGMatrix(index_t nrows, index_t ncols, bool ref_counting) - : SGReferencedData(ref_counting), num_rows(nrows), num_cols(ncols), gpu_ptr(nullptr) -{ - matrix=SG_CALLOC(T, ((int64_t) nrows)*ncols); - m_on_gpu.store(false, std::memory_order_release); -} + template + SGMatrix::SGMatrix(T* m, index_t nrows, index_t ncols, bool ref_counting) + : SGReferencedData(ref_counting), matrix(m), num_rows(nrows), + num_cols(ncols), gpu_ptr(nullptr) + { + m_on_gpu.store(false, std::memory_order_release); + } -template -SGMatrix::SGMatrix(SGVector vec) : SGReferencedData(vec) -{ - REQUIRE((vec.vector || vec.gpu_ptr), "Vector not initialized!\n"); - matrix=vec.vector; - num_rows=vec.vlen; - num_cols=1; - gpu_ptr = vec.gpu_ptr; - m_on_gpu.store(vec.on_gpu(), std::memory_order_release); -} + template + SGMatrix::SGMatrix(T* m, index_t nrows, index_t ncols, index_t offset) + : SGReferencedData(false), matrix(m + offset), num_rows(nrows), + num_cols(ncols) + { + m_on_gpu.store(false, std::memory_order_release); + } -template -SGMatrix::SGMatrix(SGVector vec, index_t nrows, index_t ncols) -: SGReferencedData(vec) -{ - REQUIRE((vec.vector || vec.gpu_ptr), "Vector not initialized!\n"); - REQUIRE(nrows>0, "Number of rows (%d) has to be a positive integer!\n", nrows); - REQUIRE(ncols>0, "Number of cols (%d) has to be a positive integer!\n", ncols); - REQUIRE(vec.vlen==nrows*ncols, "Number of elements in the matrix (%d) must " - "be the same as the number of elements in the vector (%d)!\n", - nrows*ncols, vec.vlen); + template + SGMatrix::SGMatrix(index_t nrows, index_t ncols, bool ref_counting) + : SGReferencedData(ref_counting), num_rows(nrows), num_cols(ncols), + gpu_ptr(nullptr) + { + matrix = SG_CALLOC(T, ((int64_t)nrows) * ncols); + m_on_gpu.store(false, std::memory_order_release); + } - matrix=vec.vector; - num_rows=nrows; - num_cols=ncols; - gpu_ptr = vec.gpu_ptr; - m_on_gpu.store(vec.on_gpu(), std::memory_order_release); -} + template + SGMatrix::SGMatrix(SGVector vec) : SGReferencedData(vec) + { + REQUIRE((vec.vector || vec.gpu_ptr), "Vector not initialized!\n"); + matrix = vec.vector; + num_rows = vec.vlen; + num_cols = 1; + gpu_ptr = vec.gpu_ptr; + m_on_gpu.store(vec.on_gpu(), std::memory_order_release); + } -template -SGMatrix::SGMatrix(GPUMemoryBase* mat, index_t nrows, index_t ncols) - : SGReferencedData(true), matrix(NULL), num_rows(nrows), num_cols(ncols), - gpu_ptr(std::shared_ptr>(mat)) -{ - m_on_gpu.store(true, std::memory_order_release); -} + template + SGMatrix::SGMatrix(SGVector vec, index_t nrows, index_t ncols) + : SGReferencedData(vec) + { + REQUIRE((vec.vector || vec.gpu_ptr), "Vector not initialized!\n"); + REQUIRE( + nrows > 0, "Number of rows (%d) has to be a positive integer!\n", + nrows); + REQUIRE( + ncols > 0, "Number of cols (%d) has to be a positive integer!\n", + ncols); + REQUIRE( + vec.vlen == nrows * ncols, + "Number of elements in the matrix (%d) must " + "be the same as the number of elements in the vector (%d)!\n", + nrows * ncols, vec.vlen); + + matrix = vec.vector; + num_rows = nrows; + num_cols = ncols; + gpu_ptr = vec.gpu_ptr; + m_on_gpu.store(vec.on_gpu(), std::memory_order_release); + } -template -SGMatrix::SGMatrix(const SGMatrix &orig) : SGReferencedData(orig) -{ - copy_data(orig); -} + template + SGMatrix::SGMatrix(GPUMemoryBase* mat, index_t nrows, index_t ncols) + : SGReferencedData(true), matrix(NULL), num_rows(nrows), + num_cols(ncols), gpu_ptr(std::shared_ptr>(mat)) + { + m_on_gpu.store(true, std::memory_order_release); + } -template -SGMatrix::SGMatrix(EigenMatrixXt& mat) -: SGReferencedData(false), matrix(mat.data()), - num_rows(mat.rows()), num_cols(mat.cols()), gpu_ptr(nullptr) -{ - m_on_gpu.store(false, std::memory_order_release); -} + template + SGMatrix::SGMatrix(const SGMatrix& orig) : SGReferencedData(orig) + { + copy_data(orig); + } -template -SGMatrix::operator EigenMatrixXtMap() const -{ - assert_on_cpu(); - return EigenMatrixXtMap(matrix, num_rows, num_cols); -} + template + SGMatrix::SGMatrix(EigenMatrixXt& mat) + : SGReferencedData(false), matrix(mat.data()), num_rows(mat.rows()), + num_cols(mat.cols()), gpu_ptr(nullptr) + { + m_on_gpu.store(false, std::memory_order_release); + } -template -SGMatrix& SGMatrix::operator=(const SGMatrix& other) -{ - if(&other == this) - return *this; + template + SGMatrix::operator EigenMatrixXtMap() const + { + assert_on_cpu(); + return EigenMatrixXtMap(matrix, num_rows, num_cols); + } - unref(); - copy_data(other); - copy_refcount(other); - ref(); - return *this; -} + template + SGMatrix& SGMatrix::operator=(const SGMatrix& other) + { + if (&other == this) + return *this; + + unref(); + copy_data(other); + copy_refcount(other); + ref(); + return *this; + } -template -SGMatrix::~SGMatrix() -{ - unref(); -} + template + SGMatrix::~SGMatrix() + { + unref(); + } -template -bool SGMatrix::equals(const SGMatrix& other) const -{ - // avoid comparing elements when both are same. - // the case where both matrices are uninitialized is handled here as well. - if (*this==other) - return true; + template + bool SGMatrix::equals(const SGMatrix& other) const + { + // avoid comparing elements when both are same. + // the case where both matrices are uninitialized is handled here as + // well. + if (*this == other) + return true; - // avoid uninitialized memory read in case the matrices are not initialized - if (matrix==nullptr || other.matrix==nullptr) - return false; + // avoid uninitialized memory read in case the matrices are not + // initialized + if (matrix == nullptr || other.matrix == nullptr) + return false; - if (num_rows!=other.num_rows || num_cols!=other.num_cols) - return false; + if (num_rows != other.num_rows || num_cols != other.num_cols) + return false; - return std::equal(matrix, matrix+size(), other.matrix); -} + return std::equal(matrix, matrix + size(), other.matrix); + } #ifndef REAL_EQUALS #define REAL_EQUALS(real_t) \ @@ -1207,7 +1222,61 @@ void SGMatrix::save(CFile* saver) SG_SERROR("SGMatrix::save():: Not supported for complex128_t\n"); } -template +template +template +void SGMatrix::cereal_save(Archive& ar) const +{ + ar(cereal::base_class(this)); + ar(cereal::make_nvp("num_rows", num_rows)); + ar(cereal::make_nvp("num_cols", num_cols)); + + for (index_t i = 0; i < num_rows * num_cols; ++i) + ar(matrix[i]); +} + +template <> +template +void SGMatrix::cereal_save(Archive& ar) const +{ + ar(cereal::base_class(this)); + ar(cereal::make_nvp("num_rows", num_rows)); + ar(cereal::make_nvp("num_cols", num_cols)); + + float64_t* temp = reinterpret_cast(matrix); + for (index_t i = 0; i < num_rows * num_cols * 2; ++i) + ar(temp[i]); +} + +template +template +void SGMatrix::cereal_load(Archive& ar) +{ + unref(); + ar(cereal::base_class(this)); + + ar(num_rows); + ar(num_cols); + matrix = SG_MALLOC(T, num_rows * num_cols); + for (index_t i = 0; i < num_rows * num_cols; ++i) + ar(matrix[i]); +} + +template <> +template +void SGMatrix::cereal_load(Archive& ar) +{ + unref(); + ar(cereal::base_class(this)); + + ar(num_rows); + ar(num_cols); + matrix = SG_MALLOC(complex128_t, num_rows * num_cols); + float64_t* temp = reinterpret_cast(matrix); + for (index_t i = 0; i < num_rows * num_cols * 2; ++i) + ar(temp[i]); +} + +template SGVector SGMatrix::get_row_vector(index_t row) const { assert_on_cpu(); @@ -1234,18 +1303,33 @@ SGVector SGMatrix::get_diagonal_vector() const return diag; } -template class SGMatrix; -template class SGMatrix; -template class SGMatrix; -template class SGMatrix; -template class SGMatrix; -template class SGMatrix; -template class SGMatrix; -template class SGMatrix; -template class SGMatrix; -template class SGMatrix; -template class SGMatrix; -template class SGMatrix; -template class SGMatrix; -template class SGMatrix; +#define FILL_SGMATRIX(typetype) \ + template class SGMatrix; \ + template void SGMatrix::cereal_save( \ + cereal::BinaryOutputArchive& ar) const; \ + template void SGMatrix::cereal_save( \ + cereal::JSONOutputArchive& ar) const; \ + template void SGMatrix::cereal_save( \ + cereal::XMLOutputArchive& ar) const; \ + template void SGMatrix::cereal_load( \ + cereal::BinaryInputArchive& ar); \ + template void SGMatrix::cereal_load( \ + cereal::JSONInputArchive& ar); \ + template void SGMatrix::cereal_load(cereal::XMLInputArchive& ar); + +FILL_SGMATRIX(bool) +FILL_SGMATRIX(char) +FILL_SGMATRIX(int8_t) +FILL_SGMATRIX(uint8_t) +FILL_SGMATRIX(int16_t) +FILL_SGMATRIX(uint16_t) +FILL_SGMATRIX(int32_t) +FILL_SGMATRIX(uint32_t) +FILL_SGMATRIX(int64_t) +FILL_SGMATRIX(uint64_t) +FILL_SGMATRIX(float32_t) +FILL_SGMATRIX(float64_t) +FILL_SGMATRIX(floatmax_t) +FILL_SGMATRIX(complex128_t) +#undef FILL_SGMATRIX } diff --git a/src/shogun/lib/SGMatrix.h b/src/shogun/lib/SGMatrix.h index 1c33e106678..3841fbdfc84 100644 --- a/src/shogun/lib/SGMatrix.h +++ b/src/shogun/lib/SGMatrix.h @@ -463,6 +463,21 @@ template class SGMatrix : public SGReferencedData * @param saver File object via which to save data */ void save(CFile* saver); + + /** Serialize matrix with Cereal + * + * @param ar the Archive via which to save data + */ + template + void cereal_save(Archive& ar) const; + + /** Load matrix with Cereal + * + * @param ar the Archive via which to load data + */ + template + void cereal_load(Archive& ar); + #endif // #ifndef SWIG // SWIG should skip this part protected: diff --git a/src/shogun/lib/SGReferencedData.cpp b/src/shogun/lib/SGReferencedData.cpp index 31495bc40c5..dd282495d17 100644 --- a/src/shogun/lib/SGReferencedData.cpp +++ b/src/shogun/lib/SGReferencedData.cpp @@ -1,6 +1,12 @@ #include #include +#include +#include +#include +#include +#include + using namespace shogun; namespace shogun { @@ -112,4 +118,42 @@ int32_t SGReferencedData::unref() return c; } } + +template +void SGReferencedData::cereal_save(Archive& ar) const +{ + if (m_refcount != NULL) + ar(cereal::make_nvp("ref_counting", true), + cereal::make_nvp("refcount number", m_refcount->ref_count())); + else + ar(cereal::make_nvp("ref_counting", false)); +} +template void SGReferencedData::cereal_save( + cereal::BinaryOutputArchive& ar) const; +template void SGReferencedData::cereal_save( + cereal::JSONOutputArchive& ar) const; +template void SGReferencedData::cereal_save( + cereal::XMLOutputArchive& ar) const; + +template +void SGReferencedData::cereal_load(Archive& ar) +{ + bool ref_counting; + ar(ref_counting); + + if (ref_counting) + { + int32_t temp; + ar(temp); + m_refcount = new RefCount(temp); + } + else + m_refcount = NULL; +} +template void SGReferencedData::cereal_load( + cereal::BinaryInputArchive& ar); +template void SGReferencedData::cereal_load( + cereal::JSONInputArchive& ar); +template void SGReferencedData::cereal_load( + cereal::XMLInputArchive& ar); } diff --git a/src/shogun/lib/SGReferencedData.h b/src/shogun/lib/SGReferencedData.h index 7facd56d585..a747fce1e01 100644 --- a/src/shogun/lib/SGReferencedData.h +++ b/src/shogun/lib/SGReferencedData.h @@ -43,6 +43,14 @@ class SGReferencedData */ int32_t ref_count(); +#ifndef SWIG // SWIG should skip this part + template + void cereal_save(Archive& ar) const; + + template + void cereal_load(Archive& ar); +#endif //#ifndef SWIG // SWIG should skip this part + protected: /** copy refcount */ void copy_refcount(const SGReferencedData &orig); diff --git a/src/shogun/lib/SGVector.cpp b/src/shogun/lib/SGVector.cpp index e7f68e1b567..7474974f75c 100644 --- a/src/shogun/lib/SGVector.cpp +++ b/src/shogun/lib/SGVector.cpp @@ -13,18 +13,24 @@ * Copyright (C) 2012 Soeren Sonnenburg */ -#include -#include +#include #include -#include #include -#include +#include +#include +#include #include #include #include #include +#include +#include +#include +#include +#include + #define COMPLEX128_ERROR_NOARG(function) \ template <> \ void SGVector::function() \ @@ -937,6 +943,56 @@ void SGVector::save(CFile* saver) SG_SERROR("SGVector::save():: Not supported for complex128_t\n"); } +template +template +void SGVector::cereal_save(Archive& ar) const +{ + ar(cereal::make_nvp( + "ReferencedData", cereal::base_class(this))); + ar(cereal::make_nvp("length", vlen)); + for (index_t i = 0; i < vlen; ++i) + ar(vector[i]); +} + +template <> +template +void SGVector::cereal_save(Archive& ar) const +{ + ar(cereal::base_class(this)); + ar(cereal::make_nvp("length", vlen)); + + float64_t* temp = reinterpret_cast(vector); + for (index_t i = 0; i < vlen * 2; ++i) + ar(temp[i]); +} + +template +template +void SGVector::cereal_load(Archive& ar) +{ + unref(); + ar(cereal::base_class(this)); + + ar(vlen); + vector = SG_MALLOC(T, vlen); + for (index_t i = 0; i < vlen; ++i) + ar(vector[i]); +} + +template <> +template +void SGVector::cereal_load(Archive& ar) +{ + unref(); + ar(cereal::base_class(this)); + + ar(vlen); + vector = SG_MALLOC(complex128_t, vlen); + float64_t* temp = reinterpret_cast(vector); + for (index_t i = 0; i < vlen * 2; ++i) + ar(temp[i]); +} + template SGVector SGVector::get_real() { assert_on_cpu(); @@ -1032,20 +1088,35 @@ UNDEFINED(get_imag, float64_t) UNDEFINED(get_imag, floatmax_t) #undef UNDEFINED -template class SGVector; -template class SGVector; -template class SGVector; -template class SGVector; -template class SGVector; -template class SGVector; -template class SGVector; -template class SGVector; -template class SGVector; -template class SGVector; -template class SGVector; -template class SGVector; -template class SGVector; -template class SGVector; +#define FILL_SGVECTOR(typetype) \ + template class SGVector; \ + template void SGVector::cereal_save( \ + cereal::BinaryOutputArchive& ar) const; \ + template void SGVector::cereal_save( \ + cereal::JSONOutputArchive& ar) const; \ + template void SGVector::cereal_save( \ + cereal::XMLOutputArchive& ar) const; \ + template void SGVector::cereal_load( \ + cereal::BinaryInputArchive& ar); \ + template void SGVector::cereal_load( \ + cereal::JSONInputArchive& ar); \ + template void SGVector::cereal_load(cereal::XMLInputArchive& ar); + +FILL_SGVECTOR(bool) +FILL_SGVECTOR(char) +FILL_SGVECTOR(int8_t) +FILL_SGVECTOR(uint8_t) +FILL_SGVECTOR(int16_t) +FILL_SGVECTOR(uint16_t) +FILL_SGVECTOR(int32_t) +FILL_SGVECTOR(uint32_t) +FILL_SGVECTOR(int64_t) +FILL_SGVECTOR(uint64_t) +FILL_SGVECTOR(float32_t) +FILL_SGVECTOR(float64_t) +FILL_SGVECTOR(floatmax_t) +FILL_SGVECTOR(complex128_t) +#undef FILL_SGVECTOR } #undef COMPLEX128_ERROR_NOARG diff --git a/src/shogun/lib/SGVector.h b/src/shogun/lib/SGVector.h index 980f5be0031..c0db4067a17 100644 --- a/src/shogun/lib/SGVector.h +++ b/src/shogun/lib/SGVector.h @@ -14,12 +14,12 @@ #ifndef __SGVECTOR_H__ #define __SGVECTOR_H__ -#include - #include -#include #include +#include +#include #include + #include #include @@ -540,6 +540,13 @@ template class SGVector : public SGReferencedData * @return matrix */ static void convert_to_matrix(T*& matrix, index_t nrows, index_t ncols, const T* vector, int32_t vlen, bool fortran_order); + + template + void cereal_save(Archive& ar) const; + + template + void cereal_load(Archive& ar); + #endif // #ifndef SWIG // SWIG should skip this part protected: /** needs to be overridden to copy data */ diff --git a/src/shogun/lib/any.h b/src/shogun/lib/any.h index 689b72d99cf..750c76207f1 100644 --- a/src/shogun/lib/any.h +++ b/src/shogun/lib/any.h @@ -35,6 +35,7 @@ #ifndef _ANY_H_ #define _ANY_H_ +#include #include #include #include @@ -45,6 +46,78 @@ namespace shogun { +#ifndef SWIG // SWIG should skip this part + namespace serial + { + enum EnumContainerType + { + CT_UNDEFINED, + CT_PRIMITIVE, + CT_SGVECTOR, + CT_SGMATRIX + }; + + enum EnumPrimitiveType + { + PT_UNDEFINED, + PT_INT_32, + PT_FLOAT_64, + }; + + /** cast data type to EnumContainerType and EnumPrimitiveType */ + template + struct Type2Enum + { + static constexpr EnumContainerType e_containertype = CT_UNDEFINED; + static constexpr EnumPrimitiveType e_primitivetype = PT_UNDEFINED; + }; + + template <> + struct Type2Enum + { + static constexpr EnumContainerType e_containertype = CT_PRIMITIVE; + static constexpr EnumPrimitiveType e_primitivetype = PT_INT_32; + }; + + template <> + struct Type2Enum + { + static constexpr EnumContainerType e_containertype = CT_PRIMITIVE; + static constexpr EnumPrimitiveType e_primitivetype = PT_FLOAT_64; + }; + + template <> + struct Type2Enum> + { + static constexpr EnumContainerType e_containertype = CT_SGVECTOR; + static constexpr EnumPrimitiveType e_primitivetype = PT_INT_32; + }; + + template <> + struct Type2Enum> + { + static constexpr EnumContainerType e_containertype = CT_SGVECTOR; + static constexpr EnumPrimitiveType e_primitivetype = PT_FLOAT_64; + }; + + /** @brief data structure that saves the EnumContainerType + * and EnumPrimitiveType information of Any object + */ + struct DataType + { + EnumContainerType e_containertype; + EnumPrimitiveType e_primitivetype; + + template + void set() + { + e_containertype = Type2Enum::e_containertype; + e_primitivetype = Type2Enum::e_primitivetype; + } + }; + } +#endif // SWIG + /** Converts compiler-dependent name of class to * something human readable. * @return human readable name of class @@ -455,6 +528,139 @@ namespace shogun policy->clear(&storage); } +#ifndef SWIG // SWIG should skip this part + /** Cast storage data to selected type and save the data to Archive + * + * @param ar Archive type + */ + template + void cereal_save_helper(Archive& ar) const + { + ar(*(reinterpret_cast(storage))); + } + + /** save data with cereal save method + * + * @param ar Archive type + */ + template + void cereal_save(Archive& ar) const + { + ar(m_datatype.e_containertype); + ar(m_datatype.e_primitivetype); + switch (m_datatype.e_containertype) + { + case serial::EnumContainerType::CT_PRIMITIVE: + switch (m_datatype.e_primitivetype) + { + case serial::EnumPrimitiveType::PT_INT_32: + cereal_save_helper(ar); + break; + case serial::EnumPrimitiveType::PT_FLOAT_64: + cereal_save_helper(ar); + break; + case serial::EnumPrimitiveType::PT_UNDEFINED: + SG_SERROR( + "Type error: undefined data type cannot be " + "serialized.\n"); + break; + } + break; + case serial::EnumContainerType::CT_SGVECTOR: + switch (m_datatype.e_primitivetype) + { + case serial::EnumPrimitiveType::PT_INT_32: + cereal_save_helper>(ar); + break; + case serial::EnumPrimitiveType::PT_FLOAT_64: + cereal_save_helper>(ar); + break; + case serial::EnumPrimitiveType::PT_UNDEFINED: + SG_SERROR( + "Type error: undefined data type cannot be " + "serialized.\n"); + break; + } + break; + case serial::EnumContainerType::CT_SGMATRIX: + SG_SWARNING("SGMatrix serializatino method not implemented.\n"); + break; + case serial::EnumContainerType::CT_UNDEFINED: + SG_SERROR( + "Type error: undefined container type cannot be " + "serialized.\n"); + break; + } + } + + /** Load data from archive and cast to Any type + * + * @param ar Archive type + */ + template + void cereal_load_helper(Archive& ar) + { + Type temp; + ar(temp); + policy->clear(&storage); + policy->set(&storage, &temp); + } + + /** load data from archive with cereal load method + * + * @param ar Archive type + */ + template + void cereal_load(Archive& ar) + { + ar(m_datatype.e_containertype); + ar(m_datatype.e_primitivetype); + switch (m_datatype.e_containertype) + { + case serial::EnumContainerType::CT_PRIMITIVE: + switch (m_datatype.e_primitivetype) + { + case serial::EnumPrimitiveType::PT_INT_32: + cereal_load_helper(ar); + break; + case serial::EnumPrimitiveType::PT_FLOAT_64: + cereal_load_helper(ar); + break; + default: + SG_SERROR("Error: undefined data type cannot be loaded.\n"); + break; + } + break; + + case serial::EnumContainerType::CT_SGVECTOR: + switch (m_datatype.e_primitivetype) + { + case serial::EnumPrimitiveType::PT_INT_32: + cereal_load_helper>(ar); + break; + case serial::EnumPrimitiveType::PT_FLOAT_64: + cereal_load_helper>(ar); + break; + case serial::EnumPrimitiveType::PT_UNDEFINED: + SG_SERROR( + "Type error: undefined data type cannot be " + "serialized.\n"); + break; + } + break; + + case serial::EnumContainerType::CT_SGMATRIX: + SG_SWARNING("SGMatrix serializatino method not implemented.\n") + + default: + SG_SERROR( + "Error: undefined container type cannot be serialize " + "loaded.\n"); + break; + } + } +#endif // SWIG + /** Casts hidden value to provided type, fails otherwise. * @return type-casted value */ @@ -516,6 +722,11 @@ namespace shogun private: BaseAnyPolicy* policy; void* storage; + +#ifndef SWIG // SWIG should skip this part + /** Enum structure that saves the type information of Any */ + serial::DataType m_datatype; +#endif //#ifndef SWIG }; inline bool operator==(const Any& lhs, const Any& rhs) diff --git a/src/shogun/lib/common.h b/src/shogun/lib/common.h index e8787f816ce..5a55ada438d 100644 --- a/src/shogun/lib/common.h +++ b/src/shogun/lib/common.h @@ -15,6 +15,16 @@ #ifndef __COMMON_H__ #define __COMMON_H__ +#ifdef CEREAL_SAVE_FUNCTION_NAME +#undef CEREAL_SAVE_FUNCTION_NAME +#endif +#define CEREAL_SAVE_FUNCTION_NAME cereal_save + +#ifdef CEREAL_LOAD_FUNCTION_NAME +#undef CEREAL_LOAD_FUNCTION_NAME +#endif +#define CEREAL_LOAD_FUNCTION_NAME cereal_load + #include #include #include diff --git a/tests/unit/io/CerealObject.h b/tests/unit/io/CerealObject.h new file mode 100644 index 00000000000..e9bf75dc013 --- /dev/null +++ b/tests/unit/io/CerealObject.h @@ -0,0 +1,40 @@ +#include +#include + +namespace shogun +{ + + /** @brief Used to test the SGObject serialization */ + class CCerealObject : public CSGObject + { + public: + // Construct CCerealObject from input SGVector + CCerealObject(SGVector vec) : CSGObject() + { + m_vector = vec; + init_params(); + } + + // Default constructor + CCerealObject() : CSGObject() + { + m_vector = SGVector(5); + m_vector.set_const(0); + init_params(); + } + + const char* get_name() const + { + return "CerealObject"; + } + + protected: + // Register m_vector to parameter list with name(tag) "test_vector" + void init_params() + { + register_param("test_vector", m_vector); + } + + SGVector m_vector; + }; +} diff --git a/tests/unit/io/Cereal_unittest.cc b/tests/unit/io/Cereal_unittest.cc new file mode 100644 index 00000000000..1e7f1e09127 --- /dev/null +++ b/tests/unit/io/Cereal_unittest.cc @@ -0,0 +1,177 @@ +#include "CerealObject.h" +#include +#include + +#include +#include +#include + +#include +#include +#include + +using namespace shogun; + +#ifndef SWIG // SWIG should skip this part +TEST(Cereal, Json_SGVector_float64_load_equals_saved) +{ + const index_t size = 5; + SGVector a(size); + SGVector b; + a.range_fill(1.0); + + std::string filename = std::tmpnam(nullptr); + + try + { + { + std::ofstream os(filename.c_str()); + cereal::JSONOutputArchive archive(os); + archive(a); + } + + { + std::ifstream is(filename.c_str()); + cereal::JSONInputArchive archive(is); + archive(b); + } + } + catch (std::exception& e) SG_SINFO("Error code: %s \n", e.what()); + + EXPECT_EQ(a.size(), b.size()); + EXPECT_EQ(a.ref_count(), b.ref_count()); + for (index_t i = 0; i < size; i++) + EXPECT_NEAR(a[i], b[i], 1E-15); + + remove(filename.c_str()); +} + +TEST(Cereal, Json_SGMatrix_float64_load_equals_saved) +{ + const index_t nrows = 2, ncols = 3; + SGMatrix a(nrows, ncols); + SGMatrix b; + + for (index_t i = 0; i < nrows * ncols; i++) + a[i] = i; + + std::string filename = std::tmpnam(nullptr); + + try + { + { + std::ofstream os(filename.c_str()); + cereal::JSONOutputArchive archive(os); + archive(a); + } + + { + std::ifstream is(filename.c_str()); + cereal::JSONInputArchive archive(is); + archive(b); + } + } + catch (std::exception& e) SG_SINFO("Error code: %s \n", e.what()); + + EXPECT_EQ(a.num_rows, b.num_rows); + EXPECT_EQ(a.num_cols, b.num_cols); + EXPECT_EQ(a.ref_count(), b.ref_count()); + for (index_t i = 0; i < nrows * ncols; i++) + EXPECT_NEAR(a[i], b[i], 1E-15); + + remove(filename.c_str()); +} + +TEST(Cereal, Json_SGVector_complex128_load_equals_saved) +{ + const index_t size = 5; + SGVector a(size); + SGVector b(size); + + for (index_t i = 0; i < size; ++i) + a[i] = complex128_t(i, i * 2); + + std::string filename = std::tmpnam(nullptr); + + try + { + { + std::ofstream os(filename.c_str()); + cereal::JSONOutputArchive archive(os); + archive(a); + } + + { + std::ifstream is(filename.c_str()); + cereal::JSONInputArchive archive(is); + archive(b); + } + } + catch (std::exception& e) SG_SINFO("Error code: %s \n", e.what()); + + EXPECT_EQ(a.size(), b.size()); + for (index_t i = 0; i < size; i++) + { + EXPECT_NEAR(i, b[i].real(), 1E-15); + EXPECT_NEAR(i * 2, b[i].imag(), 1E-15); + } + + remove(filename.c_str()); +} + +TEST(Cereal, Json_SGVector_load_equals_saved_refcounting_false) +{ + const index_t size = 5; + SGVector a(size, false); + SGVector b; + a.range_fill(1.0); + + std::string filename = std::tmpnam(nullptr); + + try + { + { + std::ofstream os(filename.c_str()); + cereal::JSONOutputArchive archive(os); + archive(a); + } + + { + std::ifstream is(filename.c_str()); + cereal::JSONInputArchive archive(is); + archive(b); + } + } + catch (std::exception& e) SG_SINFO("Error code: %s \n", e.what()); + + EXPECT_EQ(a.size(), b.size()); + EXPECT_EQ(-1, b.ref_count()); + for (index_t i = 0; i < size; i++) + EXPECT_NEAR(a[i], b[i], 1E-15); + + remove(filename.c_str()); +} + +TEST(Cereal, Json_CerealObject_load_equals_saved) +{ + SGVector A(5); + A.range_fill(0); + SGVector B(5); + CCerealObject obj_save(A); + CCerealObject obj_load; + + std::string filename = std::tmpnam(nullptr); + + obj_save.save_json(filename.c_str()); + + obj_load.load_json(filename.c_str()); + B = obj_load.get>("test_vector"); + + EXPECT_EQ(A.size(), B.size()); + for (index_t i = 0; i < 5; ++i) + EXPECT_NEAR(A[i], B[i], 1e-15); + + remove(filename.c_str()); +} + +#endif // ifndef SWIG