Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ShogunBoard] Add a ParameterObserver (interface and implementation) …
…and add TBOutputFormat. Add interface and ParameterObserverScalar which are the classes needed to add parameter watchers to an algorithm. Add also TBOutputFormat class which converts an ObservedValue into a tensorflow::Event object. Other: * Add unit tests (ParameterObserverScalar and TBOutputFormat); * Add CMake switch to find TFLogger (or download and build it if not present); * Add SWIG code to make ParameterObserver visible to interfaces;
- Loading branch information
Showing
14 changed files
with
483 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
GetCompilers() | ||
|
||
include(ExternalProject) | ||
ExternalProject_Add( | ||
rxcpp | ||
PREFIX ${CMAKE_BINARY_DIR}/tflogger | ||
DOWNLOAD_DIR ${THIRD_PARTY_DIR}/tflogger | ||
URL https://github.com/shogun-toolbox/tflogger/archive/master.zip | ||
CMAKE_ARGS | ||
-DCMAKE_INSTALL_PREFIX:STRING=${CMAKE_BINARY_DIR}/src/shogun/lib/external | ||
-DCMAKE_C_COMPILER:STRING=${C_COMPILER} | ||
-DCMAKE_CXX_COMPILER:STRING=${CXX_COMPILER} | ||
BUILD_COMMAND "" | ||
) | ||
|
||
add_dependencies(libshogun tflogger) | ||
|
||
set(TFLogger_INCLUDE_DIR ${THIRD_PARTY_INCLUDE_DIR}) | ||
|
||
UNSET(C_COMPILER) | ||
UNSET(CXX_COMPILER) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
%include "std_vector.i" | ||
%include "std_string.i" | ||
%template(ParameterList) std::vector<std::string>; | ||
|
||
%{ | ||
#include <shogun/lib/ParameterObserverInterface.h> | ||
#include <shogun/lib/ParameterObserverScalar.h> | ||
%} | ||
|
||
%include <shogun/lib/ParameterObserverInterface.h> | ||
%include <shogun/lib/ParameterObserverScalar.h> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
/* | ||
* Written (W) 2017 Giovanni De Toni | ||
*/ | ||
|
||
#include <chrono> | ||
#include <shogun/io/TBOutputFormat.h> | ||
#include <shogun/lib/common.h> | ||
|
||
using namespace shogun; | ||
|
||
#define CHECK_TYPE(type)\ | ||
else if (\ | ||
value.second.type_info().hash_code() == typeid(type).hash_code())\ | ||
{\ | ||
summaryValue->set_simple_value(recall_type<type>(value.second));\ | ||
} | ||
|
||
TBOutputFormat::TBOutputFormat(){}; | ||
|
||
TBOutputFormat::~TBOutputFormat(){}; | ||
|
||
tensorflow::Event TBOutputFormat::convert_scalar( | ||
const int64_t& event_step, const std::pair<std::string, Any>& value, | ||
std::string& node_name) | ||
{ | ||
auto millisec = std::chrono::duration_cast<std::chrono::milliseconds>( | ||
std::chrono::system_clock::now().time_since_epoch()) | ||
.count(); | ||
|
||
tensorflow::Event e; | ||
e.set_wall_time(millisec); | ||
e.set_step(event_step); | ||
|
||
tensorflow::Summary* summary = e.mutable_summary(); | ||
auto summaryValue = summary->add_value(); | ||
summaryValue->set_tag(value.first); | ||
summaryValue->set_node_name(node_name); | ||
|
||
if (value.second.type_info().hash_code() == typeid(int8_t).hash_code()) | ||
{ | ||
summaryValue->set_simple_value(recall_type<int8_t>(value.second)); | ||
} | ||
CHECK_TYPE(uint8_t) | ||
CHECK_TYPE(int16_t) | ||
CHECK_TYPE(uint16_t) | ||
CHECK_TYPE(int32_t) | ||
CHECK_TYPE(uint32_t) | ||
CHECK_TYPE(int64_t) | ||
CHECK_TYPE(uint64_t) | ||
CHECK_TYPE(float32_t) | ||
CHECK_TYPE(float64_t) | ||
CHECK_TYPE(floatmax_t) | ||
CHECK_TYPE(char) | ||
else { | ||
SG_ERROR("Unsupported type %s", value.second.type_info().name()); | ||
} | ||
|
||
return e; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
/* | ||
* Written (W) 2017 Giovanni De Toni | ||
*/ | ||
|
||
#ifndef SHOGUN_OUTPUTFORMAT_H | ||
#define SHOGUN_OUTPUTFORMAT_H | ||
|
||
#include <shogun/base/SGObject.h> | ||
#include <shogun/lib/any.h> | ||
#include <tflogger/event.pb.h> | ||
|
||
#include <utility> | ||
|
||
namespace shogun | ||
{ | ||
/** | ||
* Convert an std::pair<std::string, Any> to a tensorflow::Event, | ||
* which can be written to file and used with tools like Tensorboard. | ||
*/ | ||
class TBOutputFormat : public CSGObject | ||
{ | ||
|
||
public: | ||
TBOutputFormat(); | ||
~TBOutputFormat(); | ||
|
||
/** | ||
* Generate a tensorflow::Event object give some informations | ||
* @param event_step the current event step | ||
* @param value the value which will be converted to tensorflow::Event | ||
* @param node_name the node name (default: node) | ||
* @return the newly created tensorflow::Event | ||
*/ | ||
tensorflow::Event convert_scalar( | ||
const int64_t& event_step, const std::pair<std::string, Any>& value, | ||
std::string& node_name); | ||
|
||
virtual const char * get_name() const | ||
{ | ||
return "TFLogger"; | ||
} | ||
}; | ||
} | ||
|
||
#endif // SHOGUN_OUTPUTFORMAT_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
#include <shogun/lib/ParameterObserverInterface.h> | ||
|
||
using namespace shogun; | ||
|
||
ParameterObserverInterface::ParameterObserverInterface() | ||
: m_parameters(), m_writer("shogun") | ||
{ | ||
} | ||
|
||
ParameterObserverInterface::ParameterObserverInterface( | ||
std::vector<std::string>& parameters) | ||
: m_parameters(parameters), m_writer("shogun") | ||
{ | ||
} | ||
|
||
ParameterObserverInterface::ParameterObserverInterface( | ||
const std::string& filename, std::vector<std::string>& parameters) | ||
: m_parameters(parameters), m_writer(filename.c_str()) | ||
{ | ||
} | ||
|
||
ParameterObserverInterface::~ParameterObserverInterface() | ||
{ | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
#ifndef SHOGUN_PARAMETEROBSERVERINTERFACE_H | ||
#define SHOGUN_PARAMETEROBSERVERINTERFACE_H | ||
|
||
#include <stdexcept> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include <rxcpp/rx-observable.hpp> | ||
#include <shogun/lib/any.h> | ||
#include <tflogger/tensorflow_logger.h> | ||
|
||
namespace shogun | ||
{ | ||
/** | ||
* Interface for the parameter observer classes | ||
*/ | ||
class ParameterObserverInterface | ||
{ | ||
|
||
public: | ||
|
||
/* One observed value, composed of: | ||
* - step (for the graph x axis); | ||
* - a pair composed of: parameter's name + parameter's value | ||
*/ | ||
typedef std::pair<int64_t, std::pair<std::string, Any>> ObservedValue; | ||
|
||
/** | ||
* Default constructor | ||
*/ | ||
ParameterObserverInterface(); | ||
|
||
/** | ||
* Constructor | ||
* @param parameters list of parameters which we want to watch over | ||
*/ | ||
ParameterObserverInterface(std::vector<std::string>& parameters); | ||
|
||
/** | ||
* Constructor | ||
* @param filename name of the generated output file | ||
* @param parameters list of parameters which we want to watch over | ||
*/ | ||
ParameterObserverInterface( | ||
const std::string& filename, std::vector<std::string>& parameters); | ||
/** | ||
* Virtual destructor | ||
*/ | ||
virtual ~ParameterObserverInterface(); | ||
|
||
/** | ||
* Filter function, check if the parameter name supplied is what | ||
* we want to monitor | ||
* @param param the param name | ||
* @return true if param is found inside of m_parameters list | ||
*/ | ||
virtual bool filter(const std::string& param) = 0; | ||
|
||
/** | ||
* Method which will be called when the parameter observable emits a | ||
* value. | ||
* @param value the value emitted by the parameter observable | ||
*/ | ||
virtual void on_next(const ObservedValue& value) = 0; | ||
/** | ||
* Method which will be called on errors | ||
*/ | ||
virtual void on_error(std::exception_ptr) = 0; | ||
/** | ||
* Method which will be called on completion | ||
*/ | ||
virtual void on_complete() = 0; | ||
|
||
protected: | ||
/** | ||
* List of parameter's names we want to monitor | ||
*/ | ||
std::vector<std::string> m_parameters; | ||
/** | ||
* Writer object which will be used to write tensorflow::Event files | ||
*/ | ||
tflogger::TensorFlowLogger m_writer; | ||
}; | ||
} | ||
|
||
#endif // SHOGUN_PARAMETEROBSERVER_H |
Oops, something went wrong.