diff --git a/src/shogun/base/AnyParameter.h b/src/shogun/base/AnyParameter.h index c5b138aff6d..17b83b660b4 100644 --- a/src/shogun/base/AnyParameter.h +++ b/src/shogun/base/AnyParameter.h @@ -2,7 +2,9 @@ #define __ANYPARAMETER_H__ #include +#include +#include #include namespace shogun @@ -67,7 +69,7 @@ namespace shogun template void serialize(Archive& ar) { - ar(m_model_selection, m_gradient); + ar(m_description, m_model_selection, m_gradient); } private: @@ -122,14 +124,20 @@ namespace shogun return !(*this == other); } - /** serialize the object using cereal - * - * @param ar Archive type - */ - template - void serialize(Archive& ar) + template + void cereal_save(Archive& ar) const + { + std::unique_ptr> visitor(new CerealWriterVisitor(ar)); + m_value.visit(visitor.get()); + ar(m_properties); + } + + template + void cereal_load(Archive& ar) { - ar(m_value, m_properties); + std::unique_ptr> visitor(new CerealReaderVisitor(ar)); + m_value.visit(visitor.get()); + ar(m_properties); } private: diff --git a/src/shogun/base/SGObject.cpp b/src/shogun/base/SGObject.cpp index b2c9c48f115..94cc1114f1b 100644 --- a/src/shogun/base/SGObject.cpp +++ b/src/shogun/base/SGObject.cpp @@ -244,74 +244,6 @@ int32_t CSGObject::unref() } } -void CSGObject::save_binary(const char* filename) const -{ - std::ofstream os(filename); - cereal::BinaryOutputArchive archive(os); - archive(cereal::make_nvp(this->get_name(), *this)); -} - -void CSGObject::save_json(const char* filename) const -{ - std::ofstream os(filename); - cereal::JSONOutputArchive archive(os); - archive(cereal::make_nvp(this->get_name(), *this)); -} - -void CSGObject::save_xml(const char* filename) const -{ - std::ofstream os(filename); - cereal::XMLOutputArchive archive(os); - archive(cereal::make_nvp(this->get_name(), *this)); -} - -void CSGObject::load_binary(const char* filename) -{ - std::ifstream is(filename); - cereal::BinaryInputArchive archive(is); - archive(*this); -} - -void CSGObject::load_json(const char* filename) -{ - std::ifstream is(filename); - cereal::JSONInputArchive archive(is); - archive(*this); -} - -void CSGObject::load_xml(const char* filename) -{ - std::ifstream is(filename); - cereal::XMLInputArchive archive(is); - archive(*this); -} - -template -void CSGObject::cereal_save(Archive& ar) const -{ - for (const auto& it : self->map) - ar(cereal::make_nvp(it.first.name(), it.second)); -} -template void CSGObject::cereal_save( - cereal::BinaryOutputArchive& ar) const; -template void CSGObject::cereal_save( - cereal::JSONOutputArchive& ar) const; -template void CSGObject::cereal_save( - cereal::XMLOutputArchive& ar) const; - -template -void CSGObject::cereal_load(Archive& ar) -{ - for (auto& it : self->map) - ar(it.second); -} -template void CSGObject::cereal_load( - cereal::BinaryInputArchive& ar); -template void -CSGObject::cereal_load(cereal::JSONInputArchive& ar); -template void -CSGObject::cereal_load(cereal::XMLInputArchive& ar); - #ifdef TRACE_MEMORY_ALLOCS #include extern CMap* sg_mallocs; @@ -899,23 +831,23 @@ class ToStringVisitor : public AnyVisitor { } - virtual void on(const bool* v) + virtual void on(bool* v) { stream() << (*v ? "true" : "false"); } - virtual void on(const int32_t* v) + virtual void on(int32_t* v) { stream() << *v; } - virtual void on(const int64_t* v) + virtual void on(int64_t* v) { stream() << *v; } - virtual void on(const float* v) + virtual void on(float* v) { stream() << *v; } - virtual void on(const double* v) + virtual void on(double* v) { stream() << *v; } @@ -923,7 +855,7 @@ class ToStringVisitor : public AnyVisitor { stream() << *v; } - virtual void on(const CSGObject** v) + virtual void on(CSGObject** v) { if (*v) { @@ -934,27 +866,27 @@ class ToStringVisitor : public AnyVisitor stream() << "null"; } } - virtual void on(const SGVector* v) + virtual void on(SGVector* v) { to_string(v); } - virtual void on(const SGVector* v) + virtual void on(SGVector* v) { to_string(v); } - virtual void on(const SGVector* v) + virtual void on(SGVector* v) { to_string(v); } - virtual void on(const SGMatrix* mat) + virtual void on(SGMatrix* mat) { to_string(mat); } - virtual void on(const SGMatrix* mat) + virtual void on(SGMatrix* mat) { to_string(mat); } - virtual void on(const SGMatrix* mat) + virtual void on(SGMatrix* mat) { to_string(mat); } @@ -966,7 +898,7 @@ class ToStringVisitor : public AnyVisitor } template - void to_string(const SGMatrix* m) + void to_string(SGMatrix* m) { if (m) { @@ -990,7 +922,7 @@ class ToStringVisitor : public AnyVisitor } template - void to_string(const SGVector* v) + void to_string(SGVector* v) { if (v) { diff --git a/src/shogun/base/SGObject.h b/src/shogun/base/SGObject.h index 0f8325d2656..8b82639a55e 100644 --- a/src/shogun/base/SGObject.h +++ b/src/shogun/base/SGObject.h @@ -141,55 +141,67 @@ class CSGObject virtual ~CSGObject(); #ifndef SWIG // SWIG should skip this part - /** serializes the SGObject to a binary file - * - * @param filename Binary archive filename - */ - void save_binary(const char* filename) const; - - /** serializes the SGObject to a JSON file - * - * @param filename JSON archive filename - */ - void save_json(const char* filename) const; - - /** serializes the SGObject to a XML file - * - * @param filename XML archive filename - */ - void save_xml(const char* filename) const; - - /** loads SGObject from a Binary file - * - * @param filename Binary archive filename - */ - void load_binary(const char* filename); - - /** loads SGObject from a JSON file - * - * @param filename JSON archive filename - */ - void load_json(const char* filename); - - /** loads SGObject from a XML file - * - * @param filename XML archive filename - */ - void load_xml(const char* filename); - /** serializes SGObject parameters to Archive with Cereal * * @param ar Archive */ template - void cereal_save(Archive& ar) const; + void cereal_save(Archive& ar) const + { + std::string class_name(get_name()); + ar(class_name, m_generic); + auto param_names = parameter_names(); + ar(param_names.size()); + for (auto param_name: param_names) + { + ar(param_name); + BaseTag tag(param_name); + ar(get_parameter(tag)); + } + } /** loads SGObject parameters from Archive with Cereal * * @param ar Archive */ template - void cereal_load(Archive& ar); + void cereal_load(Archive& ar) + { + size_t num_param_names; + try + { + ar(num_param_names); + } + catch(const std::exception& e) + { + std::string name; + EPrimitiveType p_type; + ar(name, p_type); + + if (name.compare(std::string(get_name())) != 0) + throw ShogunException( + "cannot deserialize the object from file as num parameters does not match!"); + ar(num_param_names); + } + auto param_names = parameter_names(); + if (param_names.size() != num_param_names) + throw ShogunException( + "cannot deserialize the object from file as num parameters does not match!"); + + for (size_t i = 0; i < num_param_names; ++i) + { + std::string param_name; + ar(param_name); + if (!has(param_name)) + throw ShogunException( + "cannot deserialize the object from file!"); + + BaseTag tag(param_name); + auto parameter = get_parameter(tag); + ar(parameter); + update_parameter(tag, parameter.get_value()); + } + } #endif // #ifndef SWIG // SWIG should skip this part /** increase reference counter diff --git a/src/shogun/base/class_list.h b/src/shogun/base/class_list.h index 2f52468ed5c..accd7e19f27 100644 --- a/src/shogun/base/class_list.h +++ b/src/shogun/base/class_list.h @@ -19,7 +19,6 @@ #include namespace shogun { - class CSGObject; /** new shogun instance * @param sgserializable_name diff --git a/src/shogun/io/CerealVisitor.h b/src/shogun/io/CerealVisitor.h new file mode 100644 index 00000000000..e2a59fd0c34 --- /dev/null +++ b/src/shogun/io/CerealVisitor.h @@ -0,0 +1,184 @@ +#ifndef __CEREAL_VISITOR_H__ +#define __CEREAL_VISITOR_H__ + +#include +#include +#include +#include + +namespace shogun +{ + +template +class CerealWriterVisitor : public AnyVisitor +{ +public: + CerealWriterVisitor(Archive& ar) : AnyVisitor(), m_archive(ar) + { + } + + virtual void on(bool* v) + { + SG_SDEBUG("writing bool with value %d\n", *v) + m_archive(*v); + } + virtual void on(int32_t* v) + { + SG_SDEBUG("writing int32_t with value %d\n", *v) + m_archive(*v); + } + virtual void on(int64_t* v) + { + SG_SDEBUG("writing int64_t with value %d\n", *v) + m_archive(*v); + } + virtual void on(float* v) + { + SG_SDEBUG("writing float with value %f\n", *v) + m_archive(*v); + } + virtual void on(double* v) + { + SG_SDEBUG("writing double with value %f\n", *v) + m_archive(*v); + } + virtual void on(CSGObject** v) + { + if (*v) + { + SG_SDEBUG("writing SGObject with of type: %s\n", (*v)->get_name()) + m_archive(**v); + } + } + virtual void on(SGVector* v) + { + SG_SDEBUG("writing SGVector\n") + m_archive(*v); + } + virtual void on(SGVector* v) + { + SG_SDEBUG("writing SGVector\n") + m_archive(*v); + } + virtual void on(SGVector* v) + { + SG_SDEBUG("writing SGVector\n") + m_archive(*v); + } + virtual void on(SGMatrix* v) + { + SG_SDEBUG("writing SGMatrix\n") + m_archive(*v); + } + virtual void on(SGMatrix* v) + { + SG_SDEBUG("writing SGMatrix\n") + m_archive(*v); + } + virtual void on(SGMatrix* v) + { + SG_SDEBUG("writing SGMatrix\n") + m_archive(*v); + } + +private: + Archive& m_archive; +}; + +template +class CerealReaderVisitor : public AnyVisitor +{ +public: + CerealReaderVisitor(Archive& ar) : AnyVisitor(), m_archive(ar) + { + } + + virtual void on(bool* v) + { + SG_SDEBUG("reading bool") + *v = deserialize(); + SG_SDEBUG("%d\n", *v) + } + virtual void on(int32_t* v) + { + SG_SDEBUG("reading int32_t") + *v = deserialize(); + SG_SDEBUG("%d\n", *v) + } + virtual void on(int64_t* v) + { + SG_SDEBUG("reading int64_t") + *v = deserialize(); + SG_SDEBUG("%d\n", *v) + } + virtual void on(float* v) + { + SG_SDEBUG("reading float: ") + *v = deserialize(); + SG_SDEBUG("%f\n", *v) + } + virtual void on(double* v) + { + SG_SDEBUG("reading double: ") + *v = deserialize(); + SG_SDEBUG("%f\n", *v) + } + virtual void on(CSGObject** v) + { + if (*v) + { + SG_SDEBUG("reading SGObject: ") + std::string object_name; + EPrimitiveType primitive_type; + m_archive(object_name, primitive_type); + SG_SDEBUG("%s %d\n", object_name, primitive_type) + *v = create(object_name, primitive_type); + } + } + virtual void on(SGVector* v) + { + SG_SDEBUG("reading SGVector\n") + *v = deserialize>(); + } + virtual void on(SGVector* v) + { + SG_SDEBUG("reading SGVector\n") + *v = deserialize>(); + } + virtual void on(SGVector* v) + { + SG_SDEBUG("reading SGVector\n") + *v = deserialize>(); + } + virtual void on(SGMatrix* v) + { + SG_SDEBUG("reading SGMatrix>\n") + *v = deserialize>(); + } + virtual void on(SGMatrix* v) + { + SG_SDEBUG("reading SGMatrix>\n") + *v = deserialize>(); + } + virtual void on(SGMatrix* v) + { + SG_SDEBUG("reading SGMatrix>\n") + *v = deserialize>(); + } + +private: + template + T deserialize() + { + T value; + m_archive(value); + return value; + } + +private: + Archive& m_archive; +}; + +} + +#endif diff --git a/src/shogun/lib/SGMatrix.cpp b/src/shogun/lib/SGMatrix.cpp index 48e4753716a..5d243ad5b4a 100644 --- a/src/shogun/lib/SGMatrix.cpp +++ b/src/shogun/lib/SGMatrix.cpp @@ -12,7 +12,6 @@ #include #include #include -#include #include #include #include @@ -24,6 +23,7 @@ #include #include +#include namespace shogun { @@ -1232,7 +1232,6 @@ template template void SGMatrix::cereal_save(Archive& ar) const { - ar(cereal::base_class(this)); ar(cereal::make_nvp("num_rows", num_rows)); ar(cereal::make_nvp("num_cols", num_cols)); @@ -1244,7 +1243,6 @@ template <> template void SGMatrix::cereal_save(Archive& ar) const { - ar(cereal::base_class(this)); ar(cereal::make_nvp("num_rows", num_rows)); ar(cereal::make_nvp("num_cols", num_cols)); @@ -1257,9 +1255,6 @@ template template void SGMatrix::cereal_load(Archive& ar) { - unref(); - ar(cereal::base_class(this)); - ar(num_rows); ar(num_cols); matrix = SG_MALLOC(T, num_rows * num_cols); @@ -1271,9 +1266,6 @@ template <> template void SGMatrix::cereal_load(Archive& ar) { - unref(); - ar(cereal::base_class(this)); - ar(num_rows); ar(num_cols); matrix = SG_MALLOC(complex128_t, num_rows * num_cols); diff --git a/src/shogun/lib/SGMatrix.h b/src/shogun/lib/SGMatrix.h index 241737efbb9..573a8496e4d 100644 --- a/src/shogun/lib/SGMatrix.h +++ b/src/shogun/lib/SGMatrix.h @@ -1,9 +1,9 @@ /* * This software is distributed under BSD 3-clause license (see LICENSE file). * - * Authors: Soeren Sonnenburg, Heiko Strathmann, Soumyajit De, Sergey Lisitsyn, - * Pan Deng, Khaled Nasr, Michele Mazzoni, Viktor Gal, - * Fernando Iglesias, Thoralf Klein, Chiyuan Zhang, Koen van de Sande, + * Authors: Soeren Sonnenburg, Heiko Strathmann, Soumyajit De, Sergey Lisitsyn, + * Pan Deng, Khaled Nasr, Michele Mazzoni, Viktor Gal, + * Fernando Iglesias, Thoralf Klein, Chiyuan Zhang, Koen van de Sande, * Roman Votyakov */ #ifndef __SGMATRIX_H__ diff --git a/src/shogun/lib/SGReferencedData.cpp b/src/shogun/lib/SGReferencedData.cpp index dd282495d17..76f6b9939fb 100644 --- a/src/shogun/lib/SGReferencedData.cpp +++ b/src/shogun/lib/SGReferencedData.cpp @@ -119,41 +119,4 @@ int32_t SGReferencedData::unref() } } -template -void SGReferencedData::cereal_save(Archive& ar) const -{ - if (m_refcount != NULL) - ar(cereal::make_nvp("ref_counting", true), - cereal::make_nvp("refcount number", m_refcount->ref_count())); - else - ar(cereal::make_nvp("ref_counting", false)); -} -template void SGReferencedData::cereal_save( - cereal::BinaryOutputArchive& ar) const; -template void SGReferencedData::cereal_save( - cereal::JSONOutputArchive& ar) const; -template void SGReferencedData::cereal_save( - cereal::XMLOutputArchive& ar) const; - -template -void SGReferencedData::cereal_load(Archive& ar) -{ - bool ref_counting; - ar(ref_counting); - - if (ref_counting) - { - int32_t temp; - ar(temp); - m_refcount = new RefCount(temp); - } - else - m_refcount = NULL; -} -template void SGReferencedData::cereal_load( - cereal::BinaryInputArchive& ar); -template void SGReferencedData::cereal_load( - cereal::JSONInputArchive& ar); -template void SGReferencedData::cereal_load( - cereal::XMLInputArchive& ar); -} +} // namespace shogun diff --git a/src/shogun/lib/SGReferencedData.h b/src/shogun/lib/SGReferencedData.h index d8e5abf9a5b..ad6093911be 100644 --- a/src/shogun/lib/SGReferencedData.h +++ b/src/shogun/lib/SGReferencedData.h @@ -1,7 +1,7 @@ /* * This software is distributed under BSD 3-clause license (see LICENSE file). * - * Authors: Soeren Sonnenburg, Sergey Lisitsyn, Yuyu Zhang, + * Authors: Soeren Sonnenburg, Sergey Lisitsyn, Yuyu Zhang, * Evangelos Anagnostopoulos */ #ifndef __SGREFERENCED_DATA_H__ @@ -41,14 +41,6 @@ class SGReferencedData */ int32_t ref_count(); -#ifndef SWIG // SWIG should skip this part - template - void cereal_save(Archive& ar) const; - - template - void cereal_load(Archive& ar); -#endif //#ifndef SWIG // SWIG should skip this part - protected: /** copy refcount */ void copy_refcount(const SGReferencedData &orig); diff --git a/src/shogun/lib/SGVector.cpp b/src/shogun/lib/SGVector.cpp index 159f1e9ef4e..679565062b5 100644 --- a/src/shogun/lib/SGVector.cpp +++ b/src/shogun/lib/SGVector.cpp @@ -1,10 +1,10 @@ /* * This software is distributed under BSD 3-clause license (see LICENSE file). * - * Authors: Soeren Sonnenburg, Viktor Gal, Heiko Strathmann, Fernando Iglesias, - * Sanuj Sharma, Pan Deng, Sergey Lisitsyn, Thoralf Klein, - * Soumyajit De, Michele Mazzoni, Evgeniy Andreev, Chiyuan Zhang, - * Bjoern Esser, Weijie Lin, Khaled Nasr, Koen van de Sande, + * Authors: Soeren Sonnenburg, Viktor Gal, Heiko Strathmann, Fernando Iglesias, + * Sanuj Sharma, Pan Deng, Sergey Lisitsyn, Thoralf Klein, + * Soumyajit De, Michele Mazzoni, Evgeniy Andreev, Chiyuan Zhang, + * Bjoern Esser, Weijie Lin, Khaled Nasr, Koen van de Sande, * Somya Anand */ @@ -19,7 +19,6 @@ #include #include #include -#include #include #include @@ -27,6 +26,8 @@ #include #include +#include + #define COMPLEX128_ERROR_NOARG(function) \ template <> \ void SGVector::function() \ @@ -978,8 +979,6 @@ template template void SGVector::cereal_save(Archive& ar) const { - ar(cereal::make_nvp( - "ReferencedData", cereal::base_class(this))); ar(cereal::make_nvp("length", vlen)); for (index_t i = 0; i < vlen; ++i) ar(vector[i]); @@ -989,9 +988,7 @@ template <> template void SGVector::cereal_save(Archive& ar) const { - ar(cereal::base_class(this)); ar(cereal::make_nvp("length", vlen)); - float64_t* temp = reinterpret_cast(vector); for (index_t i = 0; i < vlen * 2; ++i) ar(temp[i]); @@ -1001,9 +998,6 @@ template template void SGVector::cereal_load(Archive& ar) { - unref(); - ar(cereal::base_class(this)); - ar(vlen); vector = SG_MALLOC(T, vlen); for (index_t i = 0; i < vlen; ++i) @@ -1014,9 +1008,6 @@ template <> template void SGVector::cereal_load(Archive& ar) { - unref(); - ar(cereal::base_class(this)); - ar(vlen); vector = SG_MALLOC(complex128_t, vlen); float64_t* temp = reinterpret_cast(vector); diff --git a/src/shogun/lib/any.h b/src/shogun/lib/any.h index 0db9a51e74e..2de233062f1 100644 --- a/src/shogun/lib/any.h +++ b/src/shogun/lib/any.h @@ -170,18 +170,18 @@ namespace shogun public: virtual ~AnyVisitor() = default; - virtual void on(const bool*) = 0; - virtual void on(const int32_t*) = 0; - virtual void on(const int64_t*) = 0; - virtual void on(const float*) = 0; - virtual void on(const double*) = 0; - virtual void on(const CSGObject**) = 0; - virtual void on(const SGVector*) = 0; - virtual void on(const SGVector*) = 0; - virtual void on(const SGVector*) = 0; - virtual void on(const SGMatrix*) = 0; - virtual void on(const SGMatrix*) = 0; - virtual void on(const SGMatrix*) = 0; + virtual void on(bool*) = 0; + virtual void on(int32_t*) = 0; + virtual void on(int64_t*) = 0; + virtual void on(float*) = 0; + virtual void on(double*) = 0; + virtual void on(CSGObject**) = 0; + virtual void on(SGVector*) = 0; + virtual void on(SGVector*) = 0; + virtual void on(SGVector*) = 0; + virtual void on(SGMatrix*) = 0; + virtual void on(SGMatrix*) = 0; + virtual void on(SGMatrix*) = 0; void on(Empty*) { @@ -590,7 +590,7 @@ namespace shogun */ virtual void visit(void* storage, AnyVisitor* visitor) const { - visitor->on(typed_pointer(storage)); + visitor->on(const_cast(typed_pointer(storage))); } }; diff --git a/tests/unit/io/Cereal_unittest.cc b/tests/unit/io/Cereal_unittest.cc index 1e7f1e09127..0e621ac16e0 100644 --- a/tests/unit/io/Cereal_unittest.cc +++ b/tests/unit/io/Cereal_unittest.cc @@ -162,9 +162,22 @@ TEST(Cereal, Json_CerealObject_load_equals_saved) std::string filename = std::tmpnam(nullptr); - obj_save.save_json(filename.c_str()); + try + { + { + std::ofstream os(filename.c_str()); + cereal::JSONOutputArchive archive(os); + archive(obj_save); + } + + { + std::ifstream is(filename.c_str()); + cereal::JSONInputArchive archive(is); + archive(obj_load); + } + } + catch (std::exception& e) SG_SINFO("Error code: %s \n", e.what()); - obj_load.load_json(filename.c_str()); B = obj_load.get>("test_vector"); EXPECT_EQ(A.size(), B.size());