diff --git a/CMakeLists.txt b/CMakeLists.txt index 2ce57a69fc5..76ceb7f07c1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -403,6 +403,14 @@ IF (GDB_FOUND) SET(GDB_DEFAULT_SCRIPT ${CMAKE_CURRENT_SOURCE_DIR}/src/.gdb) ENDIF() +FIND_PACKAGE(Cereal) +IF(NOT CEREAL_FOUND) + include(external/Cereal) + LIST(APPEND INCLUDES ${CEREAL_INCLUDE_DIRS}) +ELSE() + LIST(APPEND INCLUDES ${CEREAL_INCLUDE_DIRS}) +ENDIF() + FIND_PACKAGE(Doxygen 1.8.6) IF(DOXYGEN_FOUND) SET(HAVE_DOXYGEN 1) diff --git a/cmake/FindCereal.cmake b/cmake/FindCereal.cmake new file mode 100644 index 00000000000..5ed51bd3e35 --- /dev/null +++ b/cmake/FindCereal.cmake @@ -0,0 +1,19 @@ +# - Try to find Cereal Serialization Library +# +# This sets the following variables: +# CEREAL_FOUND - True if Cereal was found. +# CEREAL_INCLUDE_DIRS - Directories containing the Cereal include files. + +find_path(CEREAL_INCLUDE_DIR cereal + HINTS "$ENV{CMAKE_SOURCE_DIR}/include" "/usr/include" "$ENV{CMAKE_BINARY_DIR}/cereal/include") + +set(CEREAL_INCLUDE_DIRS ${CEREAL_INCLUDE_DIR}) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(Cereal DEFAULT_MSG CEREAL_INCLUDE_DIR) + +mark_as_advanced(CEREAL_INCLUDE_DIR) + +if(CEREAL_FOUND) + MESSAGE(STATUS "Found Cereal: ${CEREAL_INCLUDE_DIRS}") +endif(CEREAL_FOUND) diff --git a/cmake/external/Cereal.cmake b/cmake/external/Cereal.cmake new file mode 100644 index 00000000000..d8e3603b1e9 --- /dev/null +++ b/cmake/external/Cereal.cmake @@ -0,0 +1,14 @@ +include(ExternalProject) +ExternalProject_Add( + Cereal + PREFIX ${CMAKE_BINARY_DIR}/Cereal + DOWNLOAD_DIR ${THIRD_PARTY_DIR}/Cereal + URL https://github.com/USCiLab/cereal/archive/v1.2.0.tar.gz + URL_MD5 e372c9814696481dbdb7d500e1410d2b + CMAKE_ARGS -DCMAKE_C_FLAGS:STRING=${CMAKE_C_FLAGS}${CMAKE_DEFINITIONS} + -DCMAKE_CXX_FLAGS:STRING=${CMAKE_CXX_FLAGS}${CMAKE_DEFINITIONS} + INSTALL_COMMAND "" + ) + +SET(CEREAL_INCLUDE_DIRS ${CMAKE_BINARY_DIR}/Cereal/src/Cereal/include) +LIST(APPEND SHOGUN_DEPENDS Cereal) diff --git a/src/shogun/base/AnyParameter.h b/src/shogun/base/AnyParameter.h index 23a1ba44457..c5b138aff6d 100644 --- a/src/shogun/base/AnyParameter.h +++ b/src/shogun/base/AnyParameter.h @@ -60,6 +60,16 @@ namespace shogun return m_gradient; } + /** serialize the object using cereal + * + * @param ar Archive type + */ + template + void serialize(Archive& ar) + { + ar(m_model_selection, m_gradient); + } + private: std::string m_description; EModelSelectionAvailability m_model_selection; @@ -112,6 +122,16 @@ namespace shogun return !(*this == other); } + /** serialize the object using cereal + * + * @param ar Archive type + */ + template + void serialize(Archive& ar) + { + ar(m_value, m_properties); + } + private: Any m_value; AnyParameterProperties m_properties; diff --git a/src/shogun/base/SGObject.cpp b/src/shogun/base/SGObject.cpp index 140c1d90532..b2c9c48f115 100644 --- a/src/shogun/base/SGObject.cpp +++ b/src/shogun/base/SGObject.cpp @@ -32,6 +32,15 @@ #include #include + +#include +#include +#include +#include +#include +#include +#include + #include #include @@ -235,6 +244,74 @@ int32_t CSGObject::unref() } } +void CSGObject::save_binary(const char* filename) const +{ + std::ofstream os(filename); + cereal::BinaryOutputArchive archive(os); + archive(cereal::make_nvp(this->get_name(), *this)); +} + +void CSGObject::save_json(const char* filename) const +{ + std::ofstream os(filename); + cereal::JSONOutputArchive archive(os); + archive(cereal::make_nvp(this->get_name(), *this)); +} + +void CSGObject::save_xml(const char* filename) const +{ + std::ofstream os(filename); + cereal::XMLOutputArchive archive(os); + archive(cereal::make_nvp(this->get_name(), *this)); +} + +void CSGObject::load_binary(const char* filename) +{ + std::ifstream is(filename); + cereal::BinaryInputArchive archive(is); + archive(*this); +} + +void CSGObject::load_json(const char* filename) +{ + std::ifstream is(filename); + cereal::JSONInputArchive archive(is); + archive(*this); +} + +void CSGObject::load_xml(const char* filename) +{ + std::ifstream is(filename); + cereal::XMLInputArchive archive(is); + archive(*this); +} + +template +void CSGObject::cereal_save(Archive& ar) const +{ + for (const auto& it : self->map) + ar(cereal::make_nvp(it.first.name(), it.second)); +} +template void CSGObject::cereal_save( + cereal::BinaryOutputArchive& ar) const; +template void CSGObject::cereal_save( + cereal::JSONOutputArchive& ar) const; +template void CSGObject::cereal_save( + cereal::XMLOutputArchive& ar) const; + +template +void CSGObject::cereal_load(Archive& ar) +{ + for (auto& it : self->map) + ar(it.second); +} +template void CSGObject::cereal_load( + cereal::BinaryInputArchive& ar); +template void +CSGObject::cereal_load(cereal::JSONInputArchive& ar); +template void +CSGObject::cereal_load(cereal::XMLInputArchive& ar); + #ifdef TRACE_MEMORY_ALLOCS #include extern CMap* sg_mallocs; diff --git a/src/shogun/base/SGObject.h b/src/shogun/base/SGObject.h index d76e0df399d..0f8325d2656 100644 --- a/src/shogun/base/SGObject.h +++ b/src/shogun/base/SGObject.h @@ -140,6 +140,58 @@ class CSGObject /** destructor */ virtual ~CSGObject(); +#ifndef SWIG // SWIG should skip this part + /** serializes the SGObject to a binary file + * + * @param filename Binary archive filename + */ + void save_binary(const char* filename) const; + + /** serializes the SGObject to a JSON file + * + * @param filename JSON archive filename + */ + void save_json(const char* filename) const; + + /** serializes the SGObject to a XML file + * + * @param filename XML archive filename + */ + void save_xml(const char* filename) const; + + /** loads SGObject from a Binary file + * + * @param filename Binary archive filename + */ + void load_binary(const char* filename); + + /** loads SGObject from a JSON file + * + * @param filename JSON archive filename + */ + void load_json(const char* filename); + + /** loads SGObject from a XML file + * + * @param filename XML archive filename + */ + void load_xml(const char* filename); + + /** serializes SGObject parameters to Archive with Cereal + * + * @param ar Archive + */ + template + void cereal_save(Archive& ar) const; + + /** loads SGObject parameters from Archive with Cereal + * + * @param ar Archive + */ + template + void cereal_load(Archive& ar); +#endif // #ifndef SWIG // SWIG should skip this part + /** increase reference counter * * @return reference count diff --git a/src/shogun/lib/SGMatrix.cpp b/src/shogun/lib/SGMatrix.cpp index afcc450f8f4..48e4753716a 100644 --- a/src/shogun/lib/SGMatrix.cpp +++ b/src/shogun/lib/SGMatrix.cpp @@ -1,9 +1,9 @@ /* * This software is distributed under BSD 3-clause license (see LICENSE file). * - * Authors: Heiko Strathmann, Soeren Sonnenburg, Soumyajit De, Thoralf Klein, - * Pan Deng, Fernando Iglesias, Sergey Lisitsyn, Viktor Gal, - * Michele Mazzoni, Yingrui Chang, Weijie Lin, Khaled Nasr, + * Authors: Heiko Strathmann, Soeren Sonnenburg, Soumyajit De, Thoralf Klein, + * Pan Deng, Fernando Iglesias, Sergey Lisitsyn, Viktor Gal, + * Michele Mazzoni, Yingrui Chang, Weijie Lin, Khaled Nasr, * Koen van de Sande, Roman Votyakov */ @@ -18,145 +18,160 @@ #include #include +#include +#include +#include +#include +#include + namespace shogun { -template -SGMatrix::SGMatrix() : SGReferencedData() -{ - init_data(); -} - -template -SGMatrix::SGMatrix(bool ref_counting) : SGReferencedData(ref_counting) -{ - init_data(); -} + template + SGMatrix::SGMatrix() : SGReferencedData() + { + init_data(); + } -template -SGMatrix::SGMatrix(T* m, index_t nrows, index_t ncols, bool ref_counting) - : SGReferencedData(ref_counting), matrix(m), - num_rows(nrows), num_cols(ncols), gpu_ptr(nullptr) -{ - m_on_gpu.store(false, std::memory_order_release); -} + template + SGMatrix::SGMatrix(bool ref_counting) : SGReferencedData(ref_counting) + { + init_data(); + } -template -SGMatrix::SGMatrix(T* m, index_t nrows, index_t ncols, index_t offset) - : SGReferencedData(false), matrix(m+offset), - num_rows(nrows), num_cols(ncols) -{ - m_on_gpu.store(false, std::memory_order_release); -} + template + SGMatrix::SGMatrix(T* m, index_t nrows, index_t ncols, bool ref_counting) + : SGReferencedData(ref_counting), matrix(m), num_rows(nrows), + num_cols(ncols), gpu_ptr(nullptr) + { + m_on_gpu.store(false, std::memory_order_release); + } -template -SGMatrix::SGMatrix(index_t nrows, index_t ncols, bool ref_counting) - : SGReferencedData(ref_counting), num_rows(nrows), num_cols(ncols), gpu_ptr(nullptr) -{ - matrix=SG_CALLOC(T, ((int64_t) nrows)*ncols); - m_on_gpu.store(false, std::memory_order_release); -} + template + SGMatrix::SGMatrix(T* m, index_t nrows, index_t ncols, index_t offset) + : SGReferencedData(false), matrix(m + offset), num_rows(nrows), + num_cols(ncols) + { + m_on_gpu.store(false, std::memory_order_release); + } -template -SGMatrix::SGMatrix(SGVector vec) : SGReferencedData(vec) -{ - REQUIRE((vec.vector || vec.gpu_ptr), "Vector not initialized!\n"); - matrix=vec.vector; - num_rows=vec.vlen; - num_cols=1; - gpu_ptr = vec.gpu_ptr; - m_on_gpu.store(vec.on_gpu(), std::memory_order_release); -} + template + SGMatrix::SGMatrix(index_t nrows, index_t ncols, bool ref_counting) + : SGReferencedData(ref_counting), num_rows(nrows), num_cols(ncols), + gpu_ptr(nullptr) + { + matrix = SG_CALLOC(T, ((int64_t)nrows) * ncols); + m_on_gpu.store(false, std::memory_order_release); + } -template -SGMatrix::SGMatrix(SGVector vec, index_t nrows, index_t ncols) -: SGReferencedData(vec) -{ - REQUIRE((vec.vector || vec.gpu_ptr), "Vector not initialized!\n"); - REQUIRE(nrows>0, "Number of rows (%d) has to be a positive integer!\n", nrows); - REQUIRE(ncols>0, "Number of cols (%d) has to be a positive integer!\n", ncols); - REQUIRE(vec.vlen==nrows*ncols, "Number of elements in the matrix (%d) must " - "be the same as the number of elements in the vector (%d)!\n", - nrows*ncols, vec.vlen); + template + SGMatrix::SGMatrix(SGVector vec) : SGReferencedData(vec) + { + REQUIRE((vec.vector || vec.gpu_ptr), "Vector not initialized!\n"); + matrix = vec.vector; + num_rows = vec.vlen; + num_cols = 1; + gpu_ptr = vec.gpu_ptr; + m_on_gpu.store(vec.on_gpu(), std::memory_order_release); + } - matrix=vec.vector; - num_rows=nrows; - num_cols=ncols; - gpu_ptr = vec.gpu_ptr; - m_on_gpu.store(vec.on_gpu(), std::memory_order_release); -} + template + SGMatrix::SGMatrix(SGVector vec, index_t nrows, index_t ncols) + : SGReferencedData(vec) + { + REQUIRE((vec.vector || vec.gpu_ptr), "Vector not initialized!\n"); + REQUIRE( + nrows > 0, "Number of rows (%d) has to be a positive integer!\n", + nrows); + REQUIRE( + ncols > 0, "Number of cols (%d) has to be a positive integer!\n", + ncols); + REQUIRE( + vec.vlen == nrows * ncols, + "Number of elements in the matrix (%d) must " + "be the same as the number of elements in the vector (%d)!\n", + nrows * ncols, vec.vlen); + + matrix = vec.vector; + num_rows = nrows; + num_cols = ncols; + gpu_ptr = vec.gpu_ptr; + m_on_gpu.store(vec.on_gpu(), std::memory_order_release); + } -template -SGMatrix::SGMatrix(GPUMemoryBase* mat, index_t nrows, index_t ncols) - : SGReferencedData(true), matrix(NULL), num_rows(nrows), num_cols(ncols), - gpu_ptr(std::shared_ptr>(mat)) -{ - m_on_gpu.store(true, std::memory_order_release); -} + template + SGMatrix::SGMatrix(GPUMemoryBase* mat, index_t nrows, index_t ncols) + : SGReferencedData(true), matrix(NULL), num_rows(nrows), + num_cols(ncols), gpu_ptr(std::shared_ptr>(mat)) + { + m_on_gpu.store(true, std::memory_order_release); + } -template -SGMatrix::SGMatrix(const SGMatrix &orig) : SGReferencedData(orig) -{ - copy_data(orig); -} + template + SGMatrix::SGMatrix(const SGMatrix& orig) : SGReferencedData(orig) + { + copy_data(orig); + } -template -SGMatrix::SGMatrix(EigenMatrixXt& mat) -: SGReferencedData(false), matrix(mat.data()), - num_rows(mat.rows()), num_cols(mat.cols()), gpu_ptr(nullptr) -{ - m_on_gpu.store(false, std::memory_order_release); -} + template + SGMatrix::SGMatrix(EigenMatrixXt& mat) + : SGReferencedData(false), matrix(mat.data()), num_rows(mat.rows()), + num_cols(mat.cols()), gpu_ptr(nullptr) + { + m_on_gpu.store(false, std::memory_order_release); + } -template -SGMatrix::operator EigenMatrixXtMap() const -{ - assert_on_cpu(); - return EigenMatrixXtMap(matrix, num_rows, num_cols); -} + template + SGMatrix::operator EigenMatrixXtMap() const + { + assert_on_cpu(); + return EigenMatrixXtMap(matrix, num_rows, num_cols); + } -template -SGMatrix& SGMatrix::operator=(const SGMatrix& other) -{ - if(&other == this) - return *this; + template + SGMatrix& SGMatrix::operator=(const SGMatrix& other) + { + if (&other == this) + return *this; + + unref(); + copy_data(other); + copy_refcount(other); + ref(); + return *this; + } - unref(); - copy_data(other); - copy_refcount(other); - ref(); - return *this; -} + template + SGMatrix::~SGMatrix() + { + unref(); + } -template -SGMatrix::~SGMatrix() -{ - unref(); -} + template + bool SGMatrix::equals(const SGMatrix& other) const + { + // avoid comparing elements when both are same. + // the case where both matrices are uninitialized is handled here as + // well. + if (*this == other) + return true; -template -bool SGMatrix::equals(const SGMatrix& other) const -{ - // avoid comparing elements when both are same. - // the case where both matrices are uninitialized is handled here as well. - if (*this==other) - return true; - // both empty - if (!(num_rows || num_cols || other.num_rows || other.num_cols)) - return true; + // both empty + if (!(num_rows || num_cols || other.num_rows || other.num_cols)) + return true; - // only one empty - if (!matrix || !other.matrix) - return false; + // only one empty + if (!matrix || !other.matrix) + return false; - // different size - if (num_rows!=other.num_rows || num_cols!=other.num_cols) - return false; + // different size + if (num_rows!=other.num_rows || num_cols!=other.num_cols) + return false; - // content - return std::equal(matrix, matrix+size(), other.matrix); -} + // content + return std::equal(matrix, matrix+size(), other.matrix); + } #ifndef REAL_EQUALS #define REAL_EQUALS(real_t) \ @@ -1213,7 +1228,61 @@ void SGMatrix::save(CFile* saver) SG_SERROR("SGMatrix::save():: Not supported for complex128_t\n"); } -template +template +template +void SGMatrix::cereal_save(Archive& ar) const +{ + ar(cereal::base_class(this)); + ar(cereal::make_nvp("num_rows", num_rows)); + ar(cereal::make_nvp("num_cols", num_cols)); + + for (index_t i = 0; i < num_rows * num_cols; ++i) + ar(matrix[i]); +} + +template <> +template +void SGMatrix::cereal_save(Archive& ar) const +{ + ar(cereal::base_class(this)); + ar(cereal::make_nvp("num_rows", num_rows)); + ar(cereal::make_nvp("num_cols", num_cols)); + + float64_t* temp = reinterpret_cast(matrix); + for (index_t i = 0; i < num_rows * num_cols * 2; ++i) + ar(temp[i]); +} + +template +template +void SGMatrix::cereal_load(Archive& ar) +{ + unref(); + ar(cereal::base_class(this)); + + ar(num_rows); + ar(num_cols); + matrix = SG_MALLOC(T, num_rows * num_cols); + for (index_t i = 0; i < num_rows * num_cols; ++i) + ar(matrix[i]); +} + +template <> +template +void SGMatrix::cereal_load(Archive& ar) +{ + unref(); + ar(cereal::base_class(this)); + + ar(num_rows); + ar(num_cols); + matrix = SG_MALLOC(complex128_t, num_rows * num_cols); + float64_t* temp = reinterpret_cast(matrix); + for (index_t i = 0; i < num_rows * num_cols * 2; ++i) + ar(temp[i]); +} + +template SGVector SGMatrix::get_row_vector(index_t row) const { assert_on_cpu(); @@ -1240,18 +1309,33 @@ SGVector SGMatrix::get_diagonal_vector() const return diag; } -template class SGMatrix; -template class SGMatrix; -template class SGMatrix; -template class SGMatrix; -template class SGMatrix; -template class SGMatrix; -template class SGMatrix; -template class SGMatrix; -template class SGMatrix; -template class SGMatrix; -template class SGMatrix; -template class SGMatrix; -template class SGMatrix; -template class SGMatrix; +#define FILL_SGMATRIX(typetype) \ + template class SGMatrix; \ + template void SGMatrix::cereal_save( \ + cereal::BinaryOutputArchive& ar) const; \ + template void SGMatrix::cereal_save( \ + cereal::JSONOutputArchive& ar) const; \ + template void SGMatrix::cereal_save( \ + cereal::XMLOutputArchive& ar) const; \ + template void SGMatrix::cereal_load( \ + cereal::BinaryInputArchive& ar); \ + template void SGMatrix::cereal_load( \ + cereal::JSONInputArchive& ar); \ + template void SGMatrix::cereal_load(cereal::XMLInputArchive& ar); + +FILL_SGMATRIX(bool) +FILL_SGMATRIX(char) +FILL_SGMATRIX(int8_t) +FILL_SGMATRIX(uint8_t) +FILL_SGMATRIX(int16_t) +FILL_SGMATRIX(uint16_t) +FILL_SGMATRIX(int32_t) +FILL_SGMATRIX(uint32_t) +FILL_SGMATRIX(int64_t) +FILL_SGMATRIX(uint64_t) +FILL_SGMATRIX(float32_t) +FILL_SGMATRIX(float64_t) +FILL_SGMATRIX(floatmax_t) +FILL_SGMATRIX(complex128_t) +#undef FILL_SGMATRIX } diff --git a/src/shogun/lib/SGMatrix.h b/src/shogun/lib/SGMatrix.h index c6dbe27c599..241737efbb9 100644 --- a/src/shogun/lib/SGMatrix.h +++ b/src/shogun/lib/SGMatrix.h @@ -459,6 +459,21 @@ template class SGMatrix : public SGReferencedData * @param saver File object via which to save data */ void save(CFile* saver); + + /** Serialize matrix with Cereal + * + * @param ar the Archive via which to save data + */ + template + void cereal_save(Archive& ar) const; + + /** Load matrix with Cereal + * + * @param ar the Archive via which to load data + */ + template + void cereal_load(Archive& ar); + #endif // #ifndef SWIG // SWIG should skip this part protected: diff --git a/src/shogun/lib/SGReferencedData.cpp b/src/shogun/lib/SGReferencedData.cpp index 31495bc40c5..dd282495d17 100644 --- a/src/shogun/lib/SGReferencedData.cpp +++ b/src/shogun/lib/SGReferencedData.cpp @@ -1,6 +1,12 @@ #include #include +#include +#include +#include +#include +#include + using namespace shogun; namespace shogun { @@ -112,4 +118,42 @@ int32_t SGReferencedData::unref() return c; } } + +template +void SGReferencedData::cereal_save(Archive& ar) const +{ + if (m_refcount != NULL) + ar(cereal::make_nvp("ref_counting", true), + cereal::make_nvp("refcount number", m_refcount->ref_count())); + else + ar(cereal::make_nvp("ref_counting", false)); +} +template void SGReferencedData::cereal_save( + cereal::BinaryOutputArchive& ar) const; +template void SGReferencedData::cereal_save( + cereal::JSONOutputArchive& ar) const; +template void SGReferencedData::cereal_save( + cereal::XMLOutputArchive& ar) const; + +template +void SGReferencedData::cereal_load(Archive& ar) +{ + bool ref_counting; + ar(ref_counting); + + if (ref_counting) + { + int32_t temp; + ar(temp); + m_refcount = new RefCount(temp); + } + else + m_refcount = NULL; +} +template void SGReferencedData::cereal_load( + cereal::BinaryInputArchive& ar); +template void SGReferencedData::cereal_load( + cereal::JSONInputArchive& ar); +template void SGReferencedData::cereal_load( + cereal::XMLInputArchive& ar); } diff --git a/src/shogun/lib/SGReferencedData.h b/src/shogun/lib/SGReferencedData.h index 04c65eda264..d8e5abf9a5b 100644 --- a/src/shogun/lib/SGReferencedData.h +++ b/src/shogun/lib/SGReferencedData.h @@ -41,6 +41,14 @@ class SGReferencedData */ int32_t ref_count(); +#ifndef SWIG // SWIG should skip this part + template + void cereal_save(Archive& ar) const; + + template + void cereal_load(Archive& ar); +#endif //#ifndef SWIG // SWIG should skip this part + protected: /** copy refcount */ void copy_refcount(const SGReferencedData &orig); diff --git a/src/shogun/lib/SGVector.cpp b/src/shogun/lib/SGVector.cpp index 099bc6aa566..159f1e9ef4e 100644 --- a/src/shogun/lib/SGVector.cpp +++ b/src/shogun/lib/SGVector.cpp @@ -8,12 +8,12 @@ * Somya Anand */ -#include -#include +#include #include -#include #include -#include +#include +#include +#include #include #include @@ -21,6 +21,12 @@ #include #include +#include +#include +#include +#include +#include + #define COMPLEX128_ERROR_NOARG(function) \ template <> \ void SGVector::function() \ @@ -968,6 +974,56 @@ void SGVector::save(CFile* saver) SG_SERROR("SGVector::save():: Not supported for complex128_t\n"); } +template +template +void SGVector::cereal_save(Archive& ar) const +{ + ar(cereal::make_nvp( + "ReferencedData", cereal::base_class(this))); + ar(cereal::make_nvp("length", vlen)); + for (index_t i = 0; i < vlen; ++i) + ar(vector[i]); +} + +template <> +template +void SGVector::cereal_save(Archive& ar) const +{ + ar(cereal::base_class(this)); + ar(cereal::make_nvp("length", vlen)); + + float64_t* temp = reinterpret_cast(vector); + for (index_t i = 0; i < vlen * 2; ++i) + ar(temp[i]); +} + +template +template +void SGVector::cereal_load(Archive& ar) +{ + unref(); + ar(cereal::base_class(this)); + + ar(vlen); + vector = SG_MALLOC(T, vlen); + for (index_t i = 0; i < vlen; ++i) + ar(vector[i]); +} + +template <> +template +void SGVector::cereal_load(Archive& ar) +{ + unref(); + ar(cereal::base_class(this)); + + ar(vlen); + vector = SG_MALLOC(complex128_t, vlen); + float64_t* temp = reinterpret_cast(vector); + for (index_t i = 0; i < vlen * 2; ++i) + ar(temp[i]); +} + template SGVector SGVector::get_real() { assert_on_cpu(); @@ -1063,20 +1119,35 @@ UNDEFINED(get_imag, float64_t) UNDEFINED(get_imag, floatmax_t) #undef UNDEFINED -template class SGVector; -template class SGVector; -template class SGVector; -template class SGVector; -template class SGVector; -template class SGVector; -template class SGVector; -template class SGVector; -template class SGVector; -template class SGVector; -template class SGVector; -template class SGVector; -template class SGVector; -template class SGVector; +#define FILL_SGVECTOR(typetype) \ + template class SGVector; \ + template void SGVector::cereal_save( \ + cereal::BinaryOutputArchive& ar) const; \ + template void SGVector::cereal_save( \ + cereal::JSONOutputArchive& ar) const; \ + template void SGVector::cereal_save( \ + cereal::XMLOutputArchive& ar) const; \ + template void SGVector::cereal_load( \ + cereal::BinaryInputArchive& ar); \ + template void SGVector::cereal_load( \ + cereal::JSONInputArchive& ar); \ + template void SGVector::cereal_load(cereal::XMLInputArchive& ar); + +FILL_SGVECTOR(bool) +FILL_SGVECTOR(char) +FILL_SGVECTOR(int8_t) +FILL_SGVECTOR(uint8_t) +FILL_SGVECTOR(int16_t) +FILL_SGVECTOR(uint16_t) +FILL_SGVECTOR(int32_t) +FILL_SGVECTOR(uint32_t) +FILL_SGVECTOR(int64_t) +FILL_SGVECTOR(uint64_t) +FILL_SGVECTOR(float32_t) +FILL_SGVECTOR(float64_t) +FILL_SGVECTOR(floatmax_t) +FILL_SGVECTOR(complex128_t) +#undef FILL_SGVECTOR } #undef COMPLEX128_ERROR_NOARG diff --git a/src/shogun/lib/SGVector.h b/src/shogun/lib/SGVector.h index a985d7999d8..fc92719739c 100644 --- a/src/shogun/lib/SGVector.h +++ b/src/shogun/lib/SGVector.h @@ -9,12 +9,12 @@ #ifndef __SGVECTOR_H__ #define __SGVECTOR_H__ -#include - #include -#include #include +#include +#include #include + #include #include @@ -555,6 +555,13 @@ template class SGVector : public SGReferencedData * @return matrix */ static void convert_to_matrix(T*& matrix, index_t nrows, index_t ncols, const T* vector, int32_t vlen, bool fortran_order); + + template + void cereal_save(Archive& ar) const; + + template + void cereal_load(Archive& ar); + #endif // #ifndef SWIG // SWIG should skip this part protected: /** needs to be overridden to copy data */ diff --git a/src/shogun/lib/common.h b/src/shogun/lib/common.h index 5ca3ac5c84c..e1c4ee2e1a1 100644 --- a/src/shogun/lib/common.h +++ b/src/shogun/lib/common.h @@ -7,6 +7,16 @@ #ifndef __COMMON_H__ #define __COMMON_H__ +#ifdef CEREAL_SAVE_FUNCTION_NAME +#undef CEREAL_SAVE_FUNCTION_NAME +#endif +#define CEREAL_SAVE_FUNCTION_NAME cereal_save + +#ifdef CEREAL_LOAD_FUNCTION_NAME +#undef CEREAL_LOAD_FUNCTION_NAME +#endif +#define CEREAL_LOAD_FUNCTION_NAME cereal_load + #include #include #include diff --git a/tests/unit/io/CerealObject.h b/tests/unit/io/CerealObject.h new file mode 100644 index 00000000000..e9bf75dc013 --- /dev/null +++ b/tests/unit/io/CerealObject.h @@ -0,0 +1,40 @@ +#include +#include + +namespace shogun +{ + + /** @brief Used to test the SGObject serialization */ + class CCerealObject : public CSGObject + { + public: + // Construct CCerealObject from input SGVector + CCerealObject(SGVector vec) : CSGObject() + { + m_vector = vec; + init_params(); + } + + // Default constructor + CCerealObject() : CSGObject() + { + m_vector = SGVector(5); + m_vector.set_const(0); + init_params(); + } + + const char* get_name() const + { + return "CerealObject"; + } + + protected: + // Register m_vector to parameter list with name(tag) "test_vector" + void init_params() + { + register_param("test_vector", m_vector); + } + + SGVector m_vector; + }; +} diff --git a/tests/unit/io/Cereal_unittest.cc b/tests/unit/io/Cereal_unittest.cc new file mode 100644 index 00000000000..1e7f1e09127 --- /dev/null +++ b/tests/unit/io/Cereal_unittest.cc @@ -0,0 +1,177 @@ +#include "CerealObject.h" +#include +#include + +#include +#include +#include + +#include +#include +#include + +using namespace shogun; + +#ifndef SWIG // SWIG should skip this part +TEST(Cereal, Json_SGVector_float64_load_equals_saved) +{ + const index_t size = 5; + SGVector a(size); + SGVector b; + a.range_fill(1.0); + + std::string filename = std::tmpnam(nullptr); + + try + { + { + std::ofstream os(filename.c_str()); + cereal::JSONOutputArchive archive(os); + archive(a); + } + + { + std::ifstream is(filename.c_str()); + cereal::JSONInputArchive archive(is); + archive(b); + } + } + catch (std::exception& e) SG_SINFO("Error code: %s \n", e.what()); + + EXPECT_EQ(a.size(), b.size()); + EXPECT_EQ(a.ref_count(), b.ref_count()); + for (index_t i = 0; i < size; i++) + EXPECT_NEAR(a[i], b[i], 1E-15); + + remove(filename.c_str()); +} + +TEST(Cereal, Json_SGMatrix_float64_load_equals_saved) +{ + const index_t nrows = 2, ncols = 3; + SGMatrix a(nrows, ncols); + SGMatrix b; + + for (index_t i = 0; i < nrows * ncols; i++) + a[i] = i; + + std::string filename = std::tmpnam(nullptr); + + try + { + { + std::ofstream os(filename.c_str()); + cereal::JSONOutputArchive archive(os); + archive(a); + } + + { + std::ifstream is(filename.c_str()); + cereal::JSONInputArchive archive(is); + archive(b); + } + } + catch (std::exception& e) SG_SINFO("Error code: %s \n", e.what()); + + EXPECT_EQ(a.num_rows, b.num_rows); + EXPECT_EQ(a.num_cols, b.num_cols); + EXPECT_EQ(a.ref_count(), b.ref_count()); + for (index_t i = 0; i < nrows * ncols; i++) + EXPECT_NEAR(a[i], b[i], 1E-15); + + remove(filename.c_str()); +} + +TEST(Cereal, Json_SGVector_complex128_load_equals_saved) +{ + const index_t size = 5; + SGVector a(size); + SGVector b(size); + + for (index_t i = 0; i < size; ++i) + a[i] = complex128_t(i, i * 2); + + std::string filename = std::tmpnam(nullptr); + + try + { + { + std::ofstream os(filename.c_str()); + cereal::JSONOutputArchive archive(os); + archive(a); + } + + { + std::ifstream is(filename.c_str()); + cereal::JSONInputArchive archive(is); + archive(b); + } + } + catch (std::exception& e) SG_SINFO("Error code: %s \n", e.what()); + + EXPECT_EQ(a.size(), b.size()); + for (index_t i = 0; i < size; i++) + { + EXPECT_NEAR(i, b[i].real(), 1E-15); + EXPECT_NEAR(i * 2, b[i].imag(), 1E-15); + } + + remove(filename.c_str()); +} + +TEST(Cereal, Json_SGVector_load_equals_saved_refcounting_false) +{ + const index_t size = 5; + SGVector a(size, false); + SGVector b; + a.range_fill(1.0); + + std::string filename = std::tmpnam(nullptr); + + try + { + { + std::ofstream os(filename.c_str()); + cereal::JSONOutputArchive archive(os); + archive(a); + } + + { + std::ifstream is(filename.c_str()); + cereal::JSONInputArchive archive(is); + archive(b); + } + } + catch (std::exception& e) SG_SINFO("Error code: %s \n", e.what()); + + EXPECT_EQ(a.size(), b.size()); + EXPECT_EQ(-1, b.ref_count()); + for (index_t i = 0; i < size; i++) + EXPECT_NEAR(a[i], b[i], 1E-15); + + remove(filename.c_str()); +} + +TEST(Cereal, Json_CerealObject_load_equals_saved) +{ + SGVector A(5); + A.range_fill(0); + SGVector B(5); + CCerealObject obj_save(A); + CCerealObject obj_load; + + std::string filename = std::tmpnam(nullptr); + + obj_save.save_json(filename.c_str()); + + obj_load.load_json(filename.c_str()); + B = obj_load.get>("test_vector"); + + EXPECT_EQ(A.size(), B.size()); + for (index_t i = 0; i < 5; ++i) + EXPECT_NEAR(A[i], B[i], 1e-15); + + remove(filename.c_str()); +} + +#endif // ifndef SWIG