Permalink
Cannot retrieve contributors at this time
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. | |
Licensed under the Apache License, Version 2.0 (the "License"); | |
you may not use this file except in compliance with the License. | |
You may obtain a copy of the License at | |
http://www.apache.org/licenses/LICENSE-2.0 | |
Unless required by applicable law or agreed to in writing, software | |
distributed under the License is distributed on an "AS IS" BASIS, | |
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
See the License for the specific language governing permissions and | |
limitations under the License. | |
==============================================================================*/ | |
syntax = "proto3"; | |
package xla; | |
option cc_enable_arenas = true; | |
// Primitive types are the individual values that can be held in rectangular | |
// multidimensional arrays. A description of the rectangular multidimensional | |
// array dimensions / primitive type is given by Shape, below. | |
enum PrimitiveType { | |
// Invalid primitive type to serve as default. | |
PRIMITIVE_TYPE_INVALID = 0; | |
// Predicates are two-state booleans. | |
PRED = 1; | |
// Signed integral values of fixed width. | |
S8 = 2; | |
S16 = 3; | |
S32 = 4; | |
S64 = 5; | |
// Unsigned integral values of fixed width. | |
U8 = 6; | |
U16 = 7; | |
U32 = 8; | |
U64 = 9; | |
// Floating-point values of fixed width. | |
// | |
// Note: if f16s are not natively supported on the device, they will be | |
// converted to f16 from f32 at arbirary points in the computation. | |
F16 = 10; | |
F32 = 11; | |
// Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit | |
// floating-point format, but uses 1 bit for the sign, 8 bits for the exponent | |
// and 7 bits for the mantissa. | |
BF16 = 16; | |
F64 = 12; | |
// Complex values of fixed width. | |
C64 = 15; // Paired F32 (real, imag), as in std::complex<float>. | |
// A tuple is a polymorphic sequence; e.g. a shape that holds different | |
// sub-shapes. They are used for things like returning multiple values from a | |
// computation; e.g. a computation that returns weights and biases may have a | |
// signature that results in a tuple like (f32[784x2000], f32[2000]) | |
// | |
// If a shape proto has the tuple element type, it may not have any entries | |
// in the dimensions field. | |
TUPLE = 13; | |
// An opaque type used for passing context specific data to a custom | |
// operation. | |
OPAQUE = 14; | |
// Next = 17 | |
} | |
// Describes the value held inside padding elements. | |
enum PaddingValue { | |
INVALID_PAD = 0; | |
// Zero padding must be 0-values that correspond to the shape's element type. | |
ZERO_PAD = 1; | |
// One padding must be 1-values that correspond to the shape's element type. | |
ONE_PAD = 2; | |
// "Lowest" padding must be the lowest values in the shape's element type, | |
// used as padding for operations like max-accumulation. | |
LOWEST_PAD = 3; | |
// "Highest" padding must be the largest values in the shape's element type, | |
// used as padding for operations like min-accumulation. | |
HIGHEST_PAD = 4; | |
// Unknown padding could be anything; e.g. floating NaNs! | |
UNKNOWN_PAD = 5; | |
} | |
// Describes the padding configuration for Pad operation. The padding amount on | |
// both edges as well as between the elements are specified for each dimension. | |
message PaddingConfig { | |
// Describes the padding configuration for a dimension. | |
message PaddingConfigDimension { | |
// Padding amount on the low-end (next to the index 0). | |
int64 edge_padding_low = 1; | |
// Padding amount on the high-end (next to the highest index). | |
int64 edge_padding_high = 2; | |
// Padding amount between the elements. | |
int64 interior_padding = 3; | |
} | |
// The padding configuration for all dimensions. | |
repeated PaddingConfigDimension dimensions = 1; | |
} | |
// A format specifies the method used by a layout to store an array in memory. | |
enum Format { | |
INVALID_FORMAT = 0; | |
// The default layout, with exactly one storage location per element (ignoring | |
// padding). | |
DENSE = 1; | |
// A sparsely encoded layout, providing only the index/value pairs of non-zero | |
// elements. | |
SPARSE = 2; | |
} | |
// A layout describes how the array is placed in (1D) memory space. This | |
// includes the minor-to-major ordering of dimensions within a shape, as well as | |
// any padding present in those dimensions. | |
// | |
// Clients must specify the layouts of input Literals to the | |
// computation. Layouts specified in interior operations which take Shapes (for | |
// example, Convert) are ignored. | |
// | |
// See the XLA documentation for more information on shapes and layouts. | |
message Layout { | |
// The method used to store the data in memory. The format determines which of | |
// the other fields are used by the layout. | |
Format format = 4; | |
// Sequence of dimension numbers, from minor (fastest varying index) to major | |
// (slowest varying index). This field is required. | |
repeated int64 minor_to_major = 1; | |
// The width to which the layout of each dimension is padded up to. If | |
// present, the size of the padded_dimensions must equal the rank of the | |
// shape. The padding appears at the end of a dimension, not at the | |
// beginning. This kind of padding, unlike padding in e.g. convolution, is not | |
// part of the shape. This field must be unset unless the format is DENSE. | |
repeated int64 padded_dimensions = 2; | |
// Describes the values in the padding specified by padded_dimensions. This | |
// field must be unset unless the format is DENSE. | |
PaddingValue padding_value = 3; | |
// The maximum number of elements that can be stored for SPARSE formats. This | |
// can be used to determine the maximum size in bytes of arrays stored in | |
// memory. This field must be unset unless the format is SPARSE. | |
int64 max_sparse_elements = 5; | |
// Important: if any field is added, be sure to modify ShapeUtil::Equal() | |
// appropriately to account for the new field. | |
} | |
// A shape describes the number of dimensions in the array, the size of each | |
// dimension, and the primitive component type. | |
// | |
// Tuples are a special case in that they have rank zero and have tuple_shapes | |
// defined. | |
// | |
// See the XLA documentation for more information on shapes and layouts. | |
message Shape { | |
reserved 1; | |
reserved "rank"; | |
// The element type for this shape. | |
PrimitiveType element_type = 2; | |
// The size (number of elements) for each dimension. | |
// In XLA, dimensions are numbered from 0 to N-1 for an | |
// N-dimensional array. The first element of 'dimensions' is the size of | |
// dimension 0, the second element is the size of dimension 1, and so forth. | |
// Empty list indicates a scalar. | |
repeated int64 dimensions = 3; | |
// For tuples only, the shapes of constitutent shapes in the tuple sequence. | |
repeated Shape tuple_shapes = 4; | |
// The layout used to back this shape. | |
Layout layout = 5; | |
// Important: if any field is added, be sure to modify ShapeUtil::Equal() and | |
// ShapeUtil::Compatible() appropriately to account for the new field. | |
} | |
// Shape of the parameters and output of a computation (like a traditional | |
// function signature). | |
message ProgramShape { | |
repeated Shape parameters = 1; | |
Shape result = 2; | |
repeated string parameter_names = 3; | |
} | |
// Statistics of a computation. | |
message ComputationStats { | |
// The number of floating point operations in the computation. | |
double flop_count = 1; | |
// The number of transcendental operations (e.g., exp) in the computation. | |
double transcendental_count = 2; | |
} | |
// Symbolization metadata for HLO Instructions. | |
// | |
// This metadata is used for debugging XLA code generation, as well as | |
// performance profiling of XLA-generated executables. | |
message OpMetadata { | |
// The framework op name that generated this XLA op. | |
// | |
// Frameworks that build on top of XLA should mirror the names of their ops | |
// back to users by specifying the op_type. In this way, even if the | |
// framework's "ops" are implemented as multiple XLA HLO Ops, they can be | |
// grouped appropriately. (e.g. if a SoftMax layer is emitted into XLA as | |
// multiple ops, then each op should have the op_type be "SoftMax".) | |
string op_type = 1; | |
// The user-specified name of the op. | |
// | |
// This name is often unique within a computation. Note: some frameworks | |
// add auto-generated names if the user does not provide one. | |
string op_name = 2; | |
// Indicate a file and line that this op is associated to in a user's program. | |
// | |
// e.g. it could be the file and line of user code that generated the op. | |
string source_file = 3; | |
int32 source_line = 4; | |
} | |
// Profile data from the execution of a computation. | |
message ExecutionProfile { | |
// Whether the executable was read from the compilation cache. | |
bool compilation_cache_hit = 1; | |
// The time in milliseconds spent to compile the computation. This only set if | |
// the executable was not read from the compilation cache | |
// (compilation_cache_hit == false). | |
int64 compile_time_ms = 2; | |
// The number of cycles spent for the computation. This does not include the | |
// time taken for the data transfers between the host and the device. This is | |
// a target-dependent field and only used for debugging purposes. | |
int64 compute_cycle_count = 3; | |
// The time in nanoseconds spent for the computation, without data transfer. | |
int64 compute_time_ns = 4; | |
// The time in nanoseconds spent for the entire computation, including the | |
// result data transfer time. Current implementation does not spend any cycles | |
// for the input data transfer since the memory is initialized with the proper | |
// values before the execution. | |
int64 compute_and_transfer_time_ns = 5; | |
} | |
// Handle given to a user that represents a computation that the user builds up | |
// before execution. | |
message ComputationHandle { | |
int64 handle = 1; | |
} | |
// Handle given to a user that represents an execution that the user launched | |
// asynchronously on the device. | |
message ExecutionHandle { | |
int64 handle = 1; | |
} | |
// Handle given to a user that represents a globally accessible allocation. | |
// Contrast this against a ComputationDataHandle, which is not globally | |
// accessible, since it only exists within a specific computation. | |
message GlobalDataHandle { | |
int64 handle = 1; | |
} | |
// Handle given to a user that represents a data result in a computation. | |
// This is used to pass to subsequent computations that depends upon the data as | |
// an operand. | |
message ComputationDataHandle { | |
int64 handle = 1; | |
} | |
// Handle given to a user that represents a replicated virtual device. Each | |
// replicated device represents N physical devices for execution where N is the | |
// number of replicas. | |
message DeviceHandle { | |
int64 handle = 1; | |
// The number of model-parallel virtual devices that communicate via XLA | |
// Send/Recv instructions. | |
int64 device_count = 2; | |
} | |
// Handle given to a user to represent a channel between two computations | |
// via a Send and Recv instruction pair. Channels are unbuffered, so Send | |
// Send instructions will be blocked until the data is transferred. | |
message ChannelHandle { | |
int64 handle = 1; | |
} | |
// DeviceAssignmentProto is a serialized form of DeviceAssignment class, which | |
// represents the device ids assigned to a set of replicated computations. | |
// See xla::DeviceAssignment class comment for more details. | |
message DeviceAssignmentProto { | |
int32 replica_count = 1; | |
int32 computation_count = 2; | |
// Each logical computation runs on replica_count physical devices. | |
// ComputationDevice represents the device ids assinged to the replicas. | |
message ComputationDevice { | |
repeated int32 replica_device_ids = 1; | |
} | |
repeated ComputationDevice computation_devices = 3; | |
} | |
// Literals are used when the server and client need to exchange materialized | |
// data / results. Literals are also used to describe constants used in | |
// computations. | |
// | |
// Transfers to/from the client are encoded in literal form, and the structure | |
// of the repeated fields is implied by the shape. | |
message LiteralProto { | |
Shape shape = 1; | |
repeated bool preds = 2; | |
bytes u8s = 3; | |
repeated int32 s32s = 4; | |
repeated int64 s64s = 5; | |
repeated uint32 u32s = 6; | |
repeated uint64 u64s = 7; | |
repeated float f32s = 8; | |
repeated double f64s = 9; | |
repeated float c64s = 12; // Stored as interleaved real, imag floats. | |
repeated LiteralProto tuple_literals = 10; | |
// The F16s and BF16s are encoded in little endian byte order | |
bytes f16s = 11; | |
bytes bf16s = 13; | |
repeated int64 sparse_indices = 14; | |
// Next = 15 | |
} | |
message WindowDimension { | |
// The size of the window in this dimension. For a rectangle, this would be | |
// the width or height. | |
int64 size = 1; | |
// The stride at which the window moves across the base area in this | |
// dimension. In other words, this is the spacing between different | |
// positions of the window in this dimension. | |
int64 stride = 2; | |
// If positive, means the amount of padding with zeroes to add to the base | |
// area at the low end of this dimension; if negative, its negative means the | |
// number of elements removed from the low end of this dimension. For example, | |
// in the horizontal dimension of a rectangle, this would be the number of | |
// zeroes to pad on the left, given that indices increase when going right. | |
int64 padding_low = 3; | |
// As padding_low, but on the high end of this dimension. For | |
// example, in the horizontal dimension of a rectangle, this would | |
// be the number of zeroes to pad on the right, given that indices | |
// increase when going right. | |
int64 padding_high = 4; | |
// Dilation factor of the sliding window in this dimension. A dilation factor | |
// of 1 means no dilation. window_dilation - 1 no-op entries ("holes") are | |
// implicitly placed between each kernel element. See documentation for | |
// convolution. | |
int64 window_dilation = 5; | |
// Dilation factor of the base area in this dimension. A dilation factor of 1 | |
// means no dilation. base_dilation - 1 no-op entries ("holes") are implicitly | |
// placed between each base area element. See documentation for convolution. | |
int64 base_dilation = 6; | |
// Window reversal means that this dimension was logically reversed before the | |
// operation. | |
bool window_reversal = 7; | |
} | |
// Describes the windowing in an operation such as convolution. | |
// | |
// The window is moved across a base area and for each position of the | |
// window a computation is performed. The field below describes the | |
// window and the movement of the window across a base area. | |
message Window { | |
repeated WindowDimension dimensions = 1; | |
} | |
// Describes the dimension numbers for a gather operation. | |
// | |
// See https://www.tensorflow.org/performance/xla/operation_semantics#gather for | |
// more details. | |
message GatherDimensionNumbers { | |
// "Window indices" is a term for a set of indices that index into the | |
// interior of a dynamic-slice from the input tensor, the starting indices for | |
// which were computed from output_gather_dims (see the operation semantic for | |
// how this is defined) and the gather_indices tensor. | |
// | |
// The window indices for a specific output index Out is computed as: | |
// | |
// i = 0 | |
// for (k : [0, input_tensor_shape.rank)) | |
// window_indices[k] = | |
// if k in elided_window_dims | |
// then 0 | |
// else Out[output_window_dims[i++]] | |
repeated int64 output_window_dims = 1; | |
repeated int64 elided_window_dims = 2; | |
// This is interpreted as a map from i to gather_dims_to_operand_dims[i]. It | |
// transforms the gather index looked up from the gather_indices tensor into | |
// the starting index in the input space. | |
repeated int64 gather_dims_to_operand_dims = 3; | |
// The dimension in the gather_indices input that contains the starting | |
// indices. | |
int64 index_vector_dim = 4; | |
} | |
// Operation requests that are all collected as a tagged union with a oneof | |
// field in OpRequest. | |
message ConstantRequest { | |
LiteralProto literal = 2; | |
} | |
message GetTupleElementRequest { | |
ComputationDataHandle operand = 2; | |
int64 index = 3; | |
} | |
message SliceRequest { | |
ComputationDataHandle operand = 2; | |
repeated int64 start_indices = 3; | |
repeated int64 limit_indices = 4; | |
repeated int64 strides = 5; | |
} | |
message DynamicSliceRequest { | |
// Operand from which to slice at dynamic 'start_indices'. | |
ComputationDataHandle operand = 2; | |
// Dynamically computed 'start_indices' for slice operation. | |
ComputationDataHandle start_indices = 3; | |
// Slice sizes for each dimension (note that indices calculations are computed | |
// modulo dimension sizes to avoid out-of-bound array accesses). | |
repeated int64 slice_sizes = 4; | |
} | |
message DynamicUpdateSliceRequest { | |
// Operand on which slice 'update' is to be applied. | |
ComputationDataHandle operand = 2; | |
// The slice update to apply to 'operand'. | |
ComputationDataHandle update = 3; | |
// Dynamically computed start indices for the update slice operation. | |
ComputationDataHandle start_indices = 4; | |
} | |
message ConvolutionDimensionNumbers { | |
// The number of the dimension that represents batch in the input. | |
int64 input_batch_dimension = 7; | |
// The number of the dimension that represents features in the input. | |
int64 input_feature_dimension = 8; | |
// The dimension numbers for the spatial dimensions that the window | |
// moves through in the input. | |
repeated int64 input_spatial_dimensions = 11; | |
// The number of the dimension that represents input features in the | |
// convolutional kernel (rhs). | |
int64 kernel_input_feature_dimension = 3; | |
// The number of the dimension that represents output features in | |
// the convolutional kernel (rhs). | |
int64 kernel_output_feature_dimension = 4; | |
// The dimension numbers for the spatial dimensions that the window | |
// moves through in the kernel (rhs). window.strides(0) is the | |
// stride in the kernel_spatial_dimensions(0) dimension. | |
repeated int64 kernel_spatial_dimensions = 6; | |
// The number of the dimension that represents batch in the output. | |
int64 output_batch_dimension = 9; | |
// The number of the dimension that represents features in the output. | |
int64 output_feature_dimension = 10; | |
// The dimension numbers for the spatial dimensions that the window | |
// moves through in the output. | |
repeated int64 output_spatial_dimensions = 12; | |
// Next = 13 | |
}; | |
message ConvolveRequest { | |
ComputationDataHandle lhs = 2; | |
ComputationDataHandle rhs = 3; // This is the filter/kernel. | |
Window window = 4; // Describes the filter/kernel. | |
ConvolutionDimensionNumbers dimension_numbers = 5; | |
} | |
enum FftType { | |
FFT = 0; // Forward FFT; complex in, complex out. | |
IFFT = 1; // Inverse FFT; complex in, complex out. | |
RFFT = 2; // Forward real FFT; real in, fft_length / 2 + 1 complex out | |
IRFFT = 3; // Inverse real FFT; fft_length / 2 + 1 complex in, | |
// fft_length real out | |
} | |
message FftRequest { | |
FftType fft_type = 1; | |
repeated int64 fft_length = 2; // Multivalent for higher-order FFT. | |
ComputationDataHandle operand = 3; | |
} | |
message InfeedRequest { | |
// The shape of the data returned by reading the device's infeed buffer. | |
Shape shape = 2; | |
// Additional infeed configuration for the backend. | |
bytes config = 3; | |
} | |
message OutfeedRequest { | |
// The shape of the data returned by reading the device's outfeed buffer. | |
Shape shape = 1; | |
// Operand to the Outfeed. Supports tuple. | |
ComputationDataHandle operand = 2; | |
// Backend-specific information for how to perform the outfeed. | |
bytes outfeed_config = 3; | |
} | |
message CallRequest { | |
ComputationHandle to_apply = 2; | |
repeated ComputationDataHandle operands = 3; | |
} | |
message CustomCallRequest { | |
string call_target_name = 2; | |
repeated ComputationDataHandle operands = 3; | |
Shape shape = 4; | |
} | |
message HostComputeRequest { | |
// Operand to the HostCompute. Supports tuple. | |
repeated ComputationDataHandle operands = 1; | |
// Name used to identify HostSend/Recv channels. | |
string channel_name = 2; | |
// Cost estimate in nanoseconds. | |
int64 cost_estimate_ns = 3; | |
// The shape of any data returned by host. | |
Shape shape = 4; | |
} | |
message DotDimensionNumbers { | |
// The dimension numbers that represent the 'lhs' contracting dimensions. | |
repeated int64 lhs_contracting_dimensions = 1; | |
// The dimension numbers that represent the 'rhs' contracting dimensions. | |
repeated int64 rhs_contracting_dimensions = 2; | |
// The dimension numbers that represent the 'lhs' batch dimensions. | |
repeated int64 lhs_batch_dimensions = 3; | |
// The dimension numbers that represent the 'rhs' batch dimensions. | |
repeated int64 rhs_batch_dimensions = 4; | |
}; | |
message DotRequest { | |
ComputationDataHandle lhs = 2; | |
ComputationDataHandle rhs = 3; | |
DotDimensionNumbers dimension_numbers = 4; | |
} | |
message MapRequest { | |
repeated ComputationDataHandle operands = 2; | |
ComputationHandle to_apply = 3; | |
repeated ComputationDataHandle static_operands = 4; | |
// The dimensions over which to map. | |
// Example mapping a Dot operation along the batch dimension 0: | |
// operand0.shape = [2, 2, 2], operand1.shape = [2,2,3] | |
// Map({operand0, operand1}, Dot, {0}) | |
repeated int64 dimensions = 5; | |
} | |
message ReduceRequest { | |
// Operand to the reduction. | |
ComputationDataHandle operand = 2; | |
// Initial value for the reduction. This must be consistent with the result | |
// shape of to_apply. | |
ComputationDataHandle init_value = 3; | |
// The dimensions to reduce over. | |
repeated int64 dimensions = 4; | |
// The computation to apply in the reduction. | |
ComputationHandle to_apply = 5; | |
} | |
message ReduceWindowRequest { | |
ComputationDataHandle operand = 2; | |
ComputationDataHandle init_value = 3; | |
Window window = 4; | |
ComputationHandle to_apply = 5; | |
} | |
message BatchNormTrainingRequest { | |
ComputationDataHandle operand = 1; | |
ComputationDataHandle scale = 2; | |
ComputationDataHandle offset = 3; | |
float epsilon = 4; | |
int64 feature_index = 5; | |
} | |
message BatchNormInferenceRequest { | |
ComputationDataHandle operand = 1; | |
ComputationDataHandle scale = 2; | |
ComputationDataHandle offset = 3; | |
ComputationDataHandle mean = 4; | |
ComputationDataHandle variance = 5; | |
float epsilon = 6; | |
int64 feature_index = 7; | |
} | |
message BatchNormGradRequest { | |
ComputationDataHandle operand = 1; | |
ComputationDataHandle scale = 2; | |
ComputationDataHandle mean = 3; | |
ComputationDataHandle variance = 4; | |
ComputationDataHandle grad_output = 5; | |
float epsilon = 6; | |
int64 feature_index = 7; | |
} | |
message CrossReplicaSumRequest { | |
ComputationDataHandle operand = 2; | |
} | |
message SelectAndScatterRequest { | |
// Operand array on which the windows slide. | |
ComputationDataHandle operand = 2; | |
// Source array for the data to scatter. | |
ComputationDataHandle source = 3; | |
// Initial scalar value for each element in the output. | |
ComputationDataHandle init_value = 4; | |
// Window configuration. | |
Window window = 5; | |
// Binary function used to select an element from each window. | |
ComputationHandle select = 6; | |
// Binary function used to combine each scattered value from source with the | |
// current output value at the selected location. | |
ComputationHandle scatter = 7; | |
} | |
message ReverseRequest { | |
ComputationDataHandle operand = 2; | |
repeated int64 dimensions = 3; | |
} | |
message BroadcastRequest { | |
ComputationDataHandle operand = 2; | |
repeated int64 broadcast_sizes = 3; | |
} | |
message PadRequest { | |
ComputationDataHandle operand = 2; | |
ComputationDataHandle padding_value = 3; | |
PaddingConfig padding_config = 4; | |
} | |
message ReshapeRequest { | |
ComputationDataHandle operand = 2; | |
// The dimension order for collapse (from fastest-changing to slowest). | |
repeated int64 dimensions = 3; | |
// The new dimension sizes (from dimension 0 to n-1). | |
repeated int64 new_sizes = 4; | |
} | |
message TransposeRequest { | |
ComputationDataHandle operand = 2; | |
// The permutation of the operand's dimensions (in the range 0 to n-1). | |
repeated int64 dimensions = 3; | |
} | |
message ParameterRequest { | |
Shape shape = 2; | |
int64 parameter = 3; | |
string name = 4; | |
} | |
message GetLocalShapeRequest { | |
ComputationHandle computation = 1; | |
ComputationDataHandle operand = 2; | |
} | |
message GetLocalShapeResponse { | |
Shape shape = 1; | |
} | |
message TraceRequest { | |
string tag = 2; | |
ComputationDataHandle operand = 3; | |
} | |
message ConvertRequest { | |
ComputationDataHandle operand = 2; | |
PrimitiveType new_element_type = 3; | |
} | |
message ConcatenateRequest { | |
repeated ComputationDataHandle operands = 2; | |
// The dimension in which we concatenate; e.g. if you had dimension arrays of | |
// [4, 1] and [5, 1], you'd concatenate in dimension 0 to produce a [9, 1]. | |
// Attempting to concatenate those in dimension 1 would produce an error, as | |
// 4 != 5 (and there is no ragged array support). | |
int64 dimension = 3; | |
} | |
message ConditionalRequest { | |
ComputationDataHandle predicate = 2; | |
ComputationDataHandle true_operand = 3; | |
ComputationHandle true_computation = 4; | |
ComputationDataHandle false_operand = 5; | |
ComputationHandle false_computation = 6; | |
} | |
message WhileRequest { | |
ComputationHandle condition = 2; | |
ComputationHandle body = 3; | |
ComputationDataHandle init = 4; | |
} | |
enum UnaryOperation { | |
UNOP_INVALID = 0; | |
// Elementwise, logical negation on booleans and bitwise negation on ints. | |
UNOP_NOT = 1; | |
// Elementwise, computes e^x. | |
UNOP_EXP = 2; | |
// Elementwise, computes -x. | |
UNOP_NEGATE = 3; | |
// Puts the elements in the operand into sorted order. | |
UNOP_SORT = 4; | |
// Elementwise, computes tanh(x). | |
UNOP_TANH = 5; | |
// Elementwise, computes the natural logarithm of x. | |
UNOP_LOG = 6; | |
// Elementwise, computes the floor of x. | |
UNOP_FLOOR = 7; | |
// Elementwise, computes the ceil of x. | |
UNOP_CEIL = 8; | |
// Elementwise, computes the abs of x. | |
UNOP_ABS = 9; | |
// Elementwise, computes the sign of x. | |
UNOP_SIGN = 10; | |
// Elementwise, tests if values are finite (not NaN or inf) | |
UNOP_IS_FINITE = 11; | |
// Elementwise, computes the cosine of x. | |
UNOP_COS = 12; | |
// Elementwise, computes the sine of x. | |
UNOP_SIN = 13; | |
// Elementwise, rounds x to nearest integral value, rounding half-way cases | |
// away from zero. | |
UNOP_ROUND_NEAREST_AFZ = 14; | |
// Elementwise, extract real component of complex x. | |
UNOP_REAL = 15; | |
// Elementwise, extract real component of complex x. | |
UNOP_IMAG = 16; | |
} | |
message UnaryOpRequest { | |
UnaryOperation unop = 2; | |
ComputationDataHandle operand = 3; | |
} | |
enum BinaryOperation { | |
BINOP_INVALID = 0; | |
// Arithmetic operations. | |
BINOP_ADD = 1; | |
BINOP_DIV = 2; | |
BINOP_MUL = 3; | |
BINOP_SUB = 4; | |
// Comparison operators. | |
BINOP_EQ = 5; | |
BINOP_GE = 6; | |
BINOP_GT = 7; | |
BINOP_LE = 8; | |
BINOP_LT = 9; | |
BINOP_NE = 10; | |
// Element-wise maximum. | |
BINOP_MAX = 14; | |
// Element-wise minimum. | |
BINOP_MIN = 15; | |
// Raises the left-hand-side to the right-hand-side power. | |
BINOP_POW = 16; | |
// Remainder operation. | |
BINOP_REM = 17; | |
// Element-wise, logical operators on booleans and bitwise operators on ints. | |
BINOP_AND = 18; | |
BINOP_OR = 19; | |
BINOP_SHIFT_LEFT = 20; | |
BINOP_SHIFT_RIGHT_ARITHMETIC = 21; | |
BINOP_SHIFT_RIGHT_LOGICAL = 22; | |
// Complex from real, imag. | |
BINOP_COMPLEX = 23; | |
// Computes the 4-quadrant arctangent of the y, x input arguments. | |
BINOP_ATAN2 = 24; | |
} | |
message BinaryOpRequest { | |
BinaryOperation binop = 2; | |
ComputationDataHandle lhs = 3; | |
ComputationDataHandle rhs = 4; | |
repeated int64 broadcast_dimensions = 5; | |
} | |
enum RandomDistribution { | |
RNG_INVALID = 0; | |
// Creates a uniform-distribution-generated random number on the semi-open | |
// interval [parameter[0], parameter[1]). | |
RNG_UNIFORM = 1; | |
// Creates a normal-distribution-generated random number with mean | |
// parameter[0] and standard deviation parameter[1]. | |
RNG_NORMAL = 2; | |
// Next: 4 | |
} | |
message RngRequest { | |
RandomDistribution distribution = 2; | |
repeated ComputationDataHandle parameter = 3; | |
Shape shape = 4; | |
} | |
enum TernaryOperation { | |
TRIOP_INVALID = 0; | |
// Given a predicate and two operands, selects operand0 if the predicate is | |
// true and operand1 if the predicate is false. | |
TRIOP_SELECT = 1; | |
// Given a min, max and an operand returns the operand if between min and max, | |
// else returns min if operand is less than min or max if operand is greater | |
// than max. | |
TRIOP_CLAMP = 3; | |
} | |
message TernaryOpRequest { | |
TernaryOperation triop = 2; | |
ComputationDataHandle lhs = 3; | |
ComputationDataHandle rhs = 4; | |
ComputationDataHandle ehs = 5; | |
} | |
enum VariadicOperation { | |
VAROP_INVALID = 0; | |
// Creates a tuple from its operands. | |
VAROP_TUPLE = 1; | |
} | |
message VariadicOpRequest { | |
VariadicOperation varop = 2; | |
repeated ComputationDataHandle operands = 3; | |
} | |
message ReducePrecisionRequest { | |
ComputationDataHandle operand = 1; | |
int32 exponent_bits = 2; | |
int32 mantissa_bits = 3; | |
} | |
message SendRequest { | |
ComputationDataHandle operand = 1; | |
ChannelHandle channel_handle = 2; | |
} | |
message RecvRequest { | |
Shape shape = 1; | |
ChannelHandle channel_handle = 2; | |
} | |
message GatherRequest { | |
ComputationDataHandle input = 1; | |
ComputationDataHandle gather_indices = 2; | |
GatherDimensionNumbers dimension_numbers = 3; | |
repeated int64 window_bounds = 4; | |
} | |
message OpSharding { | |
enum Type { | |
// This sharding is replicated across all devices (implies maximal, | |
// all other fields are unused). | |
REPLICATED = 0; | |
// This sharding is maximal - one device runs the entire operation. | |
MAXIMAL = 1; | |
// This sharding is a tuple - only the tuple_shardings field is valid. | |
TUPLE = 2; | |
// None of the above; tile_shape and tile_assignment are both used. | |
OTHER = 3; | |
} | |
Type type = 1; | |
// The shape of the sharded tile. | |
Shape tile_shape = 2; | |
// The shape of the tile assignment tensor - this must be the same rank as | |
// tile_shape and the product of its dimensions must equal | |
// tile_assignment_devices.size(). | |
repeated int64 tile_assignment_dimensions = 3; | |
// Flattened list of device IDs. The order of flattening is the same as used | |
// by IndexUtil::MultiToLinearIndex(tile_assignment_shape). | |
repeated int64 tile_assignment_devices = 4; | |
// If type == TUPLE, the sub-shardings, one per leaf node in the tuple shape, | |
// in pre-order. The tuple shape could be nested; here we store just a | |
// flattened list of all leaves in the tuple shape. Note that the tuple shape | |
// is not stored here; shardings do not store the shapes to which they are | |
// applied, this is inferred from the instruction this sharding gets attached | |
// to. | |
repeated OpSharding tuple_shardings = 5; | |
} | |
message OpRequest { | |
ComputationHandle computation = 1; | |
OpMetadata metadata = 33; | |
OpSharding sharding = 40; | |
oneof op { | |
BinaryOpRequest binary_op_request = 2; | |
BroadcastRequest broadcast_request = 3; | |
CallRequest call_request = 4; | |
ConcatenateRequest concatenate_request = 5; | |
ConstantRequest constant_request = 6; | |
ConvertRequest convert_request = 7; | |
ConvolveRequest convolve_request = 8; | |
CrossReplicaSumRequest cross_replica_sum_request = 9; | |
CustomCallRequest custom_call_request = 10; | |
DotRequest dot_request = 43; | |
DynamicSliceRequest dynamic_slice_request = 11; | |
DynamicUpdateSliceRequest dynamic_update_slice_request = 12; | |
GetTupleElementRequest get_tuple_element_request = 13; | |
InfeedRequest infeed_request = 14; | |
MapRequest map_request = 15; | |
PadRequest pad_request = 16; | |
ParameterRequest parameter_request = 17; | |
ReducePrecisionRequest reduce_precision_request = 36; | |
ReduceRequest reduce_request = 18; | |
ReduceWindowRequest reduce_window_request = 19; | |
ReshapeRequest reshape_request = 20; | |
ReverseRequest reverse_request = 21; | |
RngRequest rng_request = 22; | |
SelectAndScatterRequest select_and_scatter_request = 23; | |
SliceRequest slice_request = 24; | |
TernaryOpRequest ternary_op_request = 25; | |
TraceRequest trace_request = 26; | |
TransposeRequest transpose_request = 34; | |
UnaryOpRequest unary_op_request = 27; | |
VariadicOpRequest variadic_op_request = 28; | |
WhileRequest while_request = 29; | |
SendRequest send_request = 30; | |
RecvRequest recv_request = 31; | |
OutfeedRequest outfeed_request = 32; | |
BatchNormTrainingRequest batch_norm_training_request = 35; | |
BatchNormGradRequest batch_norm_grad_request = 37; | |
BatchNormInferenceRequest batch_norm_inference_request = 38; | |
FftRequest fft_request = 41; | |
ConvertRequest bitcast_convert_request = 42; | |
ConditionalRequest conditional_request = 44; | |
HostComputeRequest host_compute_request = 45; | |
GatherRequest gather_request = 46; | |
// Next: 47 | |
} | |
} | |
message OpResponse { | |
ComputationDataHandle output = 1; | |
} |