Skip to content

Commit

Permalink
Merge pull request #4676 from geektoni/feature/remove_macros_histograms
Browse files Browse the repository at this point in the history
Remove CHECK_TYPE_HISTO macro from TBOutputFormat.
  • Loading branch information
geektoni committed Jun 9, 2019
2 parents 2df3dcd + 0b7c5ae commit 2b7b795
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 67 deletions.
53 changes: 11 additions & 42 deletions src/shogun/io/TBOutputFormat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,28 +37,12 @@

#include <chrono>
#include <shogun/io/TBOutputFormat.h>
#include <shogun/lib/common.h>
#include <shogun/lib/observers/ObservedValueTemplated.h>
#include <shogun/lib/tfhistogram/histogram.h>
#include <shogun/lib/type_case.h>
#include <vector>

using namespace shogun;

#define CHECK_TYPE_HISTO(type) \
else if ( \
value.first->get_any().type_info().hash_code() == \
typeid(type).hash_code()) \
{ \
tensorflow::histogram::Histogram h; \
tensorflow::HistogramProto* hp = new tensorflow::HistogramProto(); \
auto v = any_cast<type>(value.first->get_any()); \
for (auto value_v : v) \
h.Add(value_v); \
h.EncodeToProto(hp, true); \
summaryValue->set_allocated_histo(hp); \
}

TBOutputFormat::TBOutputFormat(){};

TBOutputFormat::~TBOutputFormat(){};
Expand Down Expand Up @@ -98,33 +82,18 @@ tensorflow::Event TBOutputFormat::convert_vector(
summaryValue->set_tag(value.first->get<std::string>("name"));
summaryValue->set_node_name(node_name);

if (value.first->get_any().type_info().hash_code() ==
typeid(std::vector<int8_t>).hash_code())
{
tensorflow::histogram::Histogram h;
tensorflow::HistogramProto* hp = new tensorflow::HistogramProto();
auto v = any_cast<std::vector<int8_t>>(value.first->get_any());
for (auto value_v : v)
tensorflow::histogram::Histogram h;
tensorflow::HistogramProto* hp = new tensorflow::HistogramProto();

auto write_summary = [&h](auto val) {
for (auto value_v : val)
h.Add(value_v);
h.EncodeToProto(hp, true);
summaryValue->set_allocated_histo(hp);
}
CHECK_TYPE_HISTO(std::vector<uint8_t>)
CHECK_TYPE_HISTO(std::vector<int16_t>)
CHECK_TYPE_HISTO(std::vector<uint16_t>)
CHECK_TYPE_HISTO(std::vector<int32_t>)
CHECK_TYPE_HISTO(std::vector<uint32_t>)
CHECK_TYPE_HISTO(std::vector<int64_t>)
CHECK_TYPE_HISTO(std::vector<uint64_t>)
CHECK_TYPE_HISTO(std::vector<float32_t>)
CHECK_TYPE_HISTO(std::vector<float64_t>)
CHECK_TYPE_HISTO(std::vector<floatmax_t>)
CHECK_TYPE_HISTO(std::vector<char>)
else
{
SG_ERROR(
"Unsupported type %s", value.first->get_any().type_info().name());
}
};

sg_any_dispatch(value.first->get_any(), sg_all_typemap, None{}, write_summary);

h.EncodeToProto(hp, true);
summaryValue->set_allocated_histo(hp);

return e;
}
Expand Down
59 changes: 44 additions & 15 deletions src/shogun/lib/type_case.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ namespace shogun
{
typedef Types<
bool, char, int8_t, uint8_t, int16_t, uint16_t, int32_t, uint32_t,
int64_t, uint64_t, float32_t, float64_t, floatmax_t, SGVector<int32_t>,
SGVector<int64_t>, SGVector<float32_t>, SGVector<float64_t>,
SGVector<floatmax_t>, SGMatrix<int32_t>, SGMatrix<int64_t>,
SGMatrix<float32_t>, SGMatrix<float64_t>, SGMatrix<floatmax_t>>
int64_t, uint64_t, float32_t, float64_t, floatmax_t, SGVector<char>,
SGVector<int8_t>, SGVector<int16_t>, SGVector<uint8_t>, SGVector<uint16_t>,
SGVector<uint32_t>, SGVector<uint64_t>, SGVector<int32_t>, SGVector<int64_t>,
SGVector<float32_t>, SGVector<float64_t>, SGVector<floatmax_t>, SGMatrix<int32_t>,
SGMatrix<int64_t>, SGMatrix<float32_t>, SGMatrix<float64_t>, SGMatrix<floatmax_t>>
SG_TYPES;

enum class TYPE
Expand All @@ -44,17 +45,24 @@ namespace shogun
T_FLOATMAX = 13,
T_SGOBJECT = 14,
T_COMPLEX128 = 15,
T_SGVECTOR_FLOAT32 = 16,
T_SGVECTOR_FLOAT64 = 17,
T_SGVECTOR_FLOATMAX = 18,
T_SGVECTOR_INT32 = 19,
T_SGVECTOR_INT64 = 20,
T_SGMATRIX_FLOAT32 = 21,
T_SGMATRIX_FLOAT64 = 22,
T_SGMATRIX_FLOATMAX = 23,
T_SGMATRIX_INT32 = 24,
T_SGMATRIX_INT64 = 25,
T_UNDEFINED = 26
T_SGVECTOR_CHAR = 16,
T_SGVECTOR_FLOAT32 = 17,
T_SGVECTOR_FLOAT64 = 18,
T_SGVECTOR_FLOATMAX = 19,
T_SGVECTOR_UINT8 = 20,
T_SGVECTOR_INT8 = 21,
T_SGVECTOR_INT16 = 22,
T_SGVECTOR_UINT16 = 23,
T_SGVECTOR_INT32 = 24,
T_SGVECTOR_UINT32 = 25,
T_SGVECTOR_INT64 = 26,
T_SGVECTOR_UINT64 = 27,
T_SGMATRIX_FLOAT32 = 28,
T_SGMATRIX_FLOAT64 = 29,
T_SGMATRIX_FLOATMAX = 30,
T_SGMATRIX_INT32 = 31,
T_SGMATRIX_INT64 = 32,
T_UNDEFINED = 33
};
typedef std::unordered_map<std::type_index, TYPE> typemap;
namespace type_internal
Expand Down Expand Up @@ -128,11 +136,18 @@ namespace shogun
SG_ADD_PRIMITIVE_TYPE(float64_t, TYPE::T_FLOAT64)
SG_ADD_PRIMITIVE_TYPE(floatmax_t, TYPE::T_FLOATMAX)
SG_ADD_PRIMITIVE_TYPE(complex128_t, TYPE::T_COMPLEX128)
SG_ADD_SGVECTOR_TYPE(SGVector<char>, TYPE::T_SGVECTOR_CHAR)
SG_ADD_SGVECTOR_TYPE(SGVector<float32_t>, TYPE::T_SGVECTOR_FLOAT32)
SG_ADD_SGVECTOR_TYPE(SGVector<float64_t>, TYPE::T_SGVECTOR_FLOAT64)
SG_ADD_SGVECTOR_TYPE(SGVector<floatmax_t>, TYPE::T_SGVECTOR_FLOATMAX)
SG_ADD_SGVECTOR_TYPE(SGVector<int8_t>, TYPE::T_SGVECTOR_INT8)
SG_ADD_SGVECTOR_TYPE(SGVector<int16_t>, TYPE::T_SGVECTOR_INT16)
SG_ADD_SGVECTOR_TYPE(SGVector<uint16_t>, TYPE::T_SGVECTOR_UINT16)
SG_ADD_SGVECTOR_TYPE(SGVector<uint8_t>, TYPE::T_SGVECTOR_UINT8)
SG_ADD_SGVECTOR_TYPE(SGVector<int32_t>, TYPE::T_SGVECTOR_INT32)
SG_ADD_SGVECTOR_TYPE(SGVector<int64_t>, TYPE::T_SGVECTOR_INT64)
SG_ADD_SGVECTOR_TYPE(SGVector<uint32_t>, TYPE::T_SGVECTOR_UINT32)
SG_ADD_SGVECTOR_TYPE(SGVector<uint64_t>, TYPE::T_SGVECTOR_UINT64)
SG_ADD_SGMATRIX_TYPE(SGMatrix<float32_t>, TYPE::T_SGMATRIX_FLOAT32)
SG_ADD_SGMATRIX_TYPE(SGMatrix<float64_t>, TYPE::T_SGMATRIX_FLOAT64)
SG_ADD_SGMATRIX_TYPE(SGMatrix<floatmax_t>, TYPE::T_SGMATRIX_FLOATMAX)
Expand Down Expand Up @@ -416,23 +431,37 @@ static const typemap sg_all_typemap = {
ADD_TYPE_TO_MAP(float64_t , TYPE::T_FLOAT64)
ADD_TYPE_TO_MAP(floatmax_t , TYPE::T_FLOATMAX)
ADD_TYPE_TO_MAP(complex128_t, TYPE::T_COMPLEX128)
ADD_TYPE_TO_MAP(SGVector<char>, TYPE::T_SGVECTOR_CHAR)
ADD_TYPE_TO_MAP(SGVector<float32_t>, TYPE::T_SGVECTOR_FLOAT32)
ADD_TYPE_TO_MAP(SGVector<float64_t>, TYPE::T_SGVECTOR_FLOAT64)
ADD_TYPE_TO_MAP(SGVector<floatmax_t>, TYPE::T_SGVECTOR_FLOATMAX)
ADD_TYPE_TO_MAP(SGVector<int8_t>, TYPE::T_SGVECTOR_INT8)
ADD_TYPE_TO_MAP(SGVector<int16_t>, TYPE::T_SGVECTOR_INT16)
ADD_TYPE_TO_MAP(SGVector<uint8_t>, TYPE::T_SGVECTOR_UINT8)
ADD_TYPE_TO_MAP(SGVector<uint16_t>, TYPE::T_SGVECTOR_UINT16)
ADD_TYPE_TO_MAP(SGVector<int32_t>, TYPE::T_SGVECTOR_INT32)
ADD_TYPE_TO_MAP(SGVector<int64_t>, TYPE::T_SGVECTOR_INT64)
ADD_TYPE_TO_MAP(SGVector<uint32_t>, TYPE::T_SGVECTOR_UINT32)
ADD_TYPE_TO_MAP(SGVector<uint64_t>, TYPE::T_SGVECTOR_UINT64)
ADD_TYPE_TO_MAP(SGMatrix<float32_t>, TYPE::T_SGMATRIX_FLOAT32)
ADD_TYPE_TO_MAP(SGMatrix<float64_t>, TYPE::T_SGMATRIX_FLOAT64)
ADD_TYPE_TO_MAP(SGMatrix<floatmax_t>, TYPE::T_SGMATRIX_FLOATMAX)
ADD_TYPE_TO_MAP(SGMatrix<int32_t>, TYPE::T_SGMATRIX_INT32)
ADD_TYPE_TO_MAP(SGMatrix<int64_t>, TYPE::T_SGMATRIX_INT64)
};
static const typemap sg_vector_typemap = {
ADD_TYPE_TO_MAP(SGVector<char>, TYPE::T_SGVECTOR_CHAR)
ADD_TYPE_TO_MAP(SGVector<float32_t>, TYPE::T_SGVECTOR_FLOAT32)
ADD_TYPE_TO_MAP(SGVector<float64_t>, TYPE::T_SGVECTOR_FLOAT64)
ADD_TYPE_TO_MAP(SGVector<floatmax_t>, TYPE::T_SGVECTOR_FLOATMAX)
ADD_TYPE_TO_MAP(SGVector<int8_t>, TYPE::T_SGVECTOR_INT8)
ADD_TYPE_TO_MAP(SGVector<int16_t>, TYPE::T_SGVECTOR_INT16)
ADD_TYPE_TO_MAP(SGVector<uint8_t>, TYPE::T_SGVECTOR_UINT8)
ADD_TYPE_TO_MAP(SGVector<uint16_t>, TYPE::T_SGVECTOR_UINT16)
ADD_TYPE_TO_MAP(SGVector<int32_t>, TYPE::T_SGVECTOR_INT32)
ADD_TYPE_TO_MAP(SGVector<int64_t>, TYPE::T_SGVECTOR_INT64)
ADD_TYPE_TO_MAP(SGVector<uint32_t>, TYPE::T_SGVECTOR_UINT32)
ADD_TYPE_TO_MAP(SGVector<uint64_t>, TYPE::T_SGVECTOR_UINT64)
};
static const typemap sg_matrix_typemap = {
ADD_TYPE_TO_MAP(SGMatrix<float32_t>, TYPE::T_SGMATRIX_FLOAT32)
Expand Down
21 changes: 11 additions & 10 deletions tests/unit/io/TBOutputFormat_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
#include <tflogger/summary.pb.h>
#include <utility>
#include <vector>
#include <shogun/lib/SGVector.h>

using namespace shogun;

Expand Down Expand Up @@ -90,7 +91,7 @@ void test_case_scalar_error(T value_val)
}

template <class T>
void test_case_vector(std::vector<T> v)
void test_case_vector(SGVector<T> v)
{
tensorflow::Event event_ex;
auto summary = event_ex.mutable_summary();
Expand All @@ -109,7 +110,7 @@ void test_case_vector(std::vector<T> v)

time_point timestamp;
Some<ObservedValue> emitted_value = Some<ObservedValue>::from_raw(
new ObservedValueTemplated<std::vector<T>>(
new ObservedValueTemplated<SGVector<T>>(
1, "test", "test description", v));

std::string node_name = "node";
Expand All @@ -125,13 +126,13 @@ void test_case_vector(std::vector<T> v)
}

template <class T>
void test_case_vector_error(std::vector<T> v)
void test_case_vector_error(SGVector<T> v)
{
TBOutputFormat tmp;

time_point timestamp;
Some<ObservedValue> emitted_value = Some<ObservedValue>::from_raw(
new ObservedValueTemplated<std::vector<T>>(
new ObservedValueTemplated<SGVector<T>>(
1, "test", "test_description", v));

std::string node_name = "node";
Expand Down Expand Up @@ -159,17 +160,17 @@ TEST(TBOutputFormatTest, fail_convert_scalar)

TYPED_TEST(TBOutputFormatTest, convert_all_types_histo)
{
std::vector<TypeParam> v;
v.push_back((TypeParam)1);
v.push_back((TypeParam)2);
SGVector<TypeParam> v(2);
v[0] = ((TypeParam)1);
v[1] = ((TypeParam)2);
test_case_vector<TypeParam>(v);
};

TEST(TBOutputFormat, fail_convert_histo)
{
std::vector<complex128_t> v;
v.push_back((complex128_t)1);
v.push_back((complex128_t)2);
SGVector<complex128_t> v(2);
v[0] = ((complex128_t)1);
v[1] = ((complex128_t)2);
test_case_vector_error<complex128_t>(v);
}

Expand Down

0 comments on commit 2b7b795

Please sign in to comment.