Skip to content

Commit

Permalink
[wip] add cereal visitor et al
Browse files Browse the repository at this point in the history
  • Loading branch information
vigsterkr committed Apr 17, 2018
1 parent 7bbf097 commit 2c04e59
Show file tree
Hide file tree
Showing 12 changed files with 318 additions and 254 deletions.
24 changes: 16 additions & 8 deletions src/shogun/base/AnyParameter.h
Expand Up @@ -2,7 +2,9 @@
#define __ANYPARAMETER_H__

#include <shogun/lib/any.h>
#include <shogun/io/CerealVisitor.h>

#include <memory>
#include <string>

namespace shogun
Expand Down Expand Up @@ -67,7 +69,7 @@ namespace shogun
template<class Archive>
void serialize(Archive& ar)
{
ar(m_model_selection, m_gradient);
ar(m_description, m_model_selection, m_gradient);
}

private:
Expand Down Expand Up @@ -122,14 +124,20 @@ namespace shogun
return !(*this == other);
}

/** serialize the object using cereal
*
* @param ar Archive type
*/
template<class Archive>
void serialize(Archive& ar)
template <class Archive>
void cereal_save(Archive& ar) const
{
std::unique_ptr<CerealWriterVisitor<Archive>> visitor(new CerealWriterVisitor<Archive>(ar));
m_value.visit(visitor.get());
ar(m_properties);
}

template <class Archive>
void cereal_load(Archive& ar)
{
ar(m_value, m_properties);
std::unique_ptr<CerealReaderVisitor<Archive>> visitor(new CerealReaderVisitor<Archive>(ar));
m_value.visit(visitor.get());
ar(m_properties);
}

private:
Expand Down
96 changes: 14 additions & 82 deletions src/shogun/base/SGObject.cpp
Expand Up @@ -244,74 +244,6 @@ int32_t CSGObject::unref()
}
}

void CSGObject::save_binary(const char* filename) const
{
std::ofstream os(filename);
cereal::BinaryOutputArchive archive(os);
archive(cereal::make_nvp(this->get_name(), *this));
}

void CSGObject::save_json(const char* filename) const
{
std::ofstream os(filename);
cereal::JSONOutputArchive archive(os);
archive(cereal::make_nvp(this->get_name(), *this));
}

void CSGObject::save_xml(const char* filename) const
{
std::ofstream os(filename);
cereal::XMLOutputArchive archive(os);
archive(cereal::make_nvp(this->get_name(), *this));
}

void CSGObject::load_binary(const char* filename)
{
std::ifstream is(filename);
cereal::BinaryInputArchive archive(is);
archive(*this);
}

void CSGObject::load_json(const char* filename)
{
std::ifstream is(filename);
cereal::JSONInputArchive archive(is);
archive(*this);
}

void CSGObject::load_xml(const char* filename)
{
std::ifstream is(filename);
cereal::XMLInputArchive archive(is);
archive(*this);
}

template <class Archive>
void CSGObject::cereal_save(Archive& ar) const
{
for (const auto& it : self->map)
ar(cereal::make_nvp(it.first.name(), it.second));
}
template void CSGObject::cereal_save<cereal::BinaryOutputArchive>(
cereal::BinaryOutputArchive& ar) const;
template void CSGObject::cereal_save<cereal::JSONOutputArchive>(
cereal::JSONOutputArchive& ar) const;
template void CSGObject::cereal_save<cereal::XMLOutputArchive>(
cereal::XMLOutputArchive& ar) const;

template <class Archive>
void CSGObject::cereal_load(Archive& ar)
{
for (auto& it : self->map)
ar(it.second);
}
template void CSGObject::cereal_load<cereal::BinaryInputArchive>(
cereal::BinaryInputArchive& ar);
template void
CSGObject::cereal_load<cereal::JSONInputArchive>(cereal::JSONInputArchive& ar);
template void
CSGObject::cereal_load<cereal::XMLInputArchive>(cereal::XMLInputArchive& ar);

#ifdef TRACE_MEMORY_ALLOCS
#include <shogun/lib/Map.h>
extern CMap<void*, shogun::MemoryBlock>* sg_mallocs;
Expand Down Expand Up @@ -899,31 +831,31 @@ class ToStringVisitor : public AnyVisitor
{
}

virtual void on(const bool* v)
virtual void on(bool* v)
{
stream() << (*v ? "true" : "false");
}
virtual void on(const int32_t* v)
virtual void on(int32_t* v)
{
stream() << *v;
}
virtual void on(const int64_t* v)
virtual void on(int64_t* v)
{
stream() << *v;
}
virtual void on(const float* v)
virtual void on(float* v)
{
stream() << *v;
}
virtual void on(const double* v)
virtual void on(double* v)
{
stream() << *v;
}
virtual void on(long double* v)
{
stream() << *v;
}
virtual void on(const CSGObject** v)
virtual void on(CSGObject** v)
{
if (*v)
{
Expand All @@ -934,27 +866,27 @@ class ToStringVisitor : public AnyVisitor
stream() << "null";
}
}
virtual void on(const SGVector<int>* v)
virtual void on(SGVector<int>* v)
{
to_string(v);
}
virtual void on(const SGVector<float>* v)
virtual void on(SGVector<float>* v)
{
to_string(v);
}
virtual void on(const SGVector<double>* v)
virtual void on(SGVector<double>* v)
{
to_string(v);
}
virtual void on(const SGMatrix<int>* mat)
virtual void on(SGMatrix<int>* mat)
{
to_string(mat);
}
virtual void on(const SGMatrix<float>* mat)
virtual void on(SGMatrix<float>* mat)
{
to_string(mat);
}
virtual void on(const SGMatrix<double>* mat)
virtual void on(SGMatrix<double>* mat)
{
to_string(mat);
}
Expand All @@ -966,7 +898,7 @@ class ToStringVisitor : public AnyVisitor
}

template <class T>
void to_string(const SGMatrix<T>* m)
void to_string(SGMatrix<T>* m)
{
if (m)
{
Expand All @@ -990,7 +922,7 @@ class ToStringVisitor : public AnyVisitor
}

template <class T>
void to_string(const SGVector<T>* v)
void to_string(SGVector<T>* v)
{
if (v)
{
Expand Down
88 changes: 50 additions & 38 deletions src/shogun/base/SGObject.h
Expand Up @@ -141,55 +141,67 @@ class CSGObject
virtual ~CSGObject();

#ifndef SWIG // SWIG should skip this part
/** serializes the SGObject to a binary file
*
* @param filename Binary archive filename
*/
void save_binary(const char* filename) const;

/** serializes the SGObject to a JSON file
*
* @param filename JSON archive filename
*/
void save_json(const char* filename) const;

/** serializes the SGObject to a XML file
*
* @param filename XML archive filename
*/
void save_xml(const char* filename) const;

/** loads SGObject from a Binary file
*
* @param filename Binary archive filename
*/
void load_binary(const char* filename);

/** loads SGObject from a JSON file
*
* @param filename JSON archive filename
*/
void load_json(const char* filename);

/** loads SGObject from a XML file
*
* @param filename XML archive filename
*/
void load_xml(const char* filename);

/** serializes SGObject parameters to Archive with Cereal
*
* @param ar Archive
*/
template <class Archive>
void cereal_save(Archive& ar) const;
void cereal_save(Archive& ar) const
{
std::string class_name(get_name());
ar(class_name, m_generic);
auto param_names = parameter_names();
ar(param_names.size());
for (auto param_name: param_names)
{
ar(param_name);
BaseTag tag(param_name);
ar(get_parameter(tag));
}
}

/** loads SGObject parameters from Archive with Cereal
*
* @param ar Archive
*/
template <class Archive>
void cereal_load(Archive& ar);
void cereal_load(Archive& ar)
{
size_t num_param_names;
try
{
ar(num_param_names);
}
catch(const std::exception& e)
{
std::string name;
EPrimitiveType p_type;
ar(name, p_type);

if (name.compare(std::string(get_name())) != 0)
throw ShogunException(
"cannot deserialize the object from file as num parameters does not match!");
ar(num_param_names);
}
auto param_names = parameter_names();
if (param_names.size() != num_param_names)
throw ShogunException(
"cannot deserialize the object from file as num parameters does not match!");

for (size_t i = 0; i < num_param_names; ++i)
{
std::string param_name;
ar(param_name);
if (!has(param_name))
throw ShogunException(
"cannot deserialize the object from file!");

BaseTag tag(param_name);
auto parameter = get_parameter(tag);
ar(parameter);
update_parameter(tag, parameter.get_value());
}
}
#endif // #ifndef SWIG // SWIG should skip this part

/** increase reference counter
Expand Down
1 change: 0 additions & 1 deletion src/shogun/base/class_list.h
Expand Up @@ -19,7 +19,6 @@
#include <string>

namespace shogun {
class CSGObject;

/** new shogun instance
* @param sgserializable_name
Expand Down

0 comments on commit 2c04e59

Please sign in to comment.