Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA][LHLO] Added support for SideEffects interfaces to LHLO operations. #38267

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
154 changes: 77 additions & 77 deletions tensorflow/compiler/mlir/xla/ir/lhlo_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,20 @@ def LHLO_BufferOrTuple : AnyTypeOf<[LHLO_Buffer, LHLO_TupleBuffer]>;
// XLA nullary op definitions.
//===----------------------------------------------------------------------===//

class LHLO_Op<string mnemonic, list<OpTrait> traits> : Op<LHLO_Dialect,
mnemonic, traits>;
class LHLO_Op<string mnemonic, list<OpTrait> traits> :
Op<LHLO_Dialect, mnemonic,
!listconcat([MemoryEffects<[MemRead, MemWrite]>], traits)>;

def LHLO_ConstOp : LHLO_Op<"constant", []>, BASE_HLO_ConstOp {
let arguments = (ins
ElementsAttr:$value,
LHLO_Buffer:$output
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}

def LHLO_IotaOp : LHLO_Op<"iota", []>, BASE_HLO_IotaOp {
let arguments = (ins I64Attr:$iota_dimension,
LHLO_Buffer:$output);
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
}

//===----------------------------------------------------------------------===//
Expand All @@ -75,17 +76,17 @@ def LHLO_IotaOp : LHLO_Op<"iota", []>, BASE_HLO_IotaOp {

class LHLO_UnaryElementwiseOp<string mnemonic> :
LHLO_Op<mnemonic, [SameTypeOperands]> {
let arguments = (ins LHLO_Buffer:$input,
LHLO_Buffer:$output);
let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$input,
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
}

def LHLO_AbsOp: LHLO_UnaryElementwiseOp<"abs">, BASE_HLO_AbsOp;

def LHLO_CeilOp: LHLO_UnaryElementwiseOp<"ceil">, BASE_HLO_CeilOp;

def LHLO_ConvertOp : LHLO_Op<"convert", [SameOperandsShape]>, BASE_HLO_ConvertOp {
let arguments = (ins LHLO_Buffer:$input,
LHLO_Buffer:$output);
let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$input,
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
}

def LHLO_CosOp: LHLO_UnaryElementwiseOp<"cos">, BASE_HLO_CosOp;
Expand All @@ -111,9 +112,9 @@ def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh">, BASE_HLO_TanhOp;
class LHLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
LHLO_Op<mnemonic, traits> {
let arguments = (ins
LHLO_Buffer:$lhs,
LHLO_Buffer:$rhs,
LHLO_Buffer:$out,
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_Buffer, "", [MemWrite]>:$out,
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
);
}
Expand Down Expand Up @@ -147,24 +148,23 @@ def LHLO_ReduceOp: LHLO_Op<"reduce", [
SingleBlockImplicitTerminator<"TerminatorOp">
]>, BASE_HLO_ReduceOp {
let arguments = (ins
Variadic<LHLO_BufferOrTuple>:$operands,
Variadic<LHLO_BufferOrTuple>:$init_values,
Variadic<LHLO_BufferOrTuple>:$out,
Arg<Variadic<LHLO_BufferOrTuple>, "", [MemRead]>:$operands,
Arg<Variadic<LHLO_BufferOrTuple>, "", [MemRead]>:$init_values,
Arg<Variadic<LHLO_BufferOrTuple>, "", [MemWrite]>:$out,
I64ElementsAttr:$dimensions
);

let regions = (region SizedRegion<1>:$body);
}

def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [
NoSideEffect,
SingleBlockImplicitTerminator<"TerminatorOp">
]>, BASE_HLO_ReduceWindowOp {

let arguments = (ins
LHLO_Buffer:$operand,
LHLO_Buffer:$init_value,
LHLO_Buffer:$out,
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$init_value,
Arg<LHLO_Buffer, "", [MemWrite]>:$out,
I64ElementsAttr:$window_dimensions,
// If strides or dilations attributes are missing then the default value is
// one for each of the input dimensions. Similarly, padding values are zero
Expand All @@ -184,23 +184,23 @@ def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [

def LHLO_GetTupleElementOp: LHLO_Op<"get_tuple_element", []>, BASE_HLO_GetTupleElementOp {
let arguments = (ins
LHLO_TupleBuffer:$input,
LHLO_BufferOrTuple:$out,
Arg<LHLO_TupleBuffer, "", [MemRead]>:$input,
Arg<LHLO_BufferOrTuple, "", [MemWrite]>:$out,
I32Attr:$index
);
}

def LHLO_TupleOp : LHLO_Op<"tuple", []>, BASE_HLO_TupleOp {
let arguments = (ins
Variadic<LHLO_BufferOrTuple>:$val,
LHLO_TupleBuffer:$out);
Arg<Variadic<LHLO_BufferOrTuple>, "", [MemRead]>:$val,
Arg<LHLO_TupleBuffer, "", [MemWrite]>:$out);
}

def LHLO_CompareOp: LHLO_Op<"compare", []>, BASE_HLO_CompareOp {
let arguments = (ins
LHLO_Buffer:$lhs,
LHLO_Buffer:$rhs,
LHLO_PredBuffer:$out,
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_PredBuffer, "", [MemWrite]>:$out,
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions,
HLO_ComparisonDirectionAttr:$comparison_direction
);
Expand All @@ -214,8 +214,8 @@ def LHLO_SliceOp: LHLO_Op<
"slice",
[AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> {
let arguments = (ins
LHLO_Buffer:$operand,
LHLO_Buffer:$output,
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
I64ElementsAttr:$start_indices,
I64ElementsAttr:$limit_indices,
I64ElementsAttr:$strides
Expand All @@ -224,10 +224,10 @@ def LHLO_SliceOp: LHLO_Op<

def HLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> {
let arguments = (ins
LHLO_Buffer:$operand,
LHLO_Buffer:$update,
LHLO_Buffer:$output,
Variadic<LHLO_Buffer>:$start_indices
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$update,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$start_indices
);
}

Expand All @@ -239,12 +239,12 @@ def HLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []>,
BASE_HLO_BatchNormInferenceOp {

let arguments = (ins
LHLO_Buffer:$operand,
LHLO_Buffer:$scale,
LHLO_Buffer:$offset,
LHLO_Buffer:$mean,
LHLO_Buffer:$variance,
LHLO_Buffer:$output,
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
Arg<LHLO_Buffer, "", [MemRead]>:$offset,
Arg<LHLO_Buffer, "", [MemRead]>:$mean,
Arg<LHLO_Buffer, "", [MemRead]>:$variance,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
F32Attr:$epsilon,
I64Attr:$feature_index
);
Expand All @@ -253,99 +253,99 @@ def HLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []>,
def LHLO_BroadcastOp : LHLO_Op<"broadcast",
[]>, BASE_HLO_BroadcastOp {
let arguments = (ins
LHLO_Buffer:$operand,
LHLO_Buffer:$output,
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
I64ElementsAttr:$broadcast_sizes
);
}

def LHLO_BroadcastInDimOp : LHLO_Op<"broadcast_in_dim",
[]>, BASE_HLO_BroadcastInDimOp {
let arguments = (ins
LHLO_Buffer:$operand,
LHLO_Buffer:$output,
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
BroadcastDimAttr:$broadcast_dimensions
);
}

def LHLO_ClampOp : LHLO_Op<"clamp", []>, BASE_HLO_ClampOp {
let arguments = (ins
LHLO_Buffer:$min,
LHLO_Buffer:$operand,
LHLO_Buffer:$max,
LHLO_Buffer:$output
Arg<LHLO_Buffer, "", [MemRead]>:$min,
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$max,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}

def LHLO_ConcatenateOp : LHLO_Op<"concatenate", []>, BASE_HLO_ConcatenateOp {
let arguments = (ins
Variadic<LHLO_Buffer>:$val,
LHLO_Buffer:$output,
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$val,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
I64Attr:$dimension
);
}

def LHLO_ConvOp : LHLO_Op<"conv", []>, BASE_HLO_ConvOp {
let arguments = (ins
LHLO_Buffer:$lhs,
LHLO_Buffer:$rhs,
LHLO_Buffer:$output
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}

def LHLO_CopyOp: LHLO_Op<"copy", []>, BASE_HLO_CopyOp {
let arguments = (ins
LHLO_Buffer:$operand,
LHLO_Buffer:$output
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}

def LHLO_DotOp: LHLO_Op<"dot", []>, BASE_HLO_DotOp {
let arguments = (ins
LHLO_Buffer:$lhs,
LHLO_Buffer:$rhs,
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
HLO_PrecisionConfigAttr:$precision_config,
LHLO_Buffer:$output
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}

def LHLO_GatherOp: LHLO_Op<"gather", []>, BASE_HLO_GatherOp {
let arguments = (ins
LHLO_Buffer:$operand,
LHLO_IntBuffer:$start_indices,
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_IntBuffer, "", [MemRead]>:$start_indices,
I64Attr:$index_vector_dim,
I64ElementsAttr:$offset_dims,
I64ElementsAttr:$slice_sizes,
I64ElementsAttr:$collapsed_slice_dims,
I64ElementsAttr:$start_index_map,
LHLO_Buffer:$output
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}

def LHLO_ReshapeOp: LHLO_Op<"reshape", []>, BASE_HLO_ReshapeOp {
let arguments = (ins
LHLO_Buffer:$operand,
LHLO_Buffer:$output
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}


def LHLO_SelectOp: LHLO_Op<"select", []>, BASE_HLO_SelectOp {
let arguments = (ins
LHLO_PredBuffer:$pred,
LHLO_Buffer:$on_true,
LHLO_Buffer:$on_false,
LHLO_Buffer:$output
Arg<LHLO_PredBuffer, "", [MemRead]>:$pred,
Arg<LHLO_Buffer, "", [MemRead]>:$on_true,
Arg<LHLO_Buffer, "", [MemRead]>:$on_false,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}

def LHLO_SelectAndScatterOp: LHLO_Op<"select_and_scatter",
[NoSideEffect]>, BASE_HLO_SelectAndScatterOp {
def LHLO_SelectAndScatterOp: LHLO_Op<"select_and_scatter", []>,
BASE_HLO_SelectAndScatterOp {
let arguments = (ins
LHLO_Buffer:$operand,
LHLO_Buffer:$source,
LHLO_Buffer:$init_value,
LHLO_Buffer:$out,
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$source,
Arg<LHLO_Buffer, "", [MemRead]>:$init_value,
Arg<LHLO_Buffer, "", [MemWrite]>:$out,
OptionalAttr<I64ElementsAttr>:$window_dimensions,
OptionalAttr<I64ElementsAttr>:$window_strides,
OptionalAttr<I64ElementsAttr>:$padding
Expand All @@ -356,28 +356,28 @@ def LHLO_SelectAndScatterOp: LHLO_Op<"select_and_scatter",

def LHLO_ReverseOp: LHLO_Op<"reverse", []>, BASE_HLO_ReverseOp {
let arguments = (ins
LHLO_Buffer:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
I64ElementsAttr:$dimensions,
LHLO_Buffer:$output
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}

def LHLO_PadOp: LHLO_Op<"pad", []>, BASE_HLO_PadOp {
let arguments = (ins
LHLO_Buffer:$operand,
LHLO_Buffer:$padding_value,
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$padding_value,
I64ElementsAttr:$edge_padding_low,
I64ElementsAttr:$edge_padding_high,
I64ElementsAttr:$interior_padding,
LHLO_Buffer:$output
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}

def LHLO_TransposeOp: LHLO_Op<"transpose", []>, BASE_HLO_TransposeOp {
let arguments = (ins
LHLO_Buffer:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
I64ElementsAttr:$permutation,
LHLO_Buffer:$output
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}

Expand Down