diff --git a/src/shogun/base/SGObject.h b/src/shogun/base/SGObject.h index 8464880e781..124533c2e0b 100644 --- a/src/shogun/base/SGObject.h +++ b/src/shogun/base/SGObject.h @@ -33,6 +33,8 @@ */ namespace shogun { +class CDeserializer; +class CSerializer; class RefCount; class SGIO; class Parallel; @@ -216,6 +218,10 @@ class CSGObject */ virtual void print_serializable(const char* prefix=""); + + virtual void serialize(CSerializer* s) const; + virtual void deserialize(CDeserializer* ds); + /** Save this object to file. * * @param file where to save the object; will be closed during diff --git a/src/shogun/io/fs/NullFileSystem.h b/src/shogun/io/fs/NullFileSystem.h index 1d6acb93a9b..6d3dc8c1b0b 100644 --- a/src/shogun/io/fs/NullFileSystem.h +++ b/src/shogun/io/fs/NullFileSystem.h @@ -2,7 +2,7 @@ #define __NULL_FILE_SYSTEM_H__ #include -#include +#include namespace shogun { diff --git a/src/shogun/io/serialization/JsonDeserializer.cpp b/src/shogun/io/serialization/JsonDeserializer.cpp index ac335eebccf..8d9e2e7d582 100644 --- a/src/shogun/io/serialization/JsonDeserializer.cpp +++ b/src/shogun/io/serialization/JsonDeserializer.cpp @@ -1,12 +1,186 @@ /** This software is distributed under BSD 3-clause license (see LICENSE file). * - * Authors: Sergey Lisitsyn + * Authors: Sergey Lisitsyn, Viktor Gal */ +#include + +#include +#include #include +#include +#include + +#include using namespace shogun; +struct SGHandler : public rapidjson::BaseReaderHandler, MyHandler> { + bool Null() { cout << "Null()" << endl; return true; } + bool Bool(bool b) { cout << "Bool(" << boolalpha << b << ")" << endl; return true; } + bool Int(int i) { cout << "Int(" << i << ")" << endl; return true; } + bool Uint(unsigned u) { cout << "Uint(" << u << ")" << endl; return true; } + bool Int64(int64_t i) { cout << "Int64(" << i << ")" << endl; return true; } + bool Uint64(uint64_t u) { cout << "Uint64(" << u << ")" << endl; return true; } + bool Double(double d) { cout << "Double(" << d << ")" << endl; return true; } + bool String(const char* str, SizeType length, bool copy) { + cout << "String(" << str << ", " << length << ", " << boolalpha << copy << ")" << endl; + return true; + } + bool StartObject() { cout << "StartObject()" << endl; return true; } + bool Key(const char* str, SizeType length, bool copy) { + cout << "Key(" << str << ", " << length << ", " << boolalpha << copy << ")" << endl; + return true; + } + bool EndObject(SizeType memberCount) { cout << "EndObject(" << memberCount << ")" << endl; return true; } + bool StartArray() { cout << "StartArray()" << endl; return true; } + bool EndArray(SizeType elementCount) { cout << "EndArray(" << elementCount << ")" << endl; return true; } +}; + + +template +class JSONReaderVisitor : public AnyVisitor +{ +public: + JSONReaderVisitor(RapidJsonReader& jr, rapidjson::Document* document): + AnyVisitor(), m_json_reader(jr), m_document(document) {} + + virtual void on(bool* v) + { + SG_SDEBUG("reading bool") + *v = m_json_reader.GetBool(); + SG_SDEBUG("%d\n", *v) + } + virtual void on(int32_t* v) + { + SG_SDEBUG("reading int32_t") + *v = m_json_reader.GetInt(); + SG_SDEBUG("%d\n", *v) + } + virtual void on(int64_t* v) + { + SG_SDEBUG("reading int64_t") + *v = m_json_reader.GetInt64(); + SG_SDEBUG("%d\n", *v) + } + virtual void on(float* v) + { + SG_SDEBUG("reading float: ") + *v = (float32_t)m_json_reader.GetDouble(); + SG_SDEBUG("%f\n", *v) + } + virtual void on(double* v) + { + SG_SDEBUG("reading double: ") + *v = m_json_reader.GetDouble(); + SG_SDEBUG("%f\n", *v) + } + virtual void on(CSGObject** v) + { + SG_SDEBUG("reading SGObject: ") + *v = m_deser->read().get(); + /* + std::string object_name; + EPrimitiveType primitive_type; + m_archive(object_name, primitive_type); + SG_SDEBUG("%s %d\n", object_name.c_str(), primitive_type) + if (*v == nullptr) + SG_UNREF(*v); + *v = create(object_name.c_str(), primitive_type); + m_archive(**v); + */ + } + virtual void on(SGVector* v) + { + SG_SDEBUG("reading SGVector\n") + } + virtual void on(SGVector* v) + { + SG_SDEBUG("reading SGVector\n") + } + virtual void on(SGVector* v) + { + SG_SDEBUG("reading SGVector\n") + } + virtual void on(SGMatrix* v) + { + SG_SDEBUG("reading SGMatrix>\n") + } + virtual void on(SGMatrix* v) + { + SG_SDEBUG("reading SGMatrix>\n") + } + virtual void on(SGMatrix* v) + { + SG_SDEBUG("reading SGMatrix>\n") + } + +private: + RapidJsonReader& m_json_reader; + rapidjson::Document* m_document; +}; + +class CIStreamAdapter +{ +public: + typedef char Ch; + + CIStreamAdapter(CInputStream* is): m_stream(is) {} + + Ch Peek() const + { + //int c = m_stream.peek(); + // return c == std::char_traits::eof() ? '\0' : (Ch)c; + } + + Ch Take() + { + // int c = m_stream.get(); + // return c == std::char_traits::eof() ? '\0' : (Ch)c; + } + + size_t Tell() const + { + // return (size_t)m_stream.tellg(); + } + + Ch* PutBegin() { assert(false); return 0; } + void Put(Ch) { assert(false); } + void Flush() { assert(false); } + size_t PutEnd(Ch*) { assert(false); return 0; } + +private: + CInputStream* m_stream; + SG_DELETE_COPY_AND_ASSIGN(CIStreamAdapter); +}; + + +template +Some object_reader(Reader& reader) +{ + auto reader_visitor = std::make_unique>(obj_json); + if (!obj_json.IsObject()) + throw ShogunException("JSON value is not an object!"); + + std::string obj_name(obj_json["name"].GetString()); + EPrimitiveType primitive_type((EPrimitiveType) obj_json["generic"].GetInt()); + auto obj = create(obj_name.c_str(), primitive_type); + for (auto it = obj_json.MemberBegin(); it != obj_json.MemberEnd(); ++it) + { + auto param_name = it->name.GetString(); + if (!has(param_name)) + throw ShogunException( + "cannot deserialize the object from file!"); + + BaseTag tag(param_name); + auto parameter = obj->get_parameter(tag); + parameter.get_value().visit(reader_visitor.get()); + obj->update_parameter(tag, parameter.get_value()); + } + return wrap(obj); +} + + CJsonDeserializer::CJsonDeserializer() : CDeserializer() { } @@ -17,5 +191,28 @@ CJsonDeserializer::~CJsonDeserializer() Some CJsonDeserializer::read() { - return wrap(nullptr); + CIStreamAdapter is(stream().get()); + rapidjson::Document obj_json; + obj_json.ParseStream(is); + auto reader_visitor = std::make_unique>(obj_json, &obj_json); + + if (!obj_json.IsObject()) + throw ShogunException("JSON value is not an object!"); + + std::string obj_name(obj_json["name"].GetString()); + EPrimitiveType primitive_type((EPrimitiveType) obj_json["generic"].GetInt()); + auto obj = create(obj_name.c_str(), primitive_type); + for (auto it = obj_json.MemberBegin(); it != obj_json.MemberEnd(); ++it) + { + auto param_name = it->name.GetString(); + if (!has(param_name)) + throw ShogunException( + "cannot deserialize the object from file!"); + + BaseTag tag(param_name); + auto parameter = obj->get_parameter(tag); + parameter.get_value().visit(reader_visitor.get()); + obj->update_parameter(tag, parameter.get_value()); + } + return wrap(obj); } diff --git a/src/shogun/io/serialization/JsonSerializer.cpp b/src/shogun/io/serialization/JsonSerializer.cpp index 822a5ac053f..1e22fa4fd2b 100644 --- a/src/shogun/io/serialization/JsonSerializer.cpp +++ b/src/shogun/io/serialization/JsonSerializer.cpp @@ -1,9 +1,13 @@ /** This software is distributed under BSD 3-clause license (see LICENSE file). * - * Authors: Sergey Lisitsyn + * Authors: Sergey Lisitsyn, Viktor Gal */ +#include + #include +#include +#include #include @@ -25,6 +29,106 @@ struct COutputStreamAdapter COutputStream* m_stream; }; +template void object_writer(Writer& writer, Some object); + +template +class JSONWriterVisitor : public AnyVisitor +{ +public: + JSONWriterVisitor(RapidJsonWriter& jw): + AnyVisitor(), m_json_writer(jw) {} + + virtual void on(bool* v) + { + SG_SDEBUG("writing bool with value %d\n", *v) + m_json_writer.Bool(*v); + } + virtual void on(int32_t* v) + { + SG_SDEBUG("writing int32_t with value %d\n", *v) + m_json_writer.Int(*v); + } + virtual void on(int64_t* v) + { + SG_SDEBUG("writing int64_t with value %d\n", *v) + m_json_writer.Int64(*v); + } + virtual void on(float* v) + { + SG_SDEBUG("writing float with value %f\n", *v) + m_json_writer.Double(*v); + } + virtual void on(double* v) + { + SG_SDEBUG("writing double with value %f\n", *v) + m_json_writer.Double(*v); + } + virtual void on(CSGObject** v) + { + if (*v) + { + SG_SDEBUG("writing SGObject with of type\n") + object_writer(m_json_writer, wrap(*v)); + } + } + virtual void on(SGVector* v) + { + SG_SDEBUG("writing SGVector\n") + m_json_writer.StartArray(); + for (const auto& i: *v) + m_json_writer.Int(i); + m_json_writer.EndArray(); + } + virtual void on(SGVector* v) + { + SG_SDEBUG("writing SGVector\n") + } + virtual void on(SGVector* v) + { + SG_SDEBUG("writing SGVector\n") + + } + virtual void on(SGMatrix* v) + { + SG_SDEBUG("writing SGMatrix\n") + + } + virtual void on(SGMatrix* v) + { + SG_SDEBUG("writing SGMatrix\n") + + } + virtual void on(SGMatrix* v) + { + SG_SDEBUG("writing SGMatrix\n") + + } + +private: + RapidJsonWriter& m_json_writer; +}; + +template +void object_writer(Writer& writer, Some object) +{ + auto writer_visitor = std::make_unique>(writer); + writer.Key("name"); + writer.String(object->get_name()); + writer.Key("generic"); + writer.Int(object->get_generic()); + auto param_names = object->parameter_names(); + writer.Key("parameters"); + writer.StartObject(); + for (auto param_name: param_names) + { + writer.Key(param_name.c_str()); + BaseTag tag(param_name); + auto param = object->get_parameter(tag); + param.get_value().visit(writer_visitor.get()); + } + writer.EndObject(); +} + CJsonSerializer::CJsonSerializer() : CSerializer() { } @@ -37,9 +141,5 @@ void CJsonSerializer::write(Some object) { COutputStreamAdapter adapter{ .m_stream = stream().get() }; rapidjson::Writer writer(adapter); - writer.StartObject(); - writer.Key("name"); - writer.String(object->get_name()); - //writer.String(object.get()); - writer.EndObject(); + object_writer(writer, object); } diff --git a/src/shogun/io/serialization/JsonSerializer.h b/src/shogun/io/serialization/JsonSerializer.h index 137a46b753c..61de4166470 100644 --- a/src/shogun/io/serialization/JsonSerializer.h +++ b/src/shogun/io/serialization/JsonSerializer.h @@ -13,10 +13,10 @@ namespace shogun { public: CJsonSerializer(); - virtual ~CJsonSerializer(); - virtual void write(Some object); + ~CJsonSerializer() override; + void write(Some object) override; - virtual const char* get_name() const + const char* get_name() const override { return "JsonSerializer"; } diff --git a/src/shogun/lib/ShogunNotImplementedException.h b/src/shogun/lib/exception/ShogunNotImplementedException.h similarity index 100% rename from src/shogun/lib/ShogunNotImplementedException.h rename to src/shogun/lib/exception/ShogunNotImplementedException.h diff --git a/tests/unit/io/JsonSerialization_unittest.cc b/tests/unit/io/JsonSerialization_unittest.cc index 6e4313b6a86..1c22a4d4aac 100644 --- a/tests/unit/io/JsonSerialization_unittest.cc +++ b/tests/unit/io/JsonSerialization_unittest.cc @@ -1,10 +1,12 @@ +#include + #include #include #include #include -#include +#include "base/MockObject.h" using namespace shogun; @@ -34,8 +36,14 @@ class CDummyOutputStream : public COutputStream TEST(JsonSerialization, basic_serializer) { + auto obj = some(); + auto child = some(); + obj->put("watched_int", 10); + auto serializer = some(); auto stream = some(); serializer->attach(stream); - serializer->write(serializer); + serializer->write(obj); + + //EXPECT_EQ(); }