-
-
Notifications
You must be signed in to change notification settings - Fork 1k
/
TBOutputFormat.cpp
123 lines (106 loc) · 3.57 KB
/
TBOutputFormat.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
/*
* Written (W) 2017 Giovanni De Toni
*/
#include <shogun/lib/config.h>
#ifdef HAVE_TFLOGGER
#include <chrono>
#include <vector>
#include <shogun/io/TBOutputFormat.h>
#include <shogun/lib/common.h>
#include <shogun/lib/tfhistogram/histogram.h>
using namespace shogun;
#define CHECK_TYPE(type)\
else if (\
value.second.type_info().hash_code() == typeid(type).hash_code())\
{\
summaryValue->set_simple_value(recall_type<type>(value.second));\
}
#define CHECK_TYPE_HISTO(type)\
else if (\
value.second.type_info().hash_code() == typeid(type).hash_code())\
{\
tensorflow::histogram::Histogram h;\
tensorflow::HistogramProto * hp = new tensorflow::HistogramProto();\
auto v = recall_type<type>(value.second);\
for (auto value_v : v)\
h.Add(value_v);\
h.EncodeToProto(hp, true);\
summaryValue->set_allocated_histo(hp);\
}
TBOutputFormat::TBOutputFormat(){};
TBOutputFormat::~TBOutputFormat(){};
tensorflow::Event TBOutputFormat::convert_scalar(
const int64_t& event_step, const std::pair<std::string, Any>& value,
std::string& node_name)
{
auto millisec = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
tensorflow::Event e;
e.set_wall_time(millisec);
e.set_step(event_step);
tensorflow::Summary* summary = e.mutable_summary();
auto summaryValue = summary->add_value();
summaryValue->set_tag(value.first);
summaryValue->set_node_name(node_name);
if (value.second.type_info().hash_code() == typeid(int8_t).hash_code())
{
summaryValue->set_simple_value(recall_type<int8_t>(value.second));
}
CHECK_TYPE(uint8_t)
CHECK_TYPE(int16_t)
CHECK_TYPE(uint16_t)
CHECK_TYPE(int32_t)
CHECK_TYPE(uint32_t)
CHECK_TYPE(int64_t)
CHECK_TYPE(uint64_t)
CHECK_TYPE(float32_t)
CHECK_TYPE(float64_t)
CHECK_TYPE(floatmax_t)
CHECK_TYPE(char)
else {
SG_ERROR("Unsupported type %s", value.second.type_info().name());
}
return e;
}
tensorflow::Event TBOutputFormat::convert_vector(
const int64_t& event_step, const std::pair<std::string, Any>& value,
std::string& node_name)
{
auto millisec = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
tensorflow::Event e;
e.set_wall_time(millisec);
e.set_step(event_step);
tensorflow::Summary* summary = e.mutable_summary();
auto summaryValue = summary->add_value();
summaryValue->set_tag(value.first);
summaryValue->set_node_name(node_name);
if (value.second.type_info().hash_code() == typeid(std::vector<int8_t>).hash_code())
{
tensorflow::histogram::Histogram h;
tensorflow::HistogramProto * hp = new tensorflow::HistogramProto();
auto v = recall_type<std::vector<int8_t>>(value.second);
for (auto value_v : v)
h.Add(value_v);
h.EncodeToProto(hp, true);
summaryValue->set_allocated_histo(hp);
}
CHECK_TYPE_HISTO(std::vector<uint8_t>)
CHECK_TYPE_HISTO(std::vector<int16_t>)
CHECK_TYPE_HISTO(std::vector<uint16_t>)
CHECK_TYPE_HISTO(std::vector<int32_t>)
CHECK_TYPE_HISTO(std::vector<uint32_t>)
CHECK_TYPE_HISTO(std::vector<int64_t>)
CHECK_TYPE_HISTO(std::vector<uint64_t>)
CHECK_TYPE_HISTO(std::vector<float32_t>)
CHECK_TYPE_HISTO(std::vector<float64_t>)
CHECK_TYPE_HISTO(std::vector<floatmax_t>)
CHECK_TYPE_HISTO(std::vector<char>)
else {
SG_ERROR("Unsupported type %s", value.second.type_info().name());
}
return e;
}
#endif // HAVE_TFLOGGER