-
Notifications
You must be signed in to change notification settings - Fork 74k
/
auto_shard.cc
767 lines (672 loc) · 30.6 KB
/
auto_shard.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/data/auto_shard.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/match.h"
#include "tensorflow/core/data/dataset_utils.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils/functions.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace grappler {
namespace {
using tensorflow::data::AutoShardPolicy;
constexpr char kAssertCardinalityDatasetOpName[] = "AssertCardinalityDataset";
constexpr char kShardDatasetOpName[] = "ShardDataset";
constexpr char kShuffleDatasetOpName[] = "ShuffleDataset";
constexpr char kShuffleDatasetV2OpName[] = "ShuffleDatasetV2";
constexpr char kShuffleDatasetV3OpName[] = "ShuffleDatasetV3";
constexpr char kPrefetchDatasetOpName[] = "PrefetchDataset";
constexpr char kFinalizeDatasetOpName[] = "FinalizeDataset";
constexpr char kOptionsDatasetOpName[] = "OptionsDataset";
constexpr char kRebatchDatasetOpName[] = "RebatchDataset";
constexpr char kRebatchDatasetV2OpName[] = "RebatchDatasetV2";
constexpr char kTensorDatasetOpName[] = "TensorDataset";
constexpr char kTensorSliceDatasetOpName[] = "TensorSliceDataset";
constexpr char kPlaceholderOpName[] = "Placeholder";
constexpr char kConstOpName[] = "Const";
constexpr char kNumWorkersAttrName[] = "num_workers";
constexpr char kNumReplicasAttrName[] = "num_replicas";
constexpr char kIndexAttrName[] = "index";
constexpr char kAutoShardPolicyAttrName[] = "auto_shard_policy";
constexpr char kReshuffleEachIteration[] = "reshuffle_each_iteration";
constexpr char kOutputShapes[] = "output_shapes";
constexpr char kOutputTypes[] = "output_types";
// clang-format off
constexpr std::array<const char*, 5> kReaderDatasetOps = {
"FixedLengthRecordDataset",
"RecordIODataset",
"SSTableDataset",
"TextLineDataset",
"TFRecordDataset"
};
constexpr std::array<const char*, 2> kMultipleInputsDatasetOps = {
"ConcatenateDataset",
"ZipDataset"
};
constexpr std::array<const char*, 30> kPassThroughOps = {
"_Retval",
"AssertNextDataset",
"BatchDataset",
"CacheDataset",
"ExperimentalMapAndBatchDataset",
"ExperimentalParseExampleDataset",
"ExperimentalRebatchDataset",
"FilterDataset",
"FinalizeDataset",
"Identity",
"MapAndBatchDataset",
"MapDataset",
"MaxIntraOpParallelismDataset",
"ModelDataset",
"OptimizeDataset",
"OptionsDataset",
"PaddedBatchDataset",
"ParallelMapDataset",
"ParseExampleDataset",
"PrefetchDataset",
"PrivateThreadPoolDataset",
"ReduceDataset",
"RebatchDataset",
"RepeatDataset",
"ShardDataset",
"ShuffleAndRepeatDataset",
"ShuffleDataset",
"SkipDataset",
"TakeDataset",
"WindowDataset",
};
// TODO(frankchn): Process functions within kFuncDatasetOps as well.
constexpr std::array<const char*, 5> kFuncDatasetOps = {
"ExperimentalParallelInterleaveDataset",
"FlatMapDataset",
"InterleaveDataset",
"LegacyParallelInterleaveDataset",
"ParallelInterleaveDataset",
};
constexpr std::array<const char*, 5> kUnshardableSourceDatasetOps = {
"GeneratorDataset",
"RangeDataset",
"SparseTensorsSliceDataset",
"TensorDataset",
"TensorSliceDataset",
};
// clang-format on
Status OptimizeGraph(const GrapplerItem& item, int64 num_workers, int64 index,
AutoShardPolicy policy, int64 num_replicas,
GraphDef* output);
template <std::size_t SIZE>
bool IsDatasetNodeOfType(const NodeDef& node,
const std::array<const char*, SIZE>& arr) {
for (const auto& dataset_op_name : arr) {
if (tensorflow::data::MatchesAnyVersion(/*op_prefix=*/dataset_op_name,
/*op_to_match=*/node.op())) {
return true;
}
}
return false;
}
// Adds a ShardDataset node before `add_before`.
Status AddShardNode(MutableGraphView* graph, const NodeDef& add_before,
int64 num_workers, int64 index) {
NodeDef new_node;
new_node.set_op(kShardDatasetOpName);
graph_utils::SetUniqueGraphNodeName(kShardDatasetOpName, graph->graph(),
&new_node);
// Construct argument nodes
NodeDef* num_shards_node =
graph_utils::AddScalarConstNode<int64>(num_workers, graph);
NodeDef* index_node = graph_utils::AddScalarConstNode<int64>(index, graph);
// Add inputs to new node
new_node.add_input(add_before.input(0));
new_node.add_input(num_shards_node->name());
new_node.add_input(index_node->name());
// Ensure that each shard will have at least one element.
(*(new_node.mutable_attr()))["require_non_empty"].set_b(true);
// Add shapes and other attributes
NodeDef* add_after = graph->GetNode(add_before.input(0));
if (absl::StrContains(add_after->op(), "Dataset")) {
// We still may or may not have the right attributes because Datasets like
// TFRecordDataset doesn't have a output type or shape, and by default we
// set them to DT_STRING and an unknown shape.
if (add_after->attr().count(kOutputShapes) > 0) {
graph_utils::CopyAttribute(kOutputShapes, *add_after, &new_node);
} else {
tensorflow::TensorShapeProto* shape =
(*(new_node.mutable_attr()))[kOutputShapes]
.mutable_list()
->add_shape();
shape->set_unknown_rank(true);
}
if (add_after->attr().count(kOutputTypes) > 0) {
graph_utils::CopyAttribute(kOutputTypes, *add_after, &new_node);
} else if (add_after->attr().count("Toutput_types") > 0) {
(*(new_node.mutable_attr()))[kOutputTypes] =
add_after->attr().at("Toutput_types");
} else {
(*(new_node.mutable_attr()))[kOutputTypes].mutable_list()->add_type(
tensorflow::DataType::DT_STRING);
}
} else {
// TODO(frankchn): Make this work for datasets where input(0) is a Const,
// and we need to shard the Const.
// This is probably not a dataset, so we bail because we can't infer the
// output types and shape.
return errors::NotFound(
"Unable to shard this input. You may need to wrap the inputs to your "
"reader dataset in a TensorSliceDataset. Input node is ",
add_after->DebugString());
}
// Add new node into graph and update edges
NodeDef* new_node_graph = graph->AddNode(std::move(new_node));
TF_RETURN_IF_ERROR(
graph->UpdateFanouts(add_after->name(), new_node_graph->name()));
return Status::OK();
}
Status AddShuffleDataset(MutableGraphView* graph, const NodeDef& add_before,
const string& buffer_size_node,
const string& seed_node, const string& seed2_node,
bool reshuffle_each_iteration) {
NodeDef* add_after = graph->GetNode(add_before.input(0));
NodeDef new_node;
new_node.set_op(kShuffleDatasetOpName);
graph_utils::SetUniqueGraphNodeName(kShuffleDatasetOpName, graph->graph(),
&new_node);
new_node.add_input(add_before.input(0));
new_node.add_input(buffer_size_node);
new_node.add_input(seed_node);
new_node.add_input(seed2_node);
graph_utils::CopyAttribute(kOutputShapes, *add_after, &new_node);
graph_utils::CopyAttribute(kOutputTypes, *add_after, &new_node);
AttrValue reshuffle_attr;
reshuffle_attr.set_b(reshuffle_each_iteration);
(*new_node.mutable_attr())[kReshuffleEachIteration] = reshuffle_attr;
NodeDef* new_node_graph = graph->AddNode(std::move(new_node));
TF_RETURN_IF_ERROR(
graph->UpdateFanouts(add_after->name(), new_node_graph->name()));
return Status::OK();
}
Status AddShuffleDatasetV2(MutableGraphView* graph, const NodeDef& add_before,
const string& buffer_size_node,
const string& seed_generator_node) {
NodeDef* add_after = graph->GetNode(add_before.input(0));
NodeDef new_node;
new_node.set_op(kShuffleDatasetV2OpName);
graph_utils::SetUniqueGraphNodeName(kShuffleDatasetV2OpName, graph->graph(),
&new_node);
new_node.add_input(add_before.input(0));
new_node.add_input(buffer_size_node);
new_node.add_input(seed_generator_node);
graph_utils::CopyAttribute(kOutputShapes, *add_after, &new_node);
graph_utils::CopyAttribute(kOutputTypes, *add_after, &new_node);
NodeDef* new_node_graph = graph->AddNode(std::move(new_node));
TF_RETURN_IF_ERROR(
graph->UpdateFanouts(add_after->name(), new_node_graph->name()));
return Status::OK();
}
Status AddShuffleDatasetV3(MutableGraphView* graph, const NodeDef& add_before,
const string& buffer_size_node,
const string& seed_node, const string& seed2_node,
const string& seed_generator_node,
bool reshuffle_each_iteration) {
NodeDef* add_after = graph->GetNode(add_before.input(0));
NodeDef new_node;
new_node.set_op(kShuffleDatasetV3OpName);
graph_utils::SetUniqueGraphNodeName(kShuffleDatasetV3OpName, graph->graph(),
&new_node);
new_node.add_input(add_before.input(0));
new_node.add_input(buffer_size_node);
new_node.add_input(seed_node);
new_node.add_input(seed2_node);
new_node.add_input(seed_generator_node);
graph_utils::CopyAttribute(kOutputShapes, *add_after, &new_node);
graph_utils::CopyAttribute(kOutputTypes, *add_after, &new_node);
AttrValue reshuffle_attr;
reshuffle_attr.set_b(reshuffle_each_iteration);
(*new_node.mutable_attr())[kReshuffleEachIteration] = reshuffle_attr;
NodeDef* new_node_graph = graph->AddNode(std::move(new_node));
TF_RETURN_IF_ERROR(
graph->UpdateFanouts(add_after->name(), new_node_graph->name()));
return Status::OK();
}
bool ReaderOpInFunction(const NodeDef& node,
const FunctionLibraryDefinition& flib) {
const FunctionDef* func = flib.Find(node.attr().at("f").func().name());
for (int i = 0; i < func->node_def_size(); i++) {
NodeDef node_in_func = func->node_def(i);
if (IsDatasetNodeOfType(node_in_func, kReaderDatasetOps) &&
node_in_func.input_size() > 0 &&
absl::StartsWith(node_in_func.input(0), "args_0")) {
return true;
}
if (IsDatasetNodeOfType(func->node_def(i), kFuncDatasetOps) &&
ReaderOpInFunction(func->node_def(i), flib)) {
return true;
}
}
return false;
}
Status RemoveShuffleDataset(MutableGraphView* graph, const NodeDef& node,
absl::flat_hash_set<string>* nodes_to_delete,
string* op_name, string* buffer_size_node,
string* seed_node, string* seed2_node,
bool* reshuffle_each_iteration) {
if (node.op() == kShuffleDatasetOpName) {
*op_name = node.op();
*buffer_size_node = node.input(1);
*seed_node = node.input(2);
*seed2_node = node.input(3);
*reshuffle_each_iteration = node.attr().at(kReshuffleEachIteration).b();
TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0)));
nodes_to_delete->insert(node.name());
}
for (const auto& fanin : graph->GetFanins(node, true)) {
TF_RETURN_IF_ERROR(RemoveShuffleDataset(
graph, *fanin.node, nodes_to_delete, op_name, buffer_size_node,
seed_node, seed2_node, reshuffle_each_iteration));
}
// TODO(frankchn): Traverse functions too.
return Status::OK();
}
Status RemoveShuffleDatasetV2(MutableGraphView* graph, const NodeDef& node,
absl::flat_hash_set<string>* nodes_to_delete,
string* op_name, string* buffer_size_node,
string* seed_generator_node) {
if (node.op() == kShuffleDatasetV2OpName) {
*op_name = node.op();
*buffer_size_node = node.input(1);
*seed_generator_node = node.input(2);
TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0)));
nodes_to_delete->insert(node.name());
}
for (const auto& fanin : graph->GetFanins(node, true)) {
TF_RETURN_IF_ERROR(
RemoveShuffleDatasetV2(graph, *fanin.node, nodes_to_delete, op_name,
buffer_size_node, seed_generator_node));
}
// TODO(frankchn): Traverse functions too.
return Status::OK();
}
Status RemoveShuffleDatasetV3(MutableGraphView* graph, const NodeDef& node,
absl::flat_hash_set<string>* nodes_to_delete,
string* op_name, string* buffer_size_node,
string* seed_node, string* seed2_node,
string* seed_generator_node,
bool* reshuffle_each_iteration) {
if (node.op() == kShuffleDatasetV3OpName) {
*op_name = node.op();
*buffer_size_node = node.input(1);
*seed_node = node.input(2);
*seed2_node = node.input(3);
*seed_generator_node = node.input(4);
*reshuffle_each_iteration = node.attr().at(kReshuffleEachIteration).b();
TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0)));
nodes_to_delete->insert(node.name());
}
for (const auto& fanin : graph->GetFanins(node, true)) {
TF_RETURN_IF_ERROR(RemoveShuffleDatasetV3(
graph, *fanin.node, nodes_to_delete, op_name, buffer_size_node,
seed_node, seed2_node, seed_generator_node, reshuffle_each_iteration));
}
// TODO(frankchn): Traverse functions too.
return Status::OK();
}
Status ProcessDatasetSourceNode(MutableGraphView* graph, const NodeDef& node,
absl::flat_hash_set<string>* nodes_to_delete,
int64 num_workers, int64 index) {
string shuffle_op_name = "";
string buffer_size_node = "";
string seed_node = "";
string seed2_node = "";
string seed_generator_node = "";
bool reshuffle_each_iteration;
TF_RETURN_IF_ERROR(AddShardNode(graph, node, num_workers, index));
TF_RETURN_IF_ERROR(RemoveShuffleDataset(
graph, node, nodes_to_delete, &shuffle_op_name, &buffer_size_node,
&seed_node, &seed2_node, &reshuffle_each_iteration));
if (shuffle_op_name.empty()) {
TF_RETURN_IF_ERROR(
RemoveShuffleDatasetV2(graph, node, nodes_to_delete, &shuffle_op_name,
&buffer_size_node, &seed_generator_node));
}
if (shuffle_op_name.empty()) {
TF_RETURN_IF_ERROR(RemoveShuffleDatasetV3(
graph, node, nodes_to_delete, &shuffle_op_name, &buffer_size_node,
&seed_node, &seed2_node, &seed_generator_node,
&reshuffle_each_iteration));
}
if (shuffle_op_name == kShuffleDatasetOpName) {
TF_RETURN_IF_ERROR(AddShuffleDataset(graph, node, buffer_size_node,
seed_node, seed2_node,
reshuffle_each_iteration));
} else if (shuffle_op_name == kShuffleDatasetV2OpName) {
TF_RETURN_IF_ERROR(AddShuffleDatasetV2(graph, node, buffer_size_node,
seed_generator_node));
} else if (shuffle_op_name == kShuffleDatasetV3OpName) {
TF_RETURN_IF_ERROR(AddShuffleDatasetV3(
graph, node, buffer_size_node, seed_node, seed2_node,
seed_generator_node, reshuffle_each_iteration));
}
return Status::OK();
}
const NodeDef* FindFuncAndTensorSliceDataset(
const NodeDef* node, int64 num_workers, int64 index,
FunctionLibraryDefinition* flib, MutableGraphView* graph,
absl::flat_hash_set<string>* nodes_to_delete) {
if (IsDatasetNodeOfType(*node, kFuncDatasetOps)) {
const NodeDef* input_node = graph_utils::GetInputNode(*node, *graph, 0);
if (input_node->op() == kTensorSliceDatasetOpName ||
input_node->op() == kTensorDatasetOpName) {
const NodeDef* next_input_node =
graph_utils::GetInputNode(*input_node, *graph, 0);
if (next_input_node->op() == kPlaceholderOpName) {
return node;
}
}
}
if (!IsDatasetNodeOfType(*node, kPassThroughOps)) {
return nullptr;
}
// Sometimes there are other nodes between the last InterleaveDataset and the
// second to last FlatMapDataset, so we need to skip over those.
const NodeDef* input_node = graph_utils::GetInputNode(*node, *graph, 0);
return FindFuncAndTensorSliceDataset(input_node, num_workers, index, flib,
graph, nodes_to_delete);
}
Status RecursivelyHandleOp(const NodeDef& node, int64 num_workers, int64 index,
FunctionLibraryDefinition* flib,
MutableGraphView* graph,
absl::flat_hash_set<string>* nodes_to_delete) {
if (node.op() == kAssertCardinalityDatasetOpName) {
LOG(WARNING) << "The `assert_cardinality` transformation is currently not "
"handled by the auto-shard rewrite and will be removed.";
nodes_to_delete->insert(node.name());
TF_RETURN_IF_ERROR(graph->UpdateFanouts(node.name(), node.input(0)));
const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0);
return RecursivelyHandleOp(*input_node, num_workers, index, flib, graph,
nodes_to_delete);
}
if (IsDatasetNodeOfType(node, kUnshardableSourceDatasetOps)) {
return errors::NotFound("Found an unshardable source dataset: ",
node.DebugString());
}
if (IsDatasetNodeOfType(node, kMultipleInputsDatasetOps)) {
for (int i = 0; i < node.input_size(); ++i) {
const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, i);
TF_RETURN_IF_ERROR(RecursivelyHandleOp(*input_node, num_workers, index,
flib, graph, nodes_to_delete));
}
return Status::OK();
}
// This handles the case for the following subgraph:
// Placeholder -> TensorSliceDataset -> FlatMapDataset -x->
// (other preprocessing datasets) -> InterleaveDataset
// and then inserting the shard node immediately after the FlatMapDataset.
//
// This is used for some training pipelines where a dataset is created with
// the following code:
//
// def make_dataset_pipeline():
// file_globs = [...]
// datasets = []
// for file_glob in file_globs:
// datasets.append(Dataset.list_files(file_glob).map(TFRecordReader))
// dataset = Dataset.from_tensor_slices(datasets)
// dataset = dataset.flat_map(lambda x: x)
// dataset = ... # additional preprocessing
// dataset = dataset.interleave(lambda x: x, cycle_length=...)
// return dataset
if (IsDatasetNodeOfType(node, kFuncDatasetOps)) {
const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0);
const NodeDef* flat_map_node = FindFuncAndTensorSliceDataset(
input_node, num_workers, index, flib, graph, nodes_to_delete);
if (flat_map_node != nullptr) {
auto fanouts = graph->GetFanouts(*flat_map_node, false);
// FlatMapDataset should only be the input to one other dataset.
if (fanouts.size() == 1) {
return ProcessDatasetSourceNode(graph, *fanouts.begin()->node,
nodes_to_delete, num_workers, index);
}
}
}
// This handles the case where a reader Dataset is contained within a
// FuncDataset (e.g. FlatMap, ParallelInterleave, etc...). For example:
//
// dataset = Dataset.list_files("/path/to/data")
// dataset = dataset.flat_map(core_readers.TFRecordDataset)
//
// where the list of files is passed in one-by-one as an argument to the
// function in flat_map.
if (IsDatasetNodeOfType(node, kFuncDatasetOps) &&
ReaderOpInFunction(node, *flib)) {
return ProcessDatasetSourceNode(graph, node, nodes_to_delete, num_workers,
index);
}
if (IsDatasetNodeOfType(node, kReaderDatasetOps)) {
// We reached a reader dataset directly and we try to shard input 0.
return ProcessDatasetSourceNode(graph, node, nodes_to_delete, num_workers,
index);
}
if (!IsDatasetNodeOfType(node, kPassThroughOps)) {
return errors::NotFound(
"Did not find a shardable source, walked to ",
"a node which is not a dataset: ", node.DebugString(),
". Consider either turning off auto-sharding or switching the "
"auto_shard_policy to DATA to shard this dataset. You can do this by "
"creating a new `tf.data.Options()` object then setting "
"`options.experimental_distribute.auto_shard_policy = "
"AutoShardPolicy.DATA` before applying the options object to the "
"dataset via `dataset.with_options(options)`.");
}
const NodeDef* input_node = graph_utils::GetInputNode(node, *graph, 0);
return RecursivelyHandleOp(*input_node, num_workers, index, flib, graph,
nodes_to_delete);
}
// Recursively walk the dataset graph from sink to source, searching for
// the first (i.e. closest to the sink) occurence of a ReaderDataset, such as
// CSVDataset, TFRecordDataset, etc. We then insert a ShardDataset op before
// that nodes input, so that each worker only reads a subset of files.
// Additionally, we remove sources of randomness (e.g. ShuffleDataset) that
// occur upstream of the ShardDataset transformation to ensure that sharding
// returns a sensible result.
Status ShardByFile(const NodeDef& sink_node, int64 num_workers, int64 index,
FunctionLibraryDefinition* flib, MutableGraphView* graph) {
absl::flat_hash_set<string> nodes_to_delete;
TF_RETURN_IF_ERROR(RecursivelyHandleOp(sink_node, num_workers, index, flib,
graph, &nodes_to_delete));
return graph->DeleteNodes(nodes_to_delete);
}
Status RewriteRebatchV2ToV1(const NodeDef& sink_node, int64 num_replicas,
MutableGraphView* graph) {
// The final node before AutoShardDataset is RebatchDataset.
// This is always the case as RebatchDataset and AutoShardDataset are internal
// APIs used directly by tf.distribute's input_lib. As such, instead of
// walking the entire dataset graph, we can walk up directly from the
// sink_node to get the RebatchDataset.
NodeDef* input_node = graph_utils::GetInputNode(sink_node, *graph);
if (input_node->op() != kRebatchDatasetV2OpName) {
return Status::OK();
}
NodeDef* rebatch_node = input_node;
// Update RebatchDatasetV2 in place. Since Rebatch is an internal API, no
// other nodes should have it as an input.
rebatch_node->set_op(kRebatchDatasetOpName);
// Delete the `batch_sizes` and `drop_remainder` input.
rebatch_node->mutable_input()->DeleteSubrange(/*start=*/1, /*num=*/2);
// Add the `num_replicas` input.
if (num_replicas < 1) {
return errors::InvalidArgument(
"Cannot rewrite RebatchDatasetV2 to legacy RebatchDataset with invalid "
"num_replicas argument. `num_replicas` is ",
num_replicas, ", but expected to be >= 1.");
}
auto num_replicas_node = graph_utils::AddScalarConstNode(num_replicas, graph);
rebatch_node->add_input(num_replicas_node->name());
// Set `use_fallback` attr. This attr is not used anywhere, so its value
// does not matter
(*rebatch_node->mutable_attr())["use_fallback"].set_b(true);
// Update the output_shapes attr to set all its batch dimensions to -1
// (unknown).
auto* shapes_attr =
gtl::FindOrNull(*rebatch_node->mutable_attr(), "output_shapes");
if (shapes_attr == nullptr) {
return errors::InvalidArgument(
"Cannot rewrite RebatchDatasetV2 with missing `output_shapes` attr.");
}
for (int i = 0; i < shapes_attr->list().shape_size(); ++i) {
auto* shape = shapes_attr->mutable_list()->mutable_shape(i);
if (shape->unknown_rank()) continue;
shape->mutable_dim(0)->set_size(-1);
}
return Status::OK();
}
Status ShardByData(const NodeDef& sink_node, int64 num_workers, int64 index,
int64 num_replicas, MutableGraphView* graph) {
const NodeDef* shard_before = &sink_node;
// We sometimes insert a PrefetchDataset, OptionsDataset, and FinalizeDataset
// at the end of the input pipeline before autosharding. When sharding by
// data, we should insert the shard before the these datasets so that the
// right number of elements is prefetched.
NodeDef* input_node = graph_utils::GetInputNode(sink_node, *graph);
while (input_node->op() == kPrefetchDatasetOpName ||
input_node->op() == kOptionsDatasetOpName ||
input_node->op() == kFinalizeDatasetOpName) {
shard_before = input_node;
input_node = graph_utils::GetInputNode(*input_node, *graph);
}
// Sharding by data only works with legacy RebatchDataset. As such, we rewrite
// all instances of RebatchDatasetV2 to RebatchDataset.
TF_RETURN_IF_ERROR(RewriteRebatchV2ToV1(*shard_before, num_replicas, graph));
return AddShardNode(graph, *shard_before, num_workers, index);
}
// Searches the dataset graph replacing any occurence of `shard(1, 0)` with
// `shard(num_workers, index)`.
Status ShardByHint(const NodeDef& sink_node, int64 num_workers, int64 index,
int64 num_replicas, MutableGraphView* graph) {
auto get_shard_node = [graph](const NodeDef& node) -> const NodeDef* {
if (node.op() != kShardDatasetOpName) return nullptr;
auto num_workers_node = graph->GetNode(node.input(1));
if (num_workers_node->op() != kConstOpName) return nullptr;
if (num_workers_node->attr().at("value").tensor().int64_val(0) !=
tensorflow::data::kShardHint)
return nullptr;
return &node;
};
auto* num_workers_node =
graph_utils::AddScalarConstNode(static_cast<int64>(num_workers), graph);
auto* worker_index_node =
graph_utils::AddScalarConstNode(static_cast<int64>(index), graph);
for (const NodeDef& node : graph->graph()->node()) {
const NodeDef* shard_node = get_shard_node(node);
if (!shard_node) continue;
auto mutable_node = graph->GetNode(shard_node->name());
*mutable_node->mutable_input(1) = num_workers_node->name();
*mutable_node->mutable_input(2) = worker_index_node->name();
}
return Status::OK();
}
Status OptimizeGraph(const GrapplerItem& item, int64 num_workers, int64 index,
AutoShardPolicy policy, int64 num_replicas,
GraphDef* output) {
if (policy == AutoShardPolicy::OFF ||
(policy == AutoShardPolicy::FILE && num_workers == 1 && index == 0)) {
return Status::OK();
}
*output = item.graph;
MutableGraphView graph(output);
FunctionLibraryDefinition flib(OpRegistry::Global(), item.graph.library());
NodeDef* sink_node;
TF_RETURN_IF_ERROR(graph_utils::GetFetchNode(graph, item, &sink_node));
switch (policy) {
case AutoShardPolicy::OFF:
return Status::OK();
case AutoShardPolicy::FILE:
return ShardByFile(*sink_node, num_workers, index, &flib, &graph);
case AutoShardPolicy::DATA:
return ShardByData(*sink_node, num_workers, index, num_replicas, &graph);
case AutoShardPolicy::HINT:
return ShardByHint(*sink_node, num_workers, index, num_replicas, &graph);
case AutoShardPolicy::AUTO:
default:
Status s = ShardByFile(*sink_node, num_workers, index, &flib, &graph);
if (errors::IsNotFound(s)) {
LOG(WARNING) << "AUTO sharding policy will apply DATA sharding policy "
"as it failed to apply FILE sharding policy because of "
"the following reason: "
<< s.error_message();
return ShardByData(*sink_node, num_workers, index, num_replicas,
&graph);
}
return s;
}
}
} // anonymous namespace
Status AutoShard::Init(
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) {
if (!config) return errors::InvalidArgument("RewriterConfig not found.");
if ((config->parameter_map().find(kNumWorkersAttrName) ==
config->parameter_map().end())) {
return errors::InvalidArgument(kNumWorkersAttrName, " parameter missing.");
}
if ((config->parameter_map().find(kIndexAttrName) ==
config->parameter_map().end())) {
return errors::InvalidArgument(kIndexAttrName, " parameter missing.");
}
num_workers_ = config->parameter_map().at(kNumWorkersAttrName).i();
index_ = config->parameter_map().at(kIndexAttrName).i();
auto_shard_policy_ =
AutoShardPolicy(config->parameter_map().at(kAutoShardPolicyAttrName).i());
num_replicas_ = config->parameter_map().at(kNumReplicasAttrName).i();
if (auto_shard_policy_ != AutoShardPolicy::OFF &&
auto_shard_policy_ != AutoShardPolicy::AUTO &&
auto_shard_policy_ != AutoShardPolicy::DATA &&
auto_shard_policy_ != AutoShardPolicy::FILE &&
auto_shard_policy_ != AutoShardPolicy::HINT) {
return errors::InvalidArgument(kAutoShardPolicyAttrName, " is invalid.");
}
if (num_workers_ < 1) {
return errors::InvalidArgument(kNumWorkersAttrName,
" should be >= 1, currently ", num_workers_);
}
if (index_ < 0 || index_ >= num_workers_) {
return errors::InvalidArgument(kIndexAttrName, " should be >= 0 and < ",
num_workers_, ", currently ", index_);
}
if (num_replicas_ < 0) {
return errors::InvalidArgument(kNumReplicasAttrName, " should be >= 0");
}
return Status::OK();
}
Status AutoShard::OptimizeAndCollectStats(Cluster* /* cluster */,
const GrapplerItem& item,
GraphDef* output,
OptimizationStats* stats) {
*output = item.graph;
TF_RETURN_IF_ERROR(OptimizeGraph(item, num_workers_, index_,
auto_shard_policy_, num_replicas_, output));
stats->num_changes++;
return Status::OK();
}
REGISTER_GRAPH_OPTIMIZER_AS(AutoShard, "tf_auto_shard");
} // namespace grappler
} // namespace tensorflow