Skip to content

Commit

Permalink
[wip] add cereal visitor et al
Browse files Browse the repository at this point in the history
  • Loading branch information
vigsterkr committed Apr 17, 2018
1 parent 7bbf097 commit 67bf810
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 133 deletions.
19 changes: 18 additions & 1 deletion src/shogun/base/AnyParameter.h
Expand Up @@ -2,6 +2,7 @@
#define __ANYPARAMETER_H__

#include <shogun/lib/any.h>
#include <shogun/io/CerealVisitor.h>

#include <string>

Expand Down Expand Up @@ -67,7 +68,7 @@ namespace shogun
template<class Archive>
void serialize(Archive& ar)
{
ar(m_model_selection, m_gradient);
ar(m_description, m_model_selection, m_gradient);
}

private:
Expand Down Expand Up @@ -132,6 +133,22 @@ namespace shogun
ar(m_value, m_properties);
}

template <class Archive>
void cereal_save(Archive& ar)
{
std::unique_ptr<CerealWriterVisitor<Archive>> visitor(new CerealWriterVisitor<Archive>(ar));
m_value.visit(visitor.get());
ar(m_properties);
}

template <class Archive>
void cereal_load(Archive& ar)
{
std::unique_ptr<CerealReaderVisitor<Archive>> visitor(new CerealReaderVisitor<Archive>(ar));
m_value.visit(visitor.get());
ar(m_properties);
}

private:
Any m_value;
AnyParameterProperties m_properties;
Expand Down
96 changes: 14 additions & 82 deletions src/shogun/base/SGObject.cpp
Expand Up @@ -244,74 +244,6 @@ int32_t CSGObject::unref()
}
}

void CSGObject::save_binary(const char* filename) const
{
std::ofstream os(filename);
cereal::BinaryOutputArchive archive(os);
archive(cereal::make_nvp(this->get_name(), *this));
}

void CSGObject::save_json(const char* filename) const
{
std::ofstream os(filename);
cereal::JSONOutputArchive archive(os);
archive(cereal::make_nvp(this->get_name(), *this));
}

void CSGObject::save_xml(const char* filename) const
{
std::ofstream os(filename);
cereal::XMLOutputArchive archive(os);
archive(cereal::make_nvp(this->get_name(), *this));
}

void CSGObject::load_binary(const char* filename)
{
std::ifstream is(filename);
cereal::BinaryInputArchive archive(is);
archive(*this);
}

void CSGObject::load_json(const char* filename)
{
std::ifstream is(filename);
cereal::JSONInputArchive archive(is);
archive(*this);
}

void CSGObject::load_xml(const char* filename)
{
std::ifstream is(filename);
cereal::XMLInputArchive archive(is);
archive(*this);
}

template <class Archive>
void CSGObject::cereal_save(Archive& ar) const
{
for (const auto& it : self->map)
ar(cereal::make_nvp(it.first.name(), it.second));
}
template void CSGObject::cereal_save<cereal::BinaryOutputArchive>(
cereal::BinaryOutputArchive& ar) const;
template void CSGObject::cereal_save<cereal::JSONOutputArchive>(
cereal::JSONOutputArchive& ar) const;
template void CSGObject::cereal_save<cereal::XMLOutputArchive>(
cereal::XMLOutputArchive& ar) const;

template <class Archive>
void CSGObject::cereal_load(Archive& ar)
{
for (auto& it : self->map)
ar(it.second);
}
template void CSGObject::cereal_load<cereal::BinaryInputArchive>(
cereal::BinaryInputArchive& ar);
template void
CSGObject::cereal_load<cereal::JSONInputArchive>(cereal::JSONInputArchive& ar);
template void
CSGObject::cereal_load<cereal::XMLInputArchive>(cereal::XMLInputArchive& ar);

#ifdef TRACE_MEMORY_ALLOCS
#include <shogun/lib/Map.h>
extern CMap<void*, shogun::MemoryBlock>* sg_mallocs;
Expand Down Expand Up @@ -899,31 +831,31 @@ class ToStringVisitor : public AnyVisitor
{
}

virtual void on(const bool* v)
virtual void on(bool* v)
{
stream() << (*v ? "true" : "false");
}
virtual void on(const int32_t* v)
virtual void on(int32_t* v)
{
stream() << *v;
}
virtual void on(const int64_t* v)
virtual void on(int64_t* v)
{
stream() << *v;
}
virtual void on(const float* v)
virtual void on(float* v)
{
stream() << *v;
}
virtual void on(const double* v)
virtual void on(double* v)
{
stream() << *v;
}
virtual void on(long double* v)
{
stream() << *v;
}
virtual void on(const CSGObject** v)
virtual void on(CSGObject** v)
{
if (*v)
{
Expand All @@ -934,27 +866,27 @@ class ToStringVisitor : public AnyVisitor
stream() << "null";
}
}
virtual void on(const SGVector<int>* v)
virtual void on(SGVector<int>* v)
{
to_string(v);
}
virtual void on(const SGVector<float>* v)
virtual void on(SGVector<float>* v)
{
to_string(v);
}
virtual void on(const SGVector<double>* v)
virtual void on(SGVector<double>* v)
{
to_string(v);
}
virtual void on(const SGMatrix<int>* mat)
virtual void on(SGMatrix<int>* mat)
{
to_string(mat);
}
virtual void on(const SGMatrix<float>* mat)
virtual void on(SGMatrix<float>* mat)
{
to_string(mat);
}
virtual void on(const SGMatrix<double>* mat)
virtual void on(SGMatrix<double>* mat)
{
to_string(mat);
}
Expand All @@ -966,7 +898,7 @@ class ToStringVisitor : public AnyVisitor
}

template <class T>
void to_string(const SGMatrix<T>* m)
void to_string(SGMatrix<T>* m)
{
if (m)
{
Expand All @@ -990,7 +922,7 @@ class ToStringVisitor : public AnyVisitor
}

template <class T>
void to_string(const SGVector<T>* v)
void to_string(SGVector<T>* v)
{
if (v)
{
Expand Down
69 changes: 31 additions & 38 deletions src/shogun/base/SGObject.h
Expand Up @@ -141,55 +141,48 @@ class CSGObject
virtual ~CSGObject();

#ifndef SWIG // SWIG should skip this part
/** serializes the SGObject to a binary file
*
* @param filename Binary archive filename
*/
void save_binary(const char* filename) const;

/** serializes the SGObject to a JSON file
*
* @param filename JSON archive filename
*/
void save_json(const char* filename) const;

/** serializes the SGObject to a XML file
*
* @param filename XML archive filename
*/
void save_xml(const char* filename) const;

/** loads SGObject from a Binary file
*
* @param filename Binary archive filename
*/
void load_binary(const char* filename);

/** loads SGObject from a JSON file
*
* @param filename JSON archive filename
*/
void load_json(const char* filename);

/** loads SGObject from a XML file
*
* @param filename XML archive filename
*/
void load_xml(const char* filename);

/** serializes SGObject parameters to Archive with Cereal
*
* @param ar Archive
*/
template <class Archive>
void cereal_save(Archive& ar) const;
void cereal_save(Archive& ar) const
{
ar(m_generic);
auto param_names = parameter_names();
ar(param_names.size());
for (auto param_name: param_names)
{
ar(param_name);
BaseTag tag(param_name);
ar(get_parameter(tag));
}
}

/** loads SGObject parameters from Archive with Cereal
*
* @param ar Archive
*/
template <class Archive>
void cereal_load(Archive& ar);
void cereal_load(Archive& ar)
{
ar(m_generic);
size_t num_param_names;
ar(num_param_names);
auto param_names = parameter_names();
if (param_names.size() != num_param_names)
throw ShogunException("cannot deserialize!");

for (auto i = 0; i < num_param_names; ++i)
{
std::string param_name;
ar(param_name);
BaseTag tag(param_name);
AnyParameter parameter;
ar(parameter);
update_parameter(tag, parameter.get_value());
}
}
#endif // #ifndef SWIG // SWIG should skip this part

/** increase reference counter
Expand Down

0 comments on commit 67bf810

Please sign in to comment.