-
Notifications
You must be signed in to change notification settings - Fork 267
/
deserializing_weaver_op.cc
73 lines (65 loc) · 2.89 KB
/
deserializing_weaver_op.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
/* Copyright 2017 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow_fold/loom/weaver.h"
#include "tensorflow_fold/loom/weaver_op_base.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
namespace fold {
REGISTER_WEAVER_OP("DeserializingWeaver")
.Input("weaver_messages: string");
// A Weaver op which:
// 1. Reads one or more serialized WeaverMessages from `weaver_messages`, its
// input tensor.
// 2. Merges them if there are more than one, and
// 3. Creates output tensors that can drive the Loom using the resulting Weaver.
//
// (Item 3 is handled by WeaverOpBase.)
//
// Note: the reason merges are supported in this op to allow the user to
// pre-compute many WeaverMessages (for example, one per element of the training
// set) and then group them together into random mini-batches at run-time.
//
// A second reason merges are supported is that for large input examples, and
// large batch sizes, merges done in advance of `DeserializingWeaverOp` could
// push the resulting `WeaverMessage` over the protocol buffer size limit.
class DeserializingWeaverOp : public WeaverOpBase {
public:
explicit DeserializingWeaverOp(tensorflow::OpKernelConstruction *c)
: WeaverOpBase(c) {}
tensorflow::Status Weave(
tensorflow::OpKernelContext *c, Weaver* weaver) override {
auto weaver_messages = c->input(0).flat<string>();
if (weaver_messages.size() < 1) {
return tensorflow::errors::InvalidArgument(
"weaver_messages must contain at least one value.");
}
if (!weaver->Deserialize(weaver_messages(0))) {
return tensorflow::errors::Internal(
"Failed to deserialize WeaverMessage: ", weaver->error_string());
}
// Note: If necessary, this loop could be sped up by merging the messages in
// a multi-threaded way instead of in sequence.
for (int64 i = 1; i < weaver_messages.size(); ++i) {
if (!weaver->MergeFromSerialized(weaver_messages(i))) {
return tensorflow::errors::Internal(
"Failed to merge WeaverMessage", i, ":", weaver->error_string());
}
}
return tensorflow::Status::OK();
}
};
REGISTER_KERNEL_BUILDER(
Name("DeserializingWeaver").Device(tensorflow::DEVICE_CPU),
DeserializingWeaverOp);
} // namespace fold
} // namespace tensorflow