Skip to content

Commit

Permalink
[wip] json serialization test
Browse files Browse the repository at this point in the history
  • Loading branch information
vigsterkr committed Jul 23, 2018
1 parent 2f50c3f commit 0d4ba5b
Show file tree
Hide file tree
Showing 7 changed files with 325 additions and 14 deletions.
6 changes: 6 additions & 0 deletions src/shogun/base/SGObject.h
Expand Up @@ -33,6 +33,8 @@
*/
namespace shogun
{
class CDeserializer;
class CSerializer;
class RefCount;
class SGIO;
class Parallel;
Expand Down Expand Up @@ -216,6 +218,10 @@ class CSGObject
*/
virtual void print_serializable(const char* prefix="");


virtual void serialize(CSerializer* s) const;
virtual void deserialize(CDeserializer* ds);

/** Save this object to file.
*
* @param file where to save the object; will be closed during
Expand Down
2 changes: 1 addition & 1 deletion src/shogun/io/fs/NullFileSystem.h
Expand Up @@ -2,7 +2,7 @@
#define __NULL_FILE_SYSTEM_H__

#include <shogun/io/fs/FileSystem.h>
#include <shogun/lib/ShogunNotImplementedException.h>
#include <shogun/lib/exception/ShogunNotImplementedException.h>

namespace shogun
{
Expand Down
201 changes: 199 additions & 2 deletions src/shogun/io/serialization/JsonDeserializer.cpp
@@ -1,12 +1,186 @@
/** This software is distributed under BSD 3-clause license (see LICENSE file).
*
* Authors: Sergey Lisitsyn
* Authors: Sergey Lisitsyn, Viktor Gal
*/

#include <memory>

#include <shogun/base/class_list.h>
#include <shogun/base/macros.h>
#include <shogun/io/serialization/JsonDeserializer.h>
#include <shogun/lib/SGVector.h>
#include <shogun/lib/SGMatrix.h>

#include <rapidjson/reader.h>

using namespace shogun;

struct SGHandler : public rapidjson::BaseReaderHandler<UTF8<>, MyHandler> {
bool Null() { cout << "Null()" << endl; return true; }
bool Bool(bool b) { cout << "Bool(" << boolalpha << b << ")" << endl; return true; }
bool Int(int i) { cout << "Int(" << i << ")" << endl; return true; }
bool Uint(unsigned u) { cout << "Uint(" << u << ")" << endl; return true; }
bool Int64(int64_t i) { cout << "Int64(" << i << ")" << endl; return true; }
bool Uint64(uint64_t u) { cout << "Uint64(" << u << ")" << endl; return true; }
bool Double(double d) { cout << "Double(" << d << ")" << endl; return true; }
bool String(const char* str, SizeType length, bool copy) {
cout << "String(" << str << ", " << length << ", " << boolalpha << copy << ")" << endl;
return true;
}
bool StartObject() { cout << "StartObject()" << endl; return true; }
bool Key(const char* str, SizeType length, bool copy) {
cout << "Key(" << str << ", " << length << ", " << boolalpha << copy << ")" << endl;
return true;
}
bool EndObject(SizeType memberCount) { cout << "EndObject(" << memberCount << ")" << endl; return true; }
bool StartArray() { cout << "StartArray()" << endl; return true; }
bool EndArray(SizeType elementCount) { cout << "EndArray(" << elementCount << ")" << endl; return true; }
};


template<typename RapidJsonReader>
class JSONReaderVisitor : public AnyVisitor
{
public:
JSONReaderVisitor(RapidJsonReader& jr, rapidjson::Document* document):
AnyVisitor(), m_json_reader(jr), m_document(document) {}

virtual void on(bool* v)
{
SG_SDEBUG("reading bool")
*v = m_json_reader.GetBool();
SG_SDEBUG("%d\n", *v)
}
virtual void on(int32_t* v)
{
SG_SDEBUG("reading int32_t")
*v = m_json_reader.GetInt();
SG_SDEBUG("%d\n", *v)
}
virtual void on(int64_t* v)
{
SG_SDEBUG("reading int64_t")
*v = m_json_reader.GetInt64();
SG_SDEBUG("%d\n", *v)
}
virtual void on(float* v)
{
SG_SDEBUG("reading float: ")
*v = (float32_t)m_json_reader.GetDouble();
SG_SDEBUG("%f\n", *v)
}
virtual void on(double* v)
{
SG_SDEBUG("reading double: ")
*v = m_json_reader.GetDouble();
SG_SDEBUG("%f\n", *v)
}
virtual void on(CSGObject** v)
{
SG_SDEBUG("reading SGObject: ")
*v = m_deser->read().get();
/*
std::string object_name;
EPrimitiveType primitive_type;
m_archive(object_name, primitive_type);
SG_SDEBUG("%s %d\n", object_name.c_str(), primitive_type)
if (*v == nullptr)
SG_UNREF(*v);
*v = create(object_name.c_str(), primitive_type);
m_archive(**v);
*/
}
virtual void on(SGVector<int>* v)
{
SG_SDEBUG("reading SGVector<int>\n")
}
virtual void on(SGVector<float>* v)
{
SG_SDEBUG("reading SGVector<float>\n")
}
virtual void on(SGVector<double>* v)
{
SG_SDEBUG("reading SGVector<double>\n")
}
virtual void on(SGMatrix<int>* v)
{
SG_SDEBUG("reading SGMatrix<int>>\n")
}
virtual void on(SGMatrix<float>* v)
{
SG_SDEBUG("reading SGMatrix<float>>\n")
}
virtual void on(SGMatrix<double>* v)
{
SG_SDEBUG("reading SGMatrix<double>>\n")
}

private:
RapidJsonReader& m_json_reader;
rapidjson::Document* m_document;
};

class CIStreamAdapter
{
public:
typedef char Ch;

CIStreamAdapter(CInputStream* is): m_stream(is) {}

Ch Peek() const
{
//int c = m_stream.peek();
// return c == std::char_traits<char>::eof() ? '\0' : (Ch)c;
}

Ch Take()
{
// int c = m_stream.get();
// return c == std::char_traits<char>::eof() ? '\0' : (Ch)c;
}

size_t Tell() const
{
// return (size_t)m_stream.tellg();
}

Ch* PutBegin() { assert(false); return 0; }
void Put(Ch) { assert(false); }
void Flush() { assert(false); }
size_t PutEnd(Ch*) { assert(false); return 0; }

private:
CInputStream* m_stream;
SG_DELETE_COPY_AND_ASSIGN(CIStreamAdapter);
};


template<typename Reader>
Some<CSGObject> object_reader(Reader& reader)
{
auto reader_visitor = std::make_unique<JSONReaderVisitor<rapidjson::Document>>(obj_json);
if (!obj_json.IsObject())
throw ShogunException("JSON value is not an object!");

std::string obj_name(obj_json["name"].GetString());
EPrimitiveType primitive_type((EPrimitiveType) obj_json["generic"].GetInt());
auto obj = create(obj_name.c_str(), primitive_type);
for (auto it = obj_json.MemberBegin(); it != obj_json.MemberEnd(); ++it)
{
auto param_name = it->name.GetString();
if (!has(param_name))
throw ShogunException(
"cannot deserialize the object from file!");

BaseTag tag(param_name);
auto parameter = obj->get_parameter(tag);
parameter.get_value().visit(reader_visitor.get());
obj->update_parameter(tag, parameter.get_value());
}
return wrap<CSGObject>(obj);
}


CJsonDeserializer::CJsonDeserializer() : CDeserializer()
{
}
Expand All @@ -17,5 +191,28 @@ CJsonDeserializer::~CJsonDeserializer()

Some<CSGObject> CJsonDeserializer::read()
{
return wrap<CSGObject>(nullptr);
CIStreamAdapter is(stream().get());
rapidjson::Document obj_json;
obj_json.ParseStream(is);
auto reader_visitor = std::make_unique<JSONReaderVisitor<rapidjson::Document>>(obj_json, &obj_json);

if (!obj_json.IsObject())
throw ShogunException("JSON value is not an object!");

std::string obj_name(obj_json["name"].GetString());
EPrimitiveType primitive_type((EPrimitiveType) obj_json["generic"].GetInt());
auto obj = create(obj_name.c_str(), primitive_type);
for (auto it = obj_json.MemberBegin(); it != obj_json.MemberEnd(); ++it)
{
auto param_name = it->name.GetString();
if (!has(param_name))
throw ShogunException(
"cannot deserialize the object from file!");

BaseTag tag(param_name);
auto parameter = obj->get_parameter(tag);
parameter.get_value().visit(reader_visitor.get());
obj->update_parameter(tag, parameter.get_value());
}
return wrap<CSGObject>(obj);
}
112 changes: 106 additions & 6 deletions src/shogun/io/serialization/JsonSerializer.cpp
@@ -1,9 +1,13 @@
/** This software is distributed under BSD 3-clause license (see LICENSE file).
*
* Authors: Sergey Lisitsyn
* Authors: Sergey Lisitsyn, Viktor Gal
*/

#include <memory>

#include <shogun/io/serialization/JsonSerializer.h>
#include <shogun/lib/SGVector.h>
#include <shogun/lib/SGMatrix.h>

#include <rapidjson/writer.h>

Expand All @@ -25,6 +29,106 @@ struct COutputStreamAdapter
COutputStream* m_stream;
};

template<typename Writer> void object_writer(Writer& writer, Some<CSGObject> object);

template<typename RapidJsonWriter>
class JSONWriterVisitor : public AnyVisitor
{
public:
JSONWriterVisitor(RapidJsonWriter& jw):
AnyVisitor(), m_json_writer(jw) {}

virtual void on(bool* v)
{
SG_SDEBUG("writing bool with value %d\n", *v)
m_json_writer.Bool(*v);
}
virtual void on(int32_t* v)
{
SG_SDEBUG("writing int32_t with value %d\n", *v)
m_json_writer.Int(*v);
}
virtual void on(int64_t* v)
{
SG_SDEBUG("writing int64_t with value %d\n", *v)
m_json_writer.Int64(*v);
}
virtual void on(float* v)
{
SG_SDEBUG("writing float with value %f\n", *v)
m_json_writer.Double(*v);
}
virtual void on(double* v)
{
SG_SDEBUG("writing double with value %f\n", *v)
m_json_writer.Double(*v);
}
virtual void on(CSGObject** v)
{
if (*v)
{
SG_SDEBUG("writing SGObject with of type\n")
object_writer(m_json_writer, wrap<CSGObject>(*v));
}
}
virtual void on(SGVector<int>* v)
{
SG_SDEBUG("writing SGVector<int>\n")
m_json_writer.StartArray();
for (const auto& i: *v)
m_json_writer.Int(i);
m_json_writer.EndArray();
}
virtual void on(SGVector<float>* v)
{
SG_SDEBUG("writing SGVector<float>\n")
}
virtual void on(SGVector<double>* v)
{
SG_SDEBUG("writing SGVector<double>\n")

}
virtual void on(SGMatrix<int>* v)
{
SG_SDEBUG("writing SGMatrix<int>\n")

}
virtual void on(SGMatrix<float>* v)
{
SG_SDEBUG("writing SGMatrix<float>\n")

}
virtual void on(SGMatrix<double>* v)
{
SG_SDEBUG("writing SGMatrix<double>\n")

}

private:
RapidJsonWriter& m_json_writer;
};

template<typename Writer>
void object_writer(Writer& writer, Some<CSGObject> object)
{
auto writer_visitor = std::make_unique<JSONWriterVisitor<Writer>>(writer);
writer.Key("name");
writer.String(object->get_name());
writer.Key("generic");
writer.Int(object->get_generic());
auto param_names = object->parameter_names();
writer.Key("parameters");
writer.StartObject();
for (auto param_name: param_names)
{
writer.Key(param_name.c_str());
BaseTag tag(param_name);
auto param = object->get_parameter(tag);
param.get_value().visit(writer_visitor.get());
}
writer.EndObject();
}

CJsonSerializer::CJsonSerializer() : CSerializer()
{
}
Expand All @@ -37,9 +141,5 @@ void CJsonSerializer::write(Some<CSGObject> object)
{
COutputStreamAdapter adapter{ .m_stream = stream().get() };
rapidjson::Writer<COutputStreamAdapter> writer(adapter);
writer.StartObject();
writer.Key("name");
writer.String(object->get_name());
//writer.String(object.get());
writer.EndObject();
object_writer(writer, object);
}
6 changes: 3 additions & 3 deletions src/shogun/io/serialization/JsonSerializer.h
Expand Up @@ -13,10 +13,10 @@ namespace shogun
{
public:
CJsonSerializer();
virtual ~CJsonSerializer();
virtual void write(Some<CSGObject> object);
~CJsonSerializer() override;
void write(Some<CSGObject> object) override;

virtual const char* get_name() const
const char* get_name() const override
{
return "JsonSerializer";
}
Expand Down

0 comments on commit 0d4ba5b

Please sign in to comment.