Skip to content

Commit

Permalink
Add xla_lhlo.dynamic_broadcast_in_dim operation.
Browse files Browse the repository at this point in the history
Also change the type of the dynamic dimensions operand to vector of Integer, as index type is not supported in vectors.

PiperOrigin-RevId: 295141631
Change-Id: Ie8b6d5adec65d70243a3b132ffc807cafd212b42
  • Loading branch information
Stephan Herhut authored and tensorflower-gardener committed Feb 14, 2020
1 parent 33c5c0b commit 77deb92
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 11 deletions.
6 changes: 1 addition & 5 deletions tensorflow/compiler/mlir/xla/ir/hlo_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,6 @@ def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>;

def HLO_TensorOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Tuple]>;

def HLO_DimensionTensor : ShapedContainerType<
[Index], And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>,
"a 1D tensor of dimensions">;

// In general, static shaped tensor constraints should be avoided unless
// it is for a legacy op which is only correct with static shapes.
def HLO_StaticShapeTensor : StaticShapeTensorOf<[
Expand Down Expand Up @@ -778,7 +774,7 @@ def HLO_DynamicBroadcastInDimOp : HLO_Op<"dynamic_broadcast_in_dim",
[NoSideEffect]>, BASE_HLO_DynamicBroadcastInDimOp {
let arguments = (ins
HLO_Tensor:$operand,
HLO_DimensionTensor:$output_dimensions,
HLO_BASE_DimensionTensor:$output_dimensions,
BroadcastDimAttr:$broadcast_dimensions
);

Expand Down
19 changes: 13 additions & 6 deletions tensorflow/compiler/mlir/xla/ir/hlo_ops_base.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,19 @@ include "mlir/IR/OpBase.td"
def HLO_Int : IntOfWidths<[8, 16, 32, 64]>;
def HLO_Pred : TypeAlias<I1, "pred (AKA boolean or 1-bit integer)">;

// The broadcasting dimensions correspond to a tuple that describes how a
// smaller rank shape is broadcast into a larger rank shape. For example,
// given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means
// matching the matrix to dimensions 1 and 2 of the cuboid.
def BroadcastDimAttr : OptionalAttr<I64ElementsAttr>;

// Dynamic representation of a shape vector as a tensor. Ideally this would be
// an index type (as it stores indices) but that is currently disallowed in
// MLIR.
def HLO_BASE_DimensionTensor : ShapedContainerType<
[AnyInteger], And<[IsTensorTypePred, HasAnyRankOfPred<[1]>]>,
"a 1D tensor of dimensions">;

//===----------------------------------------------------------------------===//
// XLA nullary op definitions.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -318,12 +331,6 @@ class BASE_HLO_TanhOp {
// XLA binary elementwise op definitions.
//===----------------------------------------------------------------------===//

// The broadcasting dimensions correspond to a tuple that describes how a
// smaller rank shape is broadcast into a larger rank shape. For example,
// given a 2x3x4 cuboid and a 3x4 matrix, a broadcasting tuple (1,2) means
// matching the matrix to dimensions 1 and 2 of the cuboid.
def BroadcastDimAttr : OptionalAttr<I64ElementsAttr>;

class BASE_HLO_AddOp {
string summary = "Addition operator";

Expand Down
10 changes: 10 additions & 0 deletions tensorflow/compiler/mlir/xla/ir/lhlo_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,16 @@ def LHLO_BroadcastInDimOp : LHLO_Op<"broadcast_in_dim",
);
}

def HLO_DynamicBroadcastInDimOp : LHLO_Op<"dynamic_broadcast_in_dim",
[NoSideEffect]>, BASE_HLO_DynamicBroadcastInDimOp {
let arguments = (ins
LHLO_Buffer:$operand,
HLO_BASE_DimensionTensor:$output_dimensions,
LHLO_Buffer:$output,
BroadcastDimAttr:$broadcast_dimensions
);
}

def LHLO_ClampOp : LHLO_Op<"clamp", []>, BASE_HLO_ClampOp {
let arguments = (ins
LHLO_Buffer:$min,
Expand Down
24 changes: 24 additions & 0 deletions tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,30 @@ func @and_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32

// -----

// CHECK-LABEL: func @broadcast_in_dim_memref
func @broadcast_in_dim_memref(%arg0: memref<1x2xi32>, %out: memref<1x2x2xi32>) -> () {
"xla_lhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<1x2xi32>, memref<1x2x2xi32>) -> ()
return
}

// -----

// CHECK-LABEL: func @broadcast_in_dim_zero_rank_memref
func @broadcast_in_dim_zero_rank_memref(%arg0: memref<i32>, %out: memref<1x2x3xi32>) -> () {
"xla_lhlo.broadcast_in_dim"(%arg0, %out) : (memref<i32>, memref<1x2x3xi32>) -> ()
return
}

// -----

// CHECK-LABEL: func @dynamic_broadcast_in_dim_memref
func @dynamic_broadcast_in_dim_memref(%arg0: memref<?x?xi32>, %out: memref<?x?x?xi32>, %shape: tensor<3xi64>) -> () {
"xla_lhlo.dynamic_broadcast_in_dim"(%arg0, %shape, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<?x?xi32>, tensor<3xi64>, memref<?x?x?xi32>) -> ()
return
}

// -----

// CHECK-LABEL: func @reduce_memref
func @reduce_memref(%input: memref<10xf32>, %init: memref<f32>, %out: memref<1xf32>) -> () {
"xla_lhlo.reduce"(%input, %init, %out) ( {
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/compiler/mlir/xla/tests/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@ func @broadcast_in_dim_zero_rank(%arg0: tensor<i32>) -> tensor<1x2x3xi32> {

// -----

// CHECK-LABEL: func @dynamic_broadcast_in_dim
func @dynamic_broadcast_in_dim(%arg0: tensor<?x?xi32>, %shape: tensor<3xi64>) -> tensor<?x?x?xi32> {
%0 = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %shape) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<?x?xi32>, tensor<3xi64>) -> tensor<?x?x?xi32>
return %0 : tensor<?x?x?xi32>
}

// -----

func @broadcast_in_dim_bad_dimension_rank(%arg0: tensor<1x2xi32>) -> tensor<1x2x3xi32> {
// expected-error@+1 {{broadcast_dimensions has rank 2 instead of rank 1}}
%0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[[1,1],[1,1]]> : tensor<2x2xi64>} : (tensor<1x2xi32>) -> tensor<1x2x3xi32>
Expand Down

0 comments on commit 77deb92

Please sign in to comment.