-
Notifications
You must be signed in to change notification settings - Fork 74k
/
xla_compiler_test.cc
2100 lines (1820 loc) · 83.9 KB
/
xla_compiler_test.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "absl/strings/match.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/cc/ops/data_flow_ops.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/functional_ops.h"
#include "tensorflow/cc/ops/list_ops.h"
#include "tensorflow/cc/ops/math_ops.h"
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/side_effect_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "xla/client/client_library.h"
#include "xla/client/local_client.h"
#include "xla/client/xla_builder.h"
#include "xla/literal.h"
#include "xla/service/hlo.pb.h"
#include "xla/service/hlo_proto_util.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/tests/literal_test_util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/version.h"
#include "tsl/platform/statusor.h"
namespace tensorflow {
class XlaCompilerTest : public ::testing::Test {
protected:
void SetUp() override {
client_ = xla::ClientLibrary::LocalClientOrDie();
XlaOpRegistry::RegisterCompilationKernels();
FunctionDefLibrary flib;
flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
}
XlaCompiler::Options DefaultOptions() {
XlaCompiler::Options options;
options.device_type = DeviceType(DEVICE_CPU_XLA_JIT);
options.client = client_;
options.flib_def = flib_def_.get();
return options;
}
FunctionLibraryDefinition* LocalFlibDef(XlaCompiler* compiler) {
return compiler->local_flib_def_.get();
}
xla::Client* client_;
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
};
namespace {
// Helper class to test the ability to pass resources through to XLA
// compiled kernels.
class DummyResourceForTest : public ResourceBase {
public:
string DebugString() const override { return "dummy"; }
void Increment() { ++value_; }
int Get() { return value_; }
private:
int value_ = 0;
};
class DummyReadResourceOp : public XlaOpKernel {
public:
explicit DummyReadResourceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
ResourceMgr* rm = ctx->op_kernel_context()->resource_manager();
OP_REQUIRES(ctx, rm, errors::Internal("No resource manager."));
DummyResourceForTest* dummy;
OP_REQUIRES_OK(ctx, rm->Lookup<DummyResourceForTest>(
rm->default_container(), "dummy", &dummy));
dummy->Increment();
dummy->Unref();
ctx->SetOutput(0, ctx->Input(0));
ctx->SetOutput(1, ctx->Input(0));
}
};
class DummyReadResourceCC {
public:
DummyReadResourceCC(const Scope& scope, const Input& value) {
if (!scope.ok()) return;
auto _value = ops::AsNodeOut(scope, value);
if (!scope.ok()) return;
Node* ret;
const auto unique_name = scope.GetUniqueNameForOp("DummyReadResource");
auto builder = NodeBuilder(unique_name, "DummyReadResource").Input(_value);
scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
if (!scope.ok()) return;
scope.UpdateStatus(scope.DoShapeInference(ret));
if (!scope.ok()) return;
this->output1_ = Output(ret, 0);
this->output2_ = Output(ret, 1);
}
Output output1_;
Output output2_;
};
REGISTER_OP("DummyReadResource")
.Input("input: int32")
.Output("output1: int32")
.Output("output2: int32")
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
A dummy Op.
input: dummy input.
output1: dummy output.
output2: dummy output.
)doc");
REGISTER_XLA_OP(Name("DummyReadResource"), DummyReadResourceOp);
// DummyDuplicateOp is present purely to test multiple REGISTER_XLA_OP calls
// on the same Op name below.
class DummyDuplicateOp : public XlaOpKernel {
public:
explicit DummyDuplicateOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
ctx->SetOutput(0, ctx->Input(0));
}
};
REGISTER_OP("DummyDuplicateOp")
.Input("input: int32")
.Output("output: int32")
.Doc(R"doc(
A dummy Op.
input: dummy input.
output: dummy output.
)doc");
REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_CPU_XLA_JIT),
DummyDuplicateOp);
REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_GPU_XLA_JIT),
DummyDuplicateOp);
// Tests compilation and execution of an empty graph.
TEST_F(XlaCompilerTest, EmptyReturnValues) {
XlaCompiler compiler(DefaultOptions());
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph),
/*args=*/{}, &result));
TF_ASSERT_OK(client_->Execute(*result.computation, {}).status());
}
// Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest, Simple) {
// Builds a graph that adds two Tensors.
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
auto c = ops::Add(scope.WithOpName("C"), a, b);
auto d = ops::_Retval(scope.WithOpName("D"), c, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2});
args[1].kind = XlaCompiler::Argument::kParameter;
args[1].type = DT_INT32;
args[1].shape = TensorShape({2});
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args, &result));
// Tests that the generated computation works.
xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(param0_literal).value();
std::unique_ptr<xla::GlobalData> param1_data =
client_->TransferToServer(param1_literal).value();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.value();
xla::Literal actual_literal = client_->Transfer(*actual).value();
xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({4, 143});
xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
absl::StatusOr<std::unique_ptr<xla::HloModule>> LoadModuleFromHloProto(
const xla::HloModuleProto& module_proto) {
TF_ASSIGN_OR_RETURN(auto module_config,
xla::HloModule::CreateModuleConfigFromProto(
module_proto, xla::GetDebugOptionsFromFlags()));
return xla::CreateModuleFromProto(module_proto, module_config);
}
// Tests compilation and execution of a graph that adds two tensors with dynamic
// shape parameters.
TEST_F(XlaCompilerTest, SimpleDynamicShapeParameter) {
// Builds a graph that adds two Tensors.
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
auto c = ops::Add(scope.WithOpName("C"), a, b);
auto d = ops::_Retval(scope.WithOpName("D"), c, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2});
args[0].value_bound = Tensor(DT_INT32, std::get<0>(args[0].shape));
Tensor dynamism_tensor(DT_BOOL);
TF_ASSERT_OK(LiteralToHostTensor(xla::LiteralUtil::CreateR1<bool>({true}),
DT_BOOL, &dynamism_tensor));
args[0].value_dynamism = dynamism_tensor;
args[1].kind = XlaCompiler::Argument::kParameter;
args[1].type = DT_INT32;
args[1].shape = TensorShape({2});
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args, &result));
auto hlo = result.computation->proto();
TF_ASSERT_OK_AND_ASSIGN(auto module, LoadModuleFromHloProto(hlo));
EXPECT_EQ(module->computation_count(), 1);
EXPECT_TRUE(module->mutable_computation(0)
->parameter_instruction(0)
->shape()
.is_dynamic());
}
// Tests compilation of a graph where the _Retval node is not necessarily last
// amongst the graph nodes in construction order, and always_return_tuple is
// false. Regression test for bug where the wrong value was returned.
TEST_F(XlaCompilerTest, OutOfOrderGraph) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
// The _Retval node is not last in construction order.
auto d = ops::_Retval(scope.WithOpName("D"), a, 0);
auto c = ops::Add(scope.WithOpName("C"), a, b);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2});
args[1].kind = XlaCompiler::Argument::kParameter;
args[1].type = DT_INT32;
args[1].shape = TensorShape({2});
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompileOptions compile_options;
compile_options.always_return_tuple = false;
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
args, &result));
// Tests that the generated computation works.
xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(param0_literal).value();
std::unique_ptr<xla::GlobalData> param1_data =
client_->TransferToServer(param1_literal).value();
std::unique_ptr<xla::GlobalData> actual =
client_
->Execute(*result.computation, {param0_data.get(), param1_data.get()})
.value();
xla::Literal actual_literal = client_->Transfer(*actual).value();
EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal));
}
// Tests that the compiler can correctly propagate the layout assigned by
// shape_representation_fn_ to resource returns that have not been written to.
TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForUnwrittenResource) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 0);
auto d = ops::_Retval(scope.WithOpName("D"), var, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kResource;
args[0].resource_kind = XlaResource::kVariable;
args[0].initialized = true;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2, 3});
auto options = DefaultOptions();
XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns;
shape_determination_fns.shape_representation_fn =
[](const TensorShape& shape, DataType dt, bool use_fast_memory,
XlaLayoutPreference layout_preference) -> absl::StatusOr<xla::Shape> {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dt, shape, &xla_shape));
*xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
return xla_shape;
};
options.shape_determination_fns = shape_determination_fns;
// Compiles the graph.
XlaCompiler compiler(options);
XlaCompiler::CompilationResult result;
XlaCompiler::CompileOptions compile_options;
compile_options.return_updated_values_for_all_resources = true;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
args, &result));
xla::Shape transposed =
xla::ShapeUtil::MakeShapeWithDenseLayout(xla::S32, {2, 3}, {0, 1});
// Check that the return shapes are correctly tranposed.
EXPECT_EQ(result.xla_output_shape,
xla::ShapeUtil::MakeTupleShape({transposed}));
}
// Tests that the compiler can correctly propagate fast mem attribute for input
// resource variable.
TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForFastMemVar) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 0);
auto d = ops::_Retval(scope.WithOpName("D"), var, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kResource;
args[0].resource_kind = XlaResource::kVariable;
args[0].initialized = true;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2, 3});
args[0].fast_mem = true;
auto options = DefaultOptions();
int fast_mem_arg_count = 0;
XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns;
shape_determination_fns.shape_representation_fn =
[&fast_mem_arg_count](
const TensorShape& shape, DataType dt, bool use_fast_memory,
XlaLayoutPreference layout_preference) -> absl::StatusOr<xla::Shape> {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dt, shape, &xla_shape));
*xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
if (use_fast_memory) {
fast_mem_arg_count++;
}
return xla_shape;
};
options.shape_determination_fns = shape_determination_fns;
// Compiles the graph.
XlaCompiler compiler(options);
XlaCompiler::CompilationResult result;
XlaCompiler::CompileOptions compile_options;
compile_options.return_updated_values_for_all_resources = true;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
args, &result));
// Count 2: one for argument, one for the return value.
EXPECT_EQ(fast_mem_arg_count, 2);
}
// Tests that the compiler can correctly propagate the layout assigned by
// shape_representation_fn_ to return types.
TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForRetVal) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
// Adds an identity op around the resource to make sure identity ops propagate
// resources correctly.
auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
auto write = ops::AssignAddVariableOp(scope, identity, a);
auto read = ops::ReadVariableOp(
scope.WithControlDependencies(std::vector<Operation>{write}), var,
DT_INT32);
auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2, 3});
args[1].kind = XlaCompiler::Argument::kResource;
args[1].resource_kind = XlaResource::kVariable;
args[1].initialized = true;
args[1].type = DT_INT32;
args[1].shape = TensorShape({2, 3});
auto options = DefaultOptions();
XlaShapeLayoutHelpers::ShapeDeterminationFns shape_determination_fns;
shape_determination_fns.shape_representation_fn =
[](const TensorShape& shape, DataType dt, bool use_fast_memory,
XlaLayoutPreference layout_preference) -> absl::StatusOr<xla::Shape> {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dt, shape, &xla_shape));
*xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
return xla_shape;
};
options.shape_determination_fns = shape_determination_fns;
// Compiles the graph.
XlaCompiler compiler(options);
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args, &result));
xla::Shape transposed =
xla::ShapeUtil::MakeShapeWithDenseLayout(xla::S32, {2, 3}, {0, 1});
// Check that the return shapes are correctly tranposed.
EXPECT_EQ(result.xla_output_shape,
xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
EXPECT_EQ(result.computation->GetProgramShape().value().result(),
xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
}
// The layout of resource variable shouldn't change after transpose
TEST_F(XlaCompilerTest, TransposeVariables) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
// Adds an identity op around the resource to make sure identity ops propagate
// resources correctly.
auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
auto write = ops::AssignAddVariableOp(scope, identity, a);
auto read = ops::ReadVariableOp(
scope.WithControlDependencies(std::vector<Operation>{write}), var,
DT_INT32);
auto transposed_read = ops::Transpose(scope, read, {1, 0});
auto reshape = ops::Reshape(scope, transposed_read, {2, 3});
auto d = ops::_Retval(scope.WithOpName("D"), reshape, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2, 3});
args[1].kind = XlaCompiler::Argument::kResource;
args[1].resource_kind = XlaResource::kVariable;
args[1].initialized = true;
args[1].type = DT_INT32;
args[1].shape = TensorShape({2, 3});
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "transpose",
std::move(graph), args, &result));
xla::Shape transposed =
xla::ShapeUtil::MakeShapeWithDenseLayout(xla::S32, {2, 3}, {1, 0});
// Check that the return shapes are correctly tranposed.
EXPECT_EQ(result.xla_output_shape,
xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
}
// Unranked fake param returns a 0 shaped tensor.
TEST_F(XlaCompilerTest, UnrankedFakeParam) {
Scope scope = Scope::NewRootScope().ExitOnError();
PartialTensorShape shape;
auto a = ops::FakeParam(scope, DT_INT32, shape);
auto ret = ops::_Retval(scope.WithOpName("D"), a, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "compile",
std::move(graph), {}, &result));
// Check that the return shapes are correctly tranposed.
EXPECT_EQ(result.xla_output_shape,
xla::ShapeUtil::MakeTupleShape(
{xla::ShapeUtil::MakeShape(xla::S32, {0})}));
}
// Tests that the compiler doesn't reorder the parameters.
TEST_F(XlaCompilerTest, MixedOrderArguments) {
for (bool swap_order : {false, true}) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto var =
ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, swap_order ? 0 : 1);
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, swap_order ? 1 : 0);
// Adds an identity op around the resource to make sure identity ops
// propagate resources correctly.
auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
auto write = ops::AssignAddVariableOp(scope, identity, a);
auto read = ops::ReadVariableOp(
scope.WithControlDependencies(std::vector<Operation>{write}), var,
DT_INT32);
auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2});
args[1].kind = XlaCompiler::Argument::kResource;
args[1].resource_kind = XlaResource::kVariable;
args[1].initialized = true;
args[1].type = DT_INT32;
args[1].shape = TensorShape({2});
if (swap_order) {
// Even after swapping arguments, the compiler should maintain the new
// ordering of parameters.
std::swap(args[0], args[1]);
}
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompileOptions compile_options;
compile_options.always_return_tuple = false;
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
args, &result));
EXPECT_THAT(result.input_mapping, ::testing::ElementsAre(0, 1));
}
}
TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
// Builds a graph that adds reshapes a tensor, but with the shape not
// statically known.
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
auto c = ops::Reshape(scope.WithOpName("C"), a, b);
auto d = ops::_Retval(scope.WithOpName("D"), c, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2});
args[1].kind = XlaCompiler::Argument::kParameter;
args[1].type = DT_INT32;
args[1].shape = TensorShape({2});
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompilationResult result;
Status status =
compiler.CompileGraph(XlaCompiler::CompileOptions(), "reshape",
std::move(graph), args, &result);
EXPECT_FALSE(status.ok());
EXPECT_TRUE(absl::StrContains(status.message(), "depends on a parameter"))
<< status.message();
EXPECT_TRUE(absl::StrContains(status.message(), "{{node C}}"))
<< status.message();
EXPECT_TRUE(
absl::StrContains(status.message(), "must be a compile-time constant"))
<< status.message();
}
// Tests handling of compile-time constant outputs.
TEST_F(XlaCompilerTest, ConstantOutputs) {
// Builds a graph with one compile-time constant output and one data-dependent
// output, i.e.,
// func(a) { b=7; c=-a; return b, c; }
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto b = ops::Const<int32>(scope.WithOpName("B"), 7);
auto c = ops::Neg(scope.WithOpName("C"), a);
auto d = ops::_Retval(scope.WithOpName("D"), b, 0);
auto e = ops::_Retval(scope.WithOpName("E"), c, 1);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2});
XlaCompiler::Options options = DefaultOptions();
XlaCompiler compiler(options);
{
std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
CopyGraph(*graph, graph_copy.get());
XlaCompiler::CompileOptions compile_options;
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants",
std::move(graph_copy), args, &result));
ASSERT_EQ(2, result.outputs.size());
EXPECT_FALSE(result.outputs[0].is_constant);
EXPECT_FALSE(result.outputs[1].is_constant);
// Tests that the generated computation works.
xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(param0_literal).value();
std::unique_ptr<xla::GlobalData> actual =
client_->Execute(*result.computation, {param0_data.get()}).value();
xla::Literal actual_literal = client_->Transfer(*actual).value();
xla::Literal expected0 = xla::LiteralUtil::CreateR0<int32>(7);
xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({-7, -42});
xla::Literal expected =
xla::LiteralUtil::MakeTuple({&expected0, &expected1});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected, actual_literal));
}
}
TEST_F(XlaCompilerTest, ConstantOutputsOfFunctionalNode) {
// Define a function with one compile-time constant output and one
// data-dependent output.
// @function.Defun(noinline=True)
// foo(a) {b=7; return b, a; }
const Tensor seven = test::AsScalar<int>(7);
FunctionDef fdef = FunctionDefHelper::Create(
"foo", {"a_0:int32"}, {"const:int32", "a:int32"}, {},
{
{{"Const"}, "Const", {}, {{"dtype", DT_INT32}, {"value", seven}}},
},
{{"a", "a_0"}, {"const", "Const:output:0"}});
(*fdef.mutable_attr())["_noinline"].set_b(true);
FunctionDefLibrary fdef_lib;
*(fdef_lib.add_function()) = fdef;
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
{
Scope scope = Scope::NewRootScope().ExitOnError();
TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(fdef_lib));
auto arg = ops::_Arg(scope.WithOpName("input_arg"), DT_INT32, 0);
NodeDef foo;
foo.set_name("foo");
foo.set_op("foo");
*foo.add_input() = "input_arg";
Status status;
scope.graph()->AddNode(foo, &status);
TF_ASSERT_OK(status);
NodeDef retval_1;
retval_1.set_name("retval_0");
retval_1.set_op(FunctionLibraryDefinition::kRetOp);
*retval_1.add_input() = "foo";
(*retval_1.mutable_attr())["T"].set_type(DT_INT32);
(*retval_1.mutable_attr())["index"].set_i(0);
scope.graph()->AddNode(retval_1, &status);
TF_ASSERT_OK(status);
NodeDef retval_2;
retval_2.set_name("retval_1");
retval_2.set_op(FunctionLibraryDefinition::kRetOp);
*retval_2.add_input() = "foo:1";
(*retval_2.mutable_attr())["T"].set_type(DT_INT32);
(*retval_2.mutable_attr())["index"].set_i(1);
scope.graph()->AddNode(retval_2, &status);
TF_ASSERT_OK(status);
TF_ASSERT_OK(scope.ToGraph(graph.get()));
}
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({1});
XlaCompiler::Options options = DefaultOptions();
FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
options.flib_def = &flib_def;
XlaCompiler compiler(options);
XlaCompiler::CompileOptions compile_options;
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants",
std::move(graph), args, &result));
ASSERT_EQ(2, result.outputs.size());
EXPECT_FALSE(result.outputs[1].is_constant);
}
// Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest, ResourceManager) {
// Builds a graph that calls the dummy resource Op.
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto b = DummyReadResourceCC(scope.WithOpName("B"), a);
auto c = ops::Add(scope.WithOpName("C"), b.output2_, b.output1_);
auto d = ops::_Retval(scope.WithOpName("D"), c, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the argument.
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2});
DummyResourceForTest* resource = new DummyResourceForTest();
// Compiles the graph.
auto options = DefaultOptions();
std::function<Status(ResourceMgr*)> populate_function =
[resource](ResourceMgr* rm) {
resource->Ref();
return rm->Create(rm->default_container(), "dummy", resource);
};
options.populate_resource_manager = &populate_function;
XlaCompiler compiler(options);
EXPECT_EQ(0, resource->Get());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy",
std::move(graph), args, &result));
EXPECT_EQ(1, resource->Get());
resource->Unref();
}
// Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest, DeterministicCompilation) {
// Builds a graph that contains a node with two output edges. The compiler
// should always traverse them in the same order.
const int64_t test_count = 2;
std::vector<XlaCompiler::CompilationResult> results(test_count);
for (int64_t i = 0; i < test_count; ++i) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto b = ops::Neg(scope.WithOpName("B"), a);
auto c = ops::Neg(scope.WithOpName("C"), a);
auto d = ops::Add(scope.WithOpName("D"), b, c);
auto e = ops::_Retval(scope.WithOpName("E"), d, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the argument.
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2});
// Compiles the graph.
auto options = DefaultOptions();
XlaCompiler compiler(options);
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy",
std::move(graph), args, &results[i]));
}
for (int64_t i = 1; i < test_count; ++i) {
const auto& m1 = results[i - 1].computation->proto();
const auto& m2 = results[i].computation->proto();
ASSERT_EQ(m1.computations_size(), m2.computations_size());
// Check if every hlo computation is the same.
for (int k = 0; k < m1.computations_size(); k++) {
const auto& c1 = m1.computations(k);
const auto& c2 = m2.computations(k);
ASSERT_EQ(c1.instructions_size(), c2.instructions_size());
for (int j = 0; j < c1.instructions_size(); j++) {
auto instr1 = c1.instructions(j);
auto instr2 = c2.instructions(j);
instr1.clear_name();
instr1.clear_id();
instr1.clear_operand_ids();
instr2.clear_name();
instr2.clear_id();
instr2.clear_operand_ids();
// The names of instructions were uniquified by the XlaBuilder and the
// unique ids may be different, the rest of the fields should be
// identical.
string str1, str2;
LOG(INFO) << "instr1 = " << instr1.DebugString();
LOG(INFO) << "instr2 = " << instr2.DebugString();
instr1.AppendPartialToString(&str1);
instr2.AppendPartialToString(&str2);
EXPECT_EQ(str1, str2);
}
}
}
}
// Tests a computation that receives a TensorArray resource as input and
// updates it.
TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
auto flow = ops::Const<float>(scope, {});
auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad1");
auto grad2 = ops::TensorArrayGrad(scope, arg, grad1.flow_out, "grad2");
auto index = ops::Const<int32>(scope, 1);
auto write = ops::TensorArrayWrite(scope, grad1.grad_handle, index, index,
grad2.flow_out);
auto read = ops::TensorArrayRead(scope, arg, index, write.flow_out, DT_INT32);
auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kResource;
args[0].resource_kind = XlaResource::kTensorArray;
args[0].initialized = true;
args[0].type = DT_INT32;
args[0].shape = TensorShape({});
args[0].max_array_size = 2;
args[0].tensor_array_gradients = {"grad2"};
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args, &result));
ASSERT_EQ(1, result.resource_updates.size());
const XlaCompiler::ResourceUpdate& update = result.resource_updates[0];
EXPECT_EQ(0, update.input_index);
EXPECT_EQ(DT_INT32, update.type);
EXPECT_EQ((std::set<string>{"grad1", "grad2"}),
update.tensor_array_gradients_accessed);
// Tests that the generated computation works.
xla::Literal input_base = xla::LiteralUtil::CreateR1<int32>({7, 42});
xla::Literal input_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101});
xla::Literal input = xla::LiteralUtil::MakeTuple({&input_base, &input_grad2});
std::unique_ptr<xla::GlobalData> param0_data =
client_->TransferToServer(input).value();
std::unique_ptr<xla::GlobalData> actual =
client_->Execute(*result.computation, {param0_data.get()}).value();
xla::Literal actual_literal = client_->Transfer(*actual).value();
xla::Literal output_read = xla::LiteralUtil::CreateR0<int32>(42);
xla::Literal output_base = xla::LiteralUtil::CreateR1<int32>({7, 42});
xla::Literal output_grad1 = xla::LiteralUtil::CreateR1<int32>({0, 1});
xla::Literal output_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101});
xla::Literal output_resource =
xla::LiteralUtil::MakeTuple({&output_base, &output_grad1, &output_grad2});
xla::Literal expected_literal =
xla::LiteralUtil::MakeTuple({&output_read, &output_resource});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
// Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
auto flow = ops::Const<float>(scope, {});
auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad1");
auto index = ops::Const<int32>(scope, 1);
auto read = ops::TensorArrayRead(scope, arg, index, grad1.flow_out, DT_INT32);
auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kResource;
args[0].resource_kind = XlaResource::kTensorArray;
args[0].initialized = true;
args[0].type = DT_INT32;
args[0].shape = TensorShape({});
args[0].max_array_size = 2;
args[0].tensor_array_gradients = {"grad1"};
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args, &result));
EXPECT_EQ(0, result.resource_updates.size());
}
// Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
auto flow = ops::Const<float>(scope, {});
auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad2");
auto index = ops::Const<int32>(scope, 1);
auto read = ops::TensorArrayRead(scope, arg, index, grad1.flow_out, DT_INT32);
auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kResource;
args[0].resource_kind = XlaResource::kTensorArray;
args[0].initialized = true;
args[0].type = DT_INT32;
args[0].shape = TensorShape({});
args[0].max_array_size = 2;
args[0].tensor_array_gradients = {"grad1"};
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompilationResult result;
TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
std::move(graph), args, &result));
EXPECT_EQ(1, result.resource_updates.size());
}
// Tests CompileFunction with undefined function fails.
TEST_F(XlaCompilerTest, UndefinedFunctionFails) {
XlaCompiler compiler(DefaultOptions());
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
XlaCompiler::CompilationResult result;
NameAttrList name_attr;
name_attr.set_name("Function_NotDefined_");
Status status =
compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr,
/*args=*/{}, &result);
EXPECT_FALSE(status.ok());
EXPECT_TRUE(absl::StrContains(status.message(), "is not defined."))
<< status.message();
}