-
Notifications
You must be signed in to change notification settings - Fork 74k
/
xla_op_kernel.cc
833 lines (734 loc) · 30.8 KB
/
xla_op_kernel.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include <numeric>
#include "absl/memory/memory.h"
#include "tensorflow/compiler/tf2xla/literal_util.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "xla/client/value_inference.h"
#include "xla/client/xla_builder.h"
#include "xla/client/xla_computation.h"
#include "xla/status_macros.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/util/overflow.h"
namespace tensorflow {
XlaOpKernelContext::XlaOpKernelContext(OpKernelContext* context)
: context_(context),
dynamic_dimension_is_minus_one_(false),
value_inference_(xla_context()->builder()) {}
bool XlaOpKernelContext::ValidateInputsAreSameShape(OpKernel* op) {
return context_->ValidateInputsAreSameShape(op);
}
XlaContext* XlaOpKernelContext::xla_context() const {
return &XlaContext::Get(context_);
}
xla::XlaBuilder* XlaOpKernelContext::builder() const {
return xla_context()->builder();
}
xla::ValueInference& XlaOpKernelContext::value_inference() {
return value_inference_;
}
XlaCompiler* XlaOpKernelContext::compiler() const {
return xla_context()->compiler();
}
const XlaExpression& XlaOpKernelContext::InputExpression(int index) {
return *XlaExpression::CastExpressionFromTensor(context_->input(index));
}
const XlaExpression& XlaOpKernelContext::InputExpression(
absl::string_view name) {
return *XlaExpression::CastExpressionFromTensor(GetInputTensorByName(name));
}
xla::XlaOp XlaOpKernelContext::Input(int index) {
return InputExpression(index).AsXlaOp(builder());
}
xla::XlaOp XlaOpKernelContext::Input(absl::string_view name) {
return InputExpression(name).AsXlaOp(builder());
}
TensorShape XlaOpKernelContext::InputShape(int index) {
return context_->input(index).shape();
}
TensorShape XlaOpKernelContext::InputShape(absl::string_view name) {
return GetInputTensorByName(name).shape();
}
absl::StatusOr<xla::Shape> XlaOpKernelContext::InputXlaShape(int index) {
return InputExpression(index).GetXlaShape();
}
absl::StatusOr<xla::Shape> XlaOpKernelContext::InputXlaShape(
absl::string_view name) {
return InputExpression(name).GetXlaShape();
}
DataType XlaOpKernelContext::input_type(int index) const {
DataType type = context_->input_dtype(index);
if (type == DT_UINT8) {
// Masqueraded XlaExpression could have different type. See
// XlaOpKernelContext::SetOutputExpression for details.
auto expression =
XlaExpression::CastExpressionFromTensor(context_->input(index));
type = expression->dtype();
}
return type;
}
DataType XlaOpKernelContext::InputType(absl::string_view name) {
const Tensor& tensor = GetInputTensorByName(name);
DataType type = tensor.dtype();
if (type == DT_UINT8) {
// Masqueraded XlaExpression could have different type. See
// XlaOpKernelContext::SetOutputExpression for details.
auto expression = XlaExpression::CastExpressionFromTensor(tensor);
type = expression->dtype();
}
return type;
}
xla::PrimitiveType XlaOpKernelContext::input_xla_type(int index) {
xla::PrimitiveType type;
Status status = DataTypeToPrimitiveType(input_type(index), &type);
if (!status.ok()) {
SetStatus(status);
return xla::PRIMITIVE_TYPE_INVALID;
}
return type;
}
xla::PrimitiveType XlaOpKernelContext::InputXlaType(absl::string_view name) {
xla::PrimitiveType type;
Status status = DataTypeToPrimitiveType(InputType(name), &type);
if (!status.ok()) {
SetStatus(status);
return xla::PRIMITIVE_TYPE_INVALID;
}
return type;
}
Status XlaOpKernelContext::ConstantInput(int index,
xla::Literal* constant_literal,
xla::ValueInferenceMode mode) {
if (this->InputXlaShape(index)->is_dynamic()) {
return errors::InvalidArgument(
"Reading input as constant from a dynamic tensor is not yet supported. "
"Xla shape: ",
this->InputXlaShape(index)->ToString());
}
return ConstantInputReshaped(index,
context_->input(index).shape().dim_sizes(),
constant_literal, mode);
}
static absl::StatusOr<int> InputIndex(XlaOpKernelContext* context,
absl::string_view name) {
int start, stop;
TF_RETURN_IF_ERROR(context->op_kernel().InputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued input name '",
name,
"' when single-valued input was "
"expected");
}
return start;
}
Status XlaOpKernelContext::ResolveInputDynamism(
int index, xla::Literal* dynamism_literal) {
return ResolveInputDynamismReshaped(
index, context_->input(index).shape().dim_sizes(), dynamism_literal);
}
Status XlaOpKernelContext::ResolveInputDynamism(
absl::string_view name, xla::Literal* dynamism_literal) {
TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
return ResolveInputDynamism(index, dynamism_literal);
}
Status XlaOpKernelContext::ConstantInput(absl::string_view name,
xla::Literal* constant_literal,
xla::ValueInferenceMode mode) {
TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
return ConstantInput(index, constant_literal, mode);
}
Status XlaOpKernelContext::ConstantInputReshaped(
int index, absl::Span<const int64_t> new_dims,
xla::Literal* constant_literal, xla::ValueInferenceMode mode) {
TF_ASSIGN_OR_RETURN(Tensor constant, ConstantInputTensor(index, mode));
Tensor temp(constant.dtype());
if (!temp.CopyFrom(constant, TensorShape(new_dims))) {
return errors::InvalidArgument(
context_->op_kernel().name(), " input ", index, " has shape ",
constant.shape().DebugString(),
" but was asked to be reshaped to incompatible shape ",
TensorShape(new_dims).DebugString());
}
TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp));
return absl::OkStatus();
}
// Converts an int16, int32 or int64 scalar literal to an int64.
static Status LiteralToInt64Scalar(const xla::LiteralSlice& literal,
int64_t* out) {
if (literal.shape().rank() != 0) {
return errors::InvalidArgument("value is not a scalar");
}
if (literal.shape().element_type() == xla::S16) {
*out = literal.Get<int16>({});
} else if (literal.shape().element_type() == xla::S32) {
*out = literal.Get<int32>({});
} else if (literal.shape().element_type() == xla::S64) {
*out = literal.Get<int64_t>({});
} else {
return errors::InvalidArgument("value must be int16, int32, or int64");
}
return absl::OkStatus();
}
// Converts an float32 or float64 scalar literal to a float64.
static Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal,
double* out) {
if (literal.shape().rank() != 0) {
return errors::InvalidArgument("value is not a scalar");
}
if (literal.shape().element_type() == xla::F32) {
*out = literal.Get<float>({});
} else if (literal.shape().element_type() == xla::F64) {
*out = literal.Get<double>({});
} else {
return errors::InvalidArgument("value must be either float32 or float64");
}
return absl::OkStatus();
}
Status XlaOpKernelContext::ConstantInputAsIntScalar(
int index, int64_t* out, xla::ValueInferenceMode mode) {
xla::Literal literal;
TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
return LiteralToInt64Scalar(literal, out);
}
Status XlaOpKernelContext::ConstantInputAsIntScalar(
absl::string_view name, int64_t* out, xla::ValueInferenceMode mode) {
TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
return ConstantInputAsIntScalar(index, out, mode);
}
absl::StatusOr<int64_t> XlaOpKernelContext::ConstantInputAsIntScalar(
absl::string_view name, xla::ValueInferenceMode mode) {
int64_t out;
TF_RETURN_IF_ERROR(ConstantInputAsIntScalar(name, &out, mode));
return out;
}
Status XlaOpKernelContext::ConstantInputAsFloatScalar(
int index, double* out, xla::ValueInferenceMode mode) {
xla::Literal literal;
TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
return LiteralToFloat64Scalar(literal, out);
}
static Status LiteralToPredVector(const xla::LiteralSlice& literal,
std::vector<bool>* out) {
if (literal.shape().rank() != 1) {
return errors::InvalidArgument("output_shape must be rank 1, got shape ",
literal.shape().DebugString());
}
int64_t size = xla::ShapeUtil::ElementsIn(literal.shape());
if (literal.shape().element_type() != xla::PRED) {
return errors::InvalidArgument("value is not PRED");
}
for (int64_t i = 0; i < size; ++i) {
out->push_back(literal.Get<bool>({i}));
}
return absl::OkStatus();
}
Status XlaOpKernelContext::ResolveInputDynamismIntoPred(int index, bool* out) {
xla::Literal literal;
XlaExpression e = InputExpression(index);
absl::StatusOr<Tensor> dynamism_or_status = e.ResolveDynamism();
if (!dynamism_or_status.ok()) {
// When failed to resolve dynamism, conservatively consider the value
// dynamic. This could happen if the input depends on some ops like
// custom-call that is not supported generally for dynamism computation.
//
// TODO(b/176993339): Support resolving dynamism across computations so
// resolving dynamism will not fail in those cases.
*out = true;
return absl::OkStatus();
}
Tensor dynamism = dynamism_or_status.value();
Tensor temp(dynamism.dtype());
TensorShape tensor_shape({});
if (!temp.CopyFrom(dynamism, tensor_shape)) {
return errors::InvalidArgument(
context_->op_kernel().name(), " input ", index, " has shape ",
dynamism.shape().DebugString(), " which is not a R0 ", tensor_shape);
}
TF_ASSIGN_OR_RETURN(literal, HostTensorToLiteral(temp));
*out = literal.Get<bool>({});
return absl::OkStatus();
}
Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector(
absl::string_view name, std::vector<bool>* out) {
TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
return ResolveInputDynamismIntoPredVector(index, out);
}
Status XlaOpKernelContext::ResolveInputDynamismIntoPred(absl::string_view name,
bool* out) {
TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
return ResolveInputDynamismIntoPred(index, out);
}
Status XlaOpKernelContext::ResolveInputDynamismReshaped(
int index, absl::Span<const int64_t> new_dims,
xla::Literal* dynamism_literal) {
XlaExpression e = InputExpression(index);
absl::StatusOr<Tensor> dynamism_or_status = e.ResolveDynamism();
if (!dynamism_or_status.ok()) {
xla::Literal true_literal = xla::LiteralUtil::CreateR0<bool>(true);
// When failed to resolve dynamism, conservatively consider the value
// dynamic. This could happen if the input depends on some ops like
// custom-call that is not supported generally for dynamism computation.
*dynamism_literal =
true_literal
.Broadcast(xla::ShapeUtil::MakeShape(xla::PRED, new_dims), {})
.value();
return absl::OkStatus();
}
Tensor dynamism = dynamism_or_status.value();
Tensor temp(dynamism.dtype());
if (!temp.CopyFrom(dynamism, TensorShape(new_dims))) {
return errors::InvalidArgument(
context_->op_kernel().name(), " input ", index, " has shape ",
dynamism.shape().DebugString(),
" but was asked to be reshaped to incompatible shape ",
TensorShape(new_dims).DebugString());
}
TF_ASSIGN_OR_RETURN(*dynamism_literal, HostTensorToLiteral(temp));
return absl::OkStatus();
}
Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector(
int index, std::vector<bool>* out) {
xla::Literal literal;
TF_RETURN_IF_ERROR(ResolveInputDynamismReshaped(
index, {InputShape(index).num_elements()}, &literal));
return LiteralToPredVector(literal, out);
}
// Converts an int32 or int64 1D literal to an int64 vector.
static Status LiteralToInt64Vector(const xla::LiteralSlice& literal,
std::vector<int64_t>* out) {
if (literal.shape().rank() != 1) {
return errors::InvalidArgument("output_shape must be rank 1, got shape ",
literal.shape().DebugString());
}
int64_t size = xla::ShapeUtil::ElementsIn(literal.shape());
if (literal.shape().element_type() == xla::S32) {
for (int64_t i = 0; i < size; ++i) {
out->push_back(literal.Get<int32>({i}));
}
} else if (literal.shape().element_type() == xla::S64) {
for (int64_t i = 0; i < size; ++i) {
out->push_back(literal.Get<int64_t>({i}));
}
} else {
return errors::InvalidArgument("value must be either int32 or int64");
}
return absl::OkStatus();
}
Status XlaOpKernelContext::ConstantInputAsIntVector(
int index, std::vector<int64_t>* out, xla::ValueInferenceMode mode) {
xla::Literal literal;
TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
return LiteralToInt64Vector(literal, out);
}
Status XlaOpKernelContext::ConstantInputAsIntVector(
absl::string_view name, std::vector<int64_t>* out,
xla::ValueInferenceMode mode) {
TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
return ConstantInputAsIntVector(index, out, mode);
}
Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
int index, std::vector<int64_t>* out, xla::ValueInferenceMode mode) {
xla::Literal literal;
TF_RETURN_IF_ERROR(ConstantInputReshaped(
index, {InputShape(index).num_elements()}, &literal, mode));
return LiteralToInt64Vector(literal, out);
}
Status XlaOpKernelContext::ConstantInputReshapedToIntVector(
absl::string_view name, std::vector<int64_t>* out,
xla::ValueInferenceMode mode) {
TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
xla::Literal literal;
TF_RETURN_IF_ERROR(ConstantInputReshaped(
index, {InputShape(index).num_elements()}, &literal, mode));
return LiteralToInt64Vector(literal, out);
}
Status XlaOpKernelContext::ConstantInputAsInt64Literal(
int index, xla::Literal* out, xla::ValueInferenceMode mode) {
xla::Literal literal;
TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
switch (literal.shape().element_type()) {
case xla::S32: {
*out = xla::Literal(
xla::ShapeUtil::ChangeElementType(literal.shape(), xla::S64));
auto src_data = literal.data<int32>();
for (int64_t i = 0; i < src_data.size(); ++i) {
out->data<int64_t>()[i] = src_data[i];
}
return absl::OkStatus();
}
case xla::S64:
*out = std::move(literal);
return absl::OkStatus();
default:
return errors::InvalidArgument(
"Invalid argument to ConstantInputAsInt64Literal: ",
xla::ShapeUtil::HumanString(literal.shape()));
}
}
Status XlaOpKernelContext::ConstantInputAsInt64Literal(
absl::string_view name, xla::Literal* out, xla::ValueInferenceMode mode) {
TF_ASSIGN_OR_RETURN(int index, InputIndex(this, name));
return ConstantInputAsInt64Literal(index, out, mode);
}
// TODO(phawkins): validate that the dimensions form a valid shape, fail
// gracefully if they do not.
Status XlaOpKernelContext::ConstantInputAsShape(int index, TensorShape* shape,
xla::ValueInferenceMode mode) {
xla::Literal literal;
TF_RETURN_IF_ERROR(ConstantInput(index, &literal, mode));
std::vector<int64_t> dims;
TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims));
int64_t num_elements = 1;
for (auto i = dims.begin(); i != dims.end(); ++i) {
num_elements = MultiplyWithoutOverflow(num_elements, *i);
if (num_elements < 0)
return errors::InvalidArgument(
"The total elements specified by orig_input_shape is too large.",
"Encountered overflow after multiplying", *i,
", result: ", num_elements);
}
*shape = TensorShape(dims);
return absl::OkStatus();
}
Status XlaOpKernelContext::ConstantInputAsPartialShape(
int index, PartialTensorShape* shape) {
xla::Literal literal;
TF_RETURN_IF_ERROR(ConstantInput(index, &literal));
// If `literal` is a scalar it's value must be -1.
if (literal.shape().rank() == 0) {
int64_t shape_val;
TF_RETURN_IF_ERROR(LiteralToInt64Scalar(literal, &shape_val));
if (shape_val != -1) {
return errors::InvalidArgument(
"Cannot convert value to PartialTensorShape: ", shape_val);
}
*shape = PartialTensorShape(); // Shape with unknown rank.
return absl::OkStatus();
}
std::vector<int64_t> dims;
TF_RETURN_IF_ERROR(LiteralToInt64Vector(literal, &dims));
*shape = PartialTensorShape(dims);
return absl::OkStatus();
}
Status XlaOpKernelContext::InputList(absl::string_view name,
std::vector<xla::XlaOp>* handles,
std::vector<TensorShape>* shapes) {
OpInputList inputs;
TF_RETURN_IF_ERROR(context_->input_list(name, &inputs));
handles->clear();
shapes->clear();
for (const Tensor& input : inputs) {
handles->push_back(
XlaExpression::CastExpressionFromTensor(input)->AsXlaOp(builder()));
shapes->push_back(input.shape());
}
return absl::OkStatus();
}
Status XlaOpKernelContext::ConstantInputList(absl::string_view name,
std::vector<xla::Literal>* outputs,
xla::ValueInferenceMode mode) {
int start, stop;
TF_RETURN_IF_ERROR(op_kernel().InputRange(name, &start, &stop));
outputs->resize(stop - start);
for (int i = start; i < stop; ++i) {
TF_RETURN_IF_ERROR(ConstantInput(i, &(*outputs)[i], mode));
}
return absl::OkStatus();
}
absl::StatusOr<Tensor> XlaOpKernelContext::ConstantInputTensor(
int index, xla::ValueInferenceMode mode) {
XlaExpression e = InputExpression(index);
auto* client = compiler() ? compiler()->client() : nullptr;
absl::StatusOr<std::optional<Tensor>> constant_or_status =
e.ResolveConstant(client, dynamic_dimension_is_minus_one_, mode);
if (!constant_or_status.ok()) {
Status status = constant_or_status.status();
errors::AppendToMessage(&status, "while evaluating input ", index, " of ",
context_->op_kernel().type_string(),
" operator as a compile-time constant.");
return status;
}
std::optional<Tensor> constant = constant_or_status.value();
if (!constant.has_value()) {
return errors::InvalidArgument(
"Input ", index, " to node `", context_->op_kernel().name(),
"` with op ", context_->op_kernel().type_string(),
" must be a compile-time constant.\n\n"
"XLA compilation requires that operator arguments that represent "
"shapes or dimensions be evaluated to concrete values at compile time. "
"This error means that a shape or dimension argument could not be "
"evaluated at compile time, usually because the value of the argument "
"depends on a parameter to the computation, on a variable, or on a "
"stateful operation such as a random number generator.");
}
return *constant;
}
namespace {
Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
const XlaOpKernelContext* ctx,
TensorShape* shape, xla::XlaOp* value) {
const XlaExpression* expression =
XlaExpression::CastExpressionFromTensor(tensor);
XlaResource* variable = expression->resource();
TF_RET_CHECK(variable != nullptr);
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
if (!variable->initialized()) {
return errors::FailedPrecondition(
"Read variable failure ", variable->name(),
". It could mean the variable is uninitialized or the variable is on "
"another device ");
}
if (variable->type() != type) {
return errors::InvalidArgument(
"Trying to read variable with wrong dtype. Expected ",
DataTypeString(type), " got ", DataTypeString(variable->type()));
}
if (shape) {
*shape = variable->shape();
}
if (!variable->IsOverwritten() && expression->constant_value()) {
TF_ASSIGN_OR_RETURN(xla::Literal literal,
HostTensorToLiteral(*expression->constant_value()));
*value = xla::ConstantLiteral(ctx->builder(), literal);
return absl::OkStatus();
}
auto shape_determination_fns =
ctx->compiler()->options().shape_determination_fns;
XlaLayoutPreference layout_preference =
shape_determination_fns.layout_preference_fn(
variable->shape(), variable->type(), std::nullopt);
TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
shape_determination_fns.shape_representation_fn(
variable->shape(), variable->type(),
/*use_fast_memory=*/false, layout_preference));
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(
TensorShapeToXLAShape(variable->type(), variable->shape(), &xla_shape));
if (xla::ShapeUtil::Compatible(xla_shape, representation_shape)) {
*value = variable->value();
} else {
*value = xla::Reshape(variable->value(), variable->shape().dim_sizes());
}
return absl::OkStatus();
}
} // namespace
Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
TensorShape* shape,
xla::XlaOp* value) {
return ReadVariableInputTensor(context_->input(index), type, this, shape,
value);
}
Status XlaOpKernelContext::ReadVariableInput(absl::string_view name,
DataType type, TensorShape* shape,
xla::XlaOp* value) {
return ReadVariableInputTensor(GetInputTensorByName(name), type, this, shape,
value);
}
Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
TensorShape* shape) const {
const Tensor& tensor = context_->input(index);
const XlaExpression* expression =
XlaExpression::CastExpressionFromTensor(tensor);
XlaResource* variable = expression->resource();
TF_RET_CHECK(variable != nullptr);
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
if (!variable->initialized()) {
return errors::InvalidArgument(
"Read variable failure ", variable->name(),
". It could mean the variable is uninitialized or the variable is on "
"another device ");
}
*type = variable->type();
*shape = variable->shape();
return absl::OkStatus();
}
void XlaOpKernelContext::SetOutputExpression(int index,
const XlaExpression& expression) {
Status status = [&] {
// The step's default allocator is the dummy XlaCompilationAllocator which
// simply allocates a metadata buffer to hold the expression to which it
// corresponds.
// Provides a special behavior for DT_VARIANT and other types that are not
// trivially copyable. In those cases, allocate a tensor of type DT_UINT8.
if (!DataTypeCanUseMemcpy(expression.dtype())) {
// tensor_data() is not supported for tensors that cannot be copied via
// memcpy, as the copy logic might try to inspect the stored data (e.g.
// a std::string). This is likely to fail, as the data is invalid given
// that it actually encodes an XlaExpression. Using a uint8 tensor is
// always safe, so simply do that.
// TODO(jpienaar): This should be refactored to stop masquerading
// XlaExpressions as Tensors.
Tensor output;
TensorShape tensor_shape;
TF_RETURN_IF_ERROR(
context_->allocate_temp(DT_UINT8, tensor_shape, &output));
context_->set_output(index, output);
} else {
Tensor* output = nullptr;
TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape());
TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output));
}
XlaExpression::AssignExpressionToTensor(expression,
context_->mutable_output(index));
return absl::OkStatus();
}();
if (!status.ok()) {
SetStatus(status);
}
}
xla::PrimitiveType XlaOpKernelContext::output_xla_type(int index) {
xla::PrimitiveType type;
Status status = DataTypeToPrimitiveType(expected_output_dtype(index), &type);
if (!status.ok()) {
SetStatus(status);
return xla::PRIMITIVE_TYPE_INVALID;
}
return type;
}
void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
SetOutputExpression(
index,
XlaExpression::XlaOp(handle, context_->expected_output_dtype(index)));
}
void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
SetOutputExpression(index, XlaExpression::Constant(constant));
}
void XlaOpKernelContext::SetTensorListOutput(int index,
const xla::XlaOp& handle) {
SetOutputExpression(index, XlaExpression::TensorList(handle));
}
void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) {
SetOutputExpression(index, XlaExpression::Resource(resource));
}
Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
const XlaExpression* expression =
XlaExpression::CastExpressionFromTensor(context_->input(index));
TF_RET_CHECK(expression->resource() != nullptr);
*resource = expression->resource();
return absl::OkStatus();
}
namespace {
Status AssignVariableTensor(const Tensor& tensor, DataType type,
const XlaOpKernelContext* ctx, xla::XlaOp handle,
xla::XlaBuilder* builder) {
const XlaExpression* expression =
XlaExpression::CastExpressionFromTensor(tensor);
XlaResource* variable = expression->resource();
TF_RET_CHECK(variable != nullptr);
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
auto shape_or_status = builder->GetShape(handle);
if (!shape_or_status.ok()) {
return shape_or_status.status();
}
TensorShape shape;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape_or_status.value(), &shape));
TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
auto shape_determination_fns =
ctx->compiler()->options().shape_determination_fns;
XlaLayoutPreference layout_preference =
shape_determination_fns.layout_preference_fn(shape, type, std::nullopt);
TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
shape_determination_fns.shape_representation_fn(
shape, type,
/*use_fast_memory=*/false, layout_preference));
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
if (!xla::ShapeUtil::Compatible(xla_shape, representation_shape)) {
handle = xla::Reshape(handle, representation_shape.dimensions());
}
variable->SetRepresentationShape(representation_shape);
return variable->SetValue(handle);
}
} // namespace
Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
xla::XlaOp handle) {
TF_RET_CHECK(handle.valid());
return AssignVariableTensor(context_->input(input_index), type, this, handle,
builder());
}
Status XlaOpKernelContext::AssignVariable(absl::string_view name, DataType type,
xla::XlaOp handle) {
TF_RET_CHECK(handle.valid());
return AssignVariableTensor(GetInputTensorByName(name), type, this, handle,
builder());
}
static Status GetStatusWithStackTrace(const Status& s,
const XlaOpKernelContext* ctx) {
if (s.code() == error::INVALID_ARGUMENT) {
return Status{s.code(), absl::StrCat(s.message(), "\n", ctx->StackTrace())};
}
return s;
}
void XlaOpKernelContext::CtxFailure(const Status& s) {
context_->CtxFailure(GetStatusWithStackTrace(s, this));
}
void XlaOpKernelContext::CtxFailureWithWarning(const Status& s) {
context_->CtxFailureWithWarning(GetStatusWithStackTrace(s, this));
}
void XlaOpKernelContext::CtxFailure(const char* file, int line,
const Status& s) {
context_->CtxFailure(file, line, GetStatusWithStackTrace(s, this));
}
void XlaOpKernelContext::CtxFailureWithWarning(const char* file, int line,
const Status& s) {
context_->CtxFailureWithWarning(file, line, GetStatusWithStackTrace(s, this));
}
const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMax(
const DataType type) {
return xla_context()->GetOrCreateMax(type);
}
const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMin(
const DataType type) {
return xla_context()->GetOrCreateMin(type);
}
const xla::XlaComputation* XlaOpKernelContext::GetOrCreateAdd(
const DataType type) {
return xla_context()->GetOrCreateAdd(type);
}
const xla::XlaComputation* XlaOpKernelContext::GetOrCreateLogAddExp(
const DataType type) {
return xla_context()->GetOrCreateLogAddExp(type);
}
const xla::XlaComputation* XlaOpKernelContext::GetOrCreateMul(
const DataType type) {
return xla_context()->GetOrCreateMul(type);
}
const Tensor& XlaOpKernelContext::GetInputTensorByName(absl::string_view name) {
const Tensor* tensor;
CHECK(context_->input(name, &tensor).ok());
return *tensor;
}
XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {}
void XlaOpKernel::Compute(OpKernelContext* context) {
XlaOpKernelContext xla_context(context);
Compile(&xla_context);
}
std::string XlaOpKernelContext::StackTrace() const {
if (const AbstractStackTrace* stack_trace =
xla_context()->StackTraceForNodeName(op_kernel().name())) {
AbstractStackTrace::TracePrintingOptions opts;
opts.show_line_contents = true;
opts.filter_common_prefix = true;
opts.drop_internal_frames = true;
return absl::StrCat("\nStack trace for op definition: \n",
stack_trace->ToString(opts), "\n");
} else {
return "";
}
}
} // namespace tensorflow