Skip to content

Commit

Permalink
[ShogunBoard] Convert observable to observable.timestamp().
Browse files Browse the repository at this point in the history
We can now get directly from the emitted object the time when it
was produced.
  • Loading branch information
geektoni committed Jul 14, 2017
1 parent d6ba3f4 commit dee9171
Show file tree
Hide file tree
Showing 11 changed files with 170 additions and 76 deletions.
21 changes: 10 additions & 11 deletions src/shogun/base/SGObject.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -808,28 +808,27 @@ bool CSGObject::type_erased_has(const BaseTag& _tag) const

void CSGObject::subscribe_to_parameters(ParameterObserverInterface* obs)
{
auto sub =
rxcpp::make_subscriber<ParameterObserverInterface::ObservedValue>(
[obs](ParameterObserverInterface::ObservedValue e) {
obs->on_next(e);
},
[obs](std::exception_ptr ep) { obs->on_error(ep); },
[obs]() { obs->on_complete(); });
auto sub = rxcpp::make_subscriber<TimedObservedValue>(
[obs](TimedObservedValue e) { obs->on_next(e); },
[obs](std::exception_ptr ep) { obs->on_error(ep); },
[obs]() { obs->on_complete(); });

// Create an observable which emits values only if they are about
// parameters selected by the observable.
auto subscription =
m_observable_params
->filter([obs](ParameterObserverInterface::ObservedValue v) {
return obs->filter(v.second.first);
})
->filter([obs](ObservedValue v) { return obs->filter(v.name); })
.timestamp()
.subscribe(sub);
}

void CSGObject::observe_scalar(
const int64_t step, const std::string& name, const Any& value)
{
auto tmp = std::make_pair(step, std::make_pair(name, value));
ObservedValue tmp;
tmp.step = step;
tmp.name = name;
tmp.value = value;
m_subscriber_params->on_next(tmp);
}

Expand Down
12 changes: 4 additions & 8 deletions src/shogun/base/SGObject.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,12 @@ enum EGradientAvailability
class CSGObject
{
public:
typedef rxcpp::subjects::subject<ParameterObserverInterface::ObservedValue>
SGSubject;
typedef rxcpp::observable<
ParameterObserverInterface::ObservedValue,
rxcpp::dynamic_observable<ParameterObserverInterface::ObservedValue>>
typedef rxcpp::subjects::subject<ObservedValue> SGSubject;
typedef rxcpp::observable<ObservedValue,
rxcpp::dynamic_observable<ObservedValue>>
SGObservable;
typedef rxcpp::subscriber<
ParameterObserverInterface::ObservedValue,
rxcpp::observer<ParameterObserverInterface::ObservedValue, void, void,
void, void>>
ObservedValue, rxcpp::observer<ObservedValue, void, void, void, void>>
SGSubscriber;

/** default constructor */
Expand Down
50 changes: 22 additions & 28 deletions src/shogun/io/TBOutputFormat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,19 @@
using namespace shogun;

#define CHECK_TYPE(type) \
else if (value.second.type_info().hash_code() == typeid(type).hash_code()) \
else if ( \
value.first.value.type_info().hash_code() == typeid(type).hash_code()) \
{ \
summaryValue->set_simple_value(recall_type<type>(value.second)); \
summaryValue->set_simple_value(recall_type<type>(value.first.value)); \
}

#define CHECK_TYPE_HISTO(type) \
else if (value.second.type_info().hash_code() == typeid(type).hash_code()) \
else if ( \
value.first.value.type_info().hash_code() == typeid(type).hash_code()) \
{ \
tensorflow::histogram::Histogram h; \
tensorflow::HistogramProto* hp = new tensorflow::HistogramProto(); \
auto v = recall_type<type>(value.second); \
auto v = recall_type<type>(value.first.value); \
for (auto value_v : v) \
h.Add(value_v); \
h.EncodeToProto(hp, true); \
Expand All @@ -66,25 +68,21 @@ TBOutputFormat::TBOutputFormat(){};
TBOutputFormat::~TBOutputFormat(){};

tensorflow::Event TBOutputFormat::convert_scalar(
const int64_t& event_step, const std::pair<std::string, Any>& value,
std::string& node_name)
const TimedObservedValue& value, std::string& node_name)
{
auto millisec = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();

tensorflow::Event e;
e.set_wall_time(millisec);
e.set_step(event_step);
std::time_t now_t = convert_to_time_t(value.second);
e.set_wall_time(now_t);
e.set_step(value.first.step);

tensorflow::Summary* summary = e.mutable_summary();
auto summaryValue = summary->add_value();
summaryValue->set_tag(value.first);
summaryValue->set_tag(value.first.name);
summaryValue->set_node_name(node_name);

if (value.second.type_info().hash_code() == typeid(int8_t).hash_code())
if (value.first.value.type_info().hash_code() == typeid(int8_t).hash_code())
{
summaryValue->set_simple_value(recall_type<int8_t>(value.second));
summaryValue->set_simple_value(recall_type<int8_t>(value.first.value));
}
CHECK_TYPE(uint8_t)
CHECK_TYPE(int16_t)
Expand All @@ -99,35 +97,31 @@ tensorflow::Event TBOutputFormat::convert_scalar(
CHECK_TYPE(char)
else
{
SG_ERROR("Unsupported type %s", value.second.type_info().name());
SG_ERROR("Unsupported type %s", value.first.value.type_info().name());
}

return e;
}

tensorflow::Event TBOutputFormat::convert_vector(
const int64_t& event_step, const std::pair<std::string, Any>& value,
std::string& node_name)
const TimedObservedValue& value, std::string& node_name)
{
auto millisec = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();

tensorflow::Event e;
e.set_wall_time(millisec);
e.set_step(event_step);
std::time_t now_t = convert_to_time_t(value.second);
e.set_wall_time(now_t);
e.set_step(value.first.step);

tensorflow::Summary* summary = e.mutable_summary();
auto summaryValue = summary->add_value();
summaryValue->set_tag(value.first);
summaryValue->set_tag(value.first.name);
summaryValue->set_node_name(node_name);

if (value.second.type_info().hash_code() ==
if (value.first.value.type_info().hash_code() ==
typeid(std::vector<int8_t>).hash_code())
{
tensorflow::histogram::Histogram h;
tensorflow::HistogramProto* hp = new tensorflow::HistogramProto();
auto v = recall_type<std::vector<int8_t>>(value.second);
auto v = recall_type<std::vector<int8_t>>(value.first.value);
for (auto value_v : v)
h.Add(value_v);
h.EncodeToProto(hp, true);
Expand All @@ -146,7 +140,7 @@ tensorflow::Event TBOutputFormat::convert_vector(
CHECK_TYPE_HISTO(std::vector<char>)
else
{
SG_ERROR("Unsupported type %s", value.second.type_info().name());
SG_ERROR("Unsupported type %s", value.first.value.type_info().name());
}

return e;
Expand Down
11 changes: 5 additions & 6 deletions src/shogun/io/TBOutputFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#define SHOGUN_OUTPUTFORMAT_H

#include <shogun/base/SGObject.h>
#include <shogun/lib/ObservedValue.h>
#include <shogun/lib/any.h>
#include <tflogger/event.pb.h>

Expand All @@ -64,13 +65,11 @@ namespace shogun
* @param node_name the node name (default: node)
* @return the newly created tensorflow::Event
*/
tensorflow::Event convert_scalar(
const int64_t& event_step, const std::pair<std::string, Any>& value,
std::string& node_name);
tensorflow::Event
convert_scalar(const TimedObservedValue& value, std::string& node_name);

tensorflow::Event convert_vector(
const int64_t& event_step, const std::pair<std::string, Any>& value,
std::string& node_name);
tensorflow::Event
convert_vector(const TimedObservedValue& value, std::string& node_name);

virtual const char* get_name() const
{
Expand Down
85 changes: 85 additions & 0 deletions src/shogun/lib/ObservedValue.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* BSD 3-Clause License
*
* Copyright (c) 2017, Shogun-Toolbox e.V. <shogun-team@shogun-toolbox.org>
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* * Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* * Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* * Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
* Written (W) 2017 Giovanni De Toni
*
*/

#ifndef SHOGUN_OBSERVEDVALUE_H
#define SHOGUN_OBSERVEDVALUE_H

#include <chrono>
#include <shogun/lib/any.h>
#include <utility>

/**
* Definitions of basic object with are needed by the Parameter
* Observer architecture.
*/
namespace shogun
{
/* Chrono timepoint */
typedef std::chrono::time_point<
std::chrono::_V2::steady_clock,
std::chrono::duration<long int, std::ratio<1l, 1000000000l>>>
time_point;

/* One observed value, composed of:
* - step (for the graph x axis);
* - parameter's name;
* - parameter's value (Any wrapped);
*/
struct ObservedValue
{
int64_t step;
std::string name;
Any value;
};

/**
* Observed value with a timestamp
*/
typedef std::pair<ObservedValue, time_point> TimedObservedValue;

/**
* Helper method to convert a time_point to std::time_t
* @param value time point we want to convert
* @return the time point converted to std::time_t
*/
inline std::time_t convert_to_time_t(const time_point& value)
{
return std::chrono::system_clock::to_time_t(
std::chrono::system_clock::now() +
(value - std::chrono::steady_clock::now()));
}
}

#endif // SHOGUN_OBSERVEDVALUE_H
5 changes: 2 additions & 3 deletions src/shogun/lib/ParameterObserverHistogram.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,11 @@ ParameterObserverHistogram::~ParameterObserverHistogram()
{
}

void ParameterObserverHistogram::on_next(const ObservedValue& value)
void ParameterObserverHistogram::on_next(const TimedObservedValue& value)
{
auto node_name = std::string("node");
auto format = TBOutputFormat();
auto event_value =
format.convert_vector(value.first, value.second, node_name);
auto event_value = format.convert_vector(value, node_name);
m_writer.writeEvent(event_value);
}

Expand Down
2 changes: 1 addition & 1 deletion src/shogun/lib/ParameterObserverHistogram.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ namespace shogun

virtual bool filter(const std::string& param);

virtual void on_next(const ObservedValue& value);
virtual void on_next(const TimedObservedValue& value);
virtual void on_error(std::exception_ptr);
virtual void on_complete();
};
Expand Down
9 changes: 2 additions & 7 deletions src/shogun/lib/ParameterObserverInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@
#define SHOGUN_PARAMETEROBSERVERINTERFACE_H

#include <stdexcept>
#include <utility>
#include <vector>

#include <shogun/lib/ObservedValue.h>
#include <shogun/lib/any.h>

namespace shogun
Expand All @@ -50,11 +50,6 @@ namespace shogun
{

public:
/* One observed value, composed of:
* - step (for the graph x axis);
* - a pair composed of: parameter's name + parameter's value
*/
typedef std::pair<int64_t, std::pair<std::string, Any>> ObservedValue;

/**
* Default constructor
Expand Down Expand Up @@ -92,7 +87,7 @@ namespace shogun
* value.
* @param value the value emitted by the parameter observable
*/
virtual void on_next(const ObservedValue& value) = 0;
virtual void on_next(const TimedObservedValue& value) = 0;
/**
* Method which will be called on errors
*/
Expand Down
5 changes: 2 additions & 3 deletions src/shogun/lib/ParameterObserverScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,11 @@ ParameterObserverScalar::~ParameterObserverScalar()
{
}

void ParameterObserverScalar::on_next(const ObservedValue& value)
void ParameterObserverScalar::on_next(const TimedObservedValue& value)
{
auto node_name = std::string("node");
auto format = TBOutputFormat();
auto event_value =
format.convert_scalar(value.first, value.second, node_name);
auto event_value = format.convert_scalar(value, node_name);
m_writer.writeEvent(event_value);
}

Expand Down
2 changes: 1 addition & 1 deletion src/shogun/lib/ParameterObserverScalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ namespace shogun

virtual bool filter(const std::string& param);

virtual void on_next(const ObservedValue& value);
virtual void on_next(const TimedObservedValue& value);
virtual void on_error(std::exception_ptr);
virtual void on_complete();
};
Expand Down
Loading

0 comments on commit dee9171

Please sign in to comment.