Skip to content

Commit

Permalink
Migrate StableHLO to use Properties
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 626126640
  • Loading branch information
GleasonK authored and tensorflower-gardener committed Apr 18, 2024
1 parent 77c4dca commit 5420c31
Show file tree
Hide file tree
Showing 15 changed files with 2,482 additions and 246 deletions.
Expand Up @@ -27,7 +27,7 @@ module {
%20 = call @uniform_dequantize_0(%19, %5, %6) : (tensor<1x3x3x4xi8>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xi8>) -> tensor<1x3x3x4xf32>
return %20 : tensor<1x3x3x4xf32>
}
// CHECK: %[[FILTER:.*]] = stablehlo.constant() {value = dense<1> : tensor<3x3x4x4xi8>} : () -> tensor<3x3x4x4x!quant.uniform<i8:f32:3, {{{.*}}}>>
// CHECK: %[[FILTER:.*]] = stablehlo.constant() <{value = dense<1> : tensor<3x3x4x4xi8>}> : () -> tensor<3x3x4x4x!quant.uniform<i8:f32:3, {{{.*}}}>>
// CHECK: %[[QUANT_ARG:.*]] = stablehlo.uniform_quantize %[[ARG]] : (tensor<1x3x3x4xf32>) -> tensor<1x3x3x4x!quant.uniform<i8:f32, {{.*}}>>
// CHECK: %[[CONV:.*]] = stablehlo.convolution(%[[QUANT_ARG]], %[[FILTER]]) {{.*}} : (tensor<1x3x3x4x!quant.uniform<i8:f32, {{.*}}>>, tensor<3x3x4x4x!quant.uniform<i8:f32:3, {{.*}}>>) -> tensor<1x3x3x4x!quant.uniform<i8:f32, {{.*}}>>
// CHECK: %[[DEQUANT:.*]] = stablehlo.uniform_dequantize %[[CONV]] : (tensor<1x3x3x4x!quant.uniform<i8:f32, {{.*}}>>) -> tensor<1x3x3x4xf32>
Expand Down Expand Up @@ -87,7 +87,7 @@ module {
%18 = call @uniform_dequantize_0(%17, %5, %6) : (tensor<1x3x3x4xi8>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xi8>) -> tensor<1x3x3x4xf32>
return %18 : tensor<1x3x3x4xf32>
}
// CHECK: %[[FILTER:.*]] = stablehlo.constant() {value = dense<20> : tensor<3x3x4x4xi8>} : () -> tensor<3x3x4x4x!quant.uniform<i8:f32:3, {{{.*}}}>>
// CHECK: %[[FILTER:.*]] = stablehlo.constant() <{value = dense<20> : tensor<3x3x4x4xi8>}> : () -> tensor<3x3x4x4x!quant.uniform<i8:f32:3, {{{.*}}}>>
// CHECK: %[[QUANT_ARG:.*]] = stablehlo.uniform_quantize %[[ARG]] : (tensor<1x3x3x4xf32>) -> tensor<1x3x3x4x!quant.uniform<i8:f32, {{.*}}>>
// CHECK: %[[CONV:.*]] = stablehlo.convolution(%[[QUANT_ARG]], %[[FILTER]]) {{.*}} : (tensor<1x3x3x4x!quant.uniform<i8:f32, {{.*}}>>, tensor<3x3x4x4x!quant.uniform<i8:f32:3, {{.*}}>>) -> tensor<1x3x3x4x!quant.uniform<i8:f32, {{.*}}>>
// CHECK: %[[DEQUANT:.*]] = stablehlo.uniform_dequantize %[[CONV]] : (tensor<1x3x3x4x!quant.uniform<i8:f32, {{.*}}>>) -> tensor<1x3x3x4xf32>
Expand Down Expand Up @@ -182,7 +182,7 @@ module {
return %17 : tensor<1x4x3xf32>
}
// Quantization dimension == 1 because it is the output feature dimension.
// CHECK: %[[FILTER:.*]] = stablehlo.constant() {value = dense<5> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform<i8:f32:1, {{{.*}}}>>
// CHECK: %[[FILTER:.*]] = stablehlo.constant() <{value = dense<5> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform<i8:f32:1, {{{.*}}}>>
// CHECK: %[[QUANT_ARG:.*]] = stablehlo.uniform_quantize %[[ARG]] : (tensor<1x4x2xf32>) -> tensor<1x4x2x!quant.uniform<i8:f32, {{.*}}:1>>
// CHECK: %[[CONV:.*]] = stablehlo.dot_general %[[QUANT_ARG]], %[[FILTER]], contracting_dims = [2] x [0] : (tensor<1x4x2x!quant.uniform<i8:f32, {{.*}}>>, tensor<2x3x!quant.uniform<i8:f32:1, {{.*}}>>) -> tensor<1x4x3x!quant.uniform<i8:f32, {{.*}}:2>>
// CHECK: %[[DEQUANT:.*]] = stablehlo.uniform_dequantize %[[CONV]] : (tensor<1x4x3x!quant.uniform<i8:f32, {{.*}}>>) -> tensor<1x4x3xf32>
Expand Down Expand Up @@ -238,7 +238,7 @@ module {
}
// Quantization dimension == 1 because it is the output feature dimension.
// Quantized filter values (from f32 constant) are cast to i8.
// CHECK: %[[FILTER:.*]] = stablehlo.constant() {value = dense<5> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform<i8:f32:1, {{{.*}}}>>
// CHECK: %[[FILTER:.*]] = stablehlo.constant() <{value = dense<5> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform<i8:f32:1, {{{.*}}}>>
// CHECK: %[[QUANT_ARG:.*]] = stablehlo.uniform_quantize %[[ARG]] : (tensor<1x4x2xf32>) -> tensor<1x4x2x!quant.uniform<i8:f32, {{.*}}:1>>
// CHECK: %[[CONV:.*]] = stablehlo.dot_general %[[QUANT_ARG]], %[[FILTER]], contracting_dims = [2] x [0] : (tensor<1x4x2x!quant.uniform<i8:f32, {{.*}}>>, tensor<2x3x!quant.uniform<i8:f32:1, {{.*}}>>) -> tensor<1x4x3x!quant.uniform<i8:f32, {{.*}}:2>>
// CHECK: %[[DEQUANT:.*]] = stablehlo.uniform_dequantize %[[CONV]] : (tensor<1x4x3x!quant.uniform<i8:f32, {{.*}}>>) -> tensor<1x4x3xf32>
Expand Down Expand Up @@ -292,7 +292,7 @@ module {
return %15 : tensor<1x3xf32>
}
// Quantization dimension == 1 because it is the output feature dimension.
// CHECK: %[[FILTER:.*]] = stablehlo.constant() {value = dense<5> : tensor<2x3xi8>} : () -> tensor<2x3x!quant.uniform<i8:f32:1, {{{.*}}}>>
// CHECK: %[[FILTER:.*]] = stablehlo.constant() <{value = dense<5> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform<i8:f32:1, {{{.*}}}>>
// CHECK: %[[QUANT_ARG:.*]] = stablehlo.uniform_quantize %[[ARG]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform<i8:f32, {{.*}}:1>>
// CHECK: %[[CONV:.*]] = stablehlo.dot_general %[[QUANT_ARG]], %[[FILTER]], contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform<i8:f32, {{.*}}>>, tensor<2x3x!quant.uniform<i8:f32:1, {{.*}}>>) -> tensor<1x3x!quant.uniform<i8:f32, {{.*}}:2>>
// CHECK: %[[DEQUANT:.*]] = stablehlo.uniform_dequantize %[[CONV]] : (tensor<1x3x!quant.uniform<i8:f32, {{.*}}>>) -> tensor<1x3xf32>
Expand Down
Expand Up @@ -1335,12 +1335,13 @@ class RewriteQuantizedConvolutionOp

// Returns the stride amount for the height and width, respectively.
std::pair<int64_t, int64_t> GetStrides(stablehlo::ConvolutionOp op) const {
DenseI64ArrayAttr window_strides_attr = op.getWindowStridesAttr();
if (!window_strides_attr) {
std::optional<ArrayRef<int64_t>> window_strides_attr =
op.getWindowStrides();
if (!window_strides_attr.has_value()) {
return {1, 1}; // Default values.
}

auto window_strides_attr_value = window_strides_attr.asArrayRef();
auto window_strides_attr_value = window_strides_attr.value();
// It is guaranteed from the spec that it has two values:
// https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convolution.
return {window_strides_attr_value[0], window_strides_attr_value[1]};
Expand All @@ -1349,12 +1350,12 @@ class RewriteQuantizedConvolutionOp
// Returns the dilation amount for the height and width, respectively.
std::pair<int64_t, int64_t> GetDilationFactors(
stablehlo::ConvolutionOp op) const {
DenseI64ArrayAttr lhs_dilation_attr = op.getLhsDilationAttr();
if (!lhs_dilation_attr) {
std::optional<ArrayRef<int64_t>> lhs_dilation_attr = op.getLhsDilation();
if (!lhs_dilation_attr.has_value()) {
return {1, 1}; // Default values.
}

auto lhs_dilation_attr_value = lhs_dilation_attr.asArrayRef();
auto lhs_dilation_attr_value = lhs_dilation_attr.value();
// It is guaranteed from the spec that it has two values:
// https://github.com/openxla/stablehlo/blob/main/docs/spec.md#convolution.
return {lhs_dilation_attr_value[0], lhs_dilation_attr_value[1]};
Expand Down
Expand Up @@ -169,16 +169,16 @@ class DeferActivationTransposeForMaxPoolReduceWindowOp
op.getLoc(), new_result_type, transpose_op.getOperand(),
/*init_value=*/op.getOperand(1),
/*window_dimensions=*/
PermuteI64ArrayAttr(rewriter, op.getWindowDimensionsAttr(),
PermuteI64ArrayAttr(rewriter, op.getWindowDimensions(),
kNchwToNhwcPermutation),
/*window_strides=*/
PermuteI64ArrayAttr(rewriter, op.getWindowStridesAttr(),
PermuteI64ArrayAttr(rewriter, op.getWindowStrides(),
kNchwToNhwcPermutation),
/*base_dilations=*/
PermuteI64ArrayAttr(rewriter, op.getBaseDilationsAttr(),
PermuteI64ArrayAttr(rewriter, op.getBaseDilations(),
kNchwToNhwcPermutation),
/*window_dilations=*/
PermuteI64ArrayAttr(rewriter, op.getWindowDilationsAttr(),
PermuteI64ArrayAttr(rewriter, op.getWindowDilations(),
kNchwToNhwcPermutation),
/*padding=*/DenseIntElementsAttr(nullptr));

Expand All @@ -199,12 +199,13 @@ class DeferActivationTransposeForMaxPoolReduceWindowOp
// `array_attr` and `permutation` must be equal. Returns a null attribute
// if `array_attr` is null.
DenseI64ArrayAttr PermuteI64ArrayAttr(
PatternRewriter& rewriter, const DenseI64ArrayAttr array_attr,
PatternRewriter& rewriter,
const std::optional<ArrayRef<int64_t>> array_attr,
const ArrayRef<int64_t> permutation) const {
if (array_attr == nullptr) return DenseI64ArrayAttr(nullptr);
if (!array_attr.has_value()) return DenseI64ArrayAttr(nullptr);

return rewriter.getDenseI64ArrayAttr(
Permute<int64_t>(array_attr, permutation));
Permute<int64_t>(array_attr.value(), permutation));
}

LogicalResult MatchMaxPoolReduceWindowOp(
Expand Down
Expand Up @@ -47,7 +47,7 @@ func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1:
}
// CHECK-NO-UNPACK-LABEL: func.func @main_no_unpack
// CHECK-NO-UNPACK-SAME: (%[[ARG_0:.+]]: tensor<1x1024xf32>) -> tensor<1x3xf32>
// CHECK-NO-UNPACK-DAG: %[[CONST:.+]] = stablehlo.constant() {value = dense<{{.*}}> : tensor<1024x3xi8>} : () -> tensor<1024x3x!quant.uniform<i8<-127:127>:f32:1, {{.*}}>>
// CHECK-NO-UNPACK-DAG: %[[CONST:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1024x3xi8>}> : () -> tensor<1024x3x!quant.uniform<i8<-127:127>:f32:1, {{.*}}>>
// CHECK-NO-UNPACK: %[[QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x1024xf32>) -> tensor<1x1024x!quant.uniform<i8:f32, {{.*}}>>
// CHECK-NO-UNPACK: %[[DOT:.+]] = stablehlo.dot_general %[[QUANTIZE_0]], %[[CONST]]
// CHECK-NO-UNPACK: %[[QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[DOT]] : (tensor<1x3x!quant.uniform<i32:f32:1, {{.*}}>>) -> tensor<1x3x!quant.uniform<i8:f32, {{.*}}>>
Expand Down
Expand Up @@ -118,14 +118,14 @@ func.func @reduce_window_max_activation_transpose(%arg0: tensor<1x16x16x4xf32>)

// Check that the body is not modified.
// CHECK: %[[REDUCE_WINDOW:.+]] = "stablehlo.reduce_window"(%[[ARG]], %[[INIT_VALUE_CONST]])
// CHECK: <{window_dimensions = array<i64: 1, 2, 2, 1>, window_strides = array<i64: 1, 2, 2, 1>}>
// CHECK: ^bb0(%[[REDUCE_ARG_0:.+]]: tensor<f32>, %[[REDUCE_ARG_1:.+]]: tensor<f32>):
// CHECK: %[[MAX:.+]] = stablehlo.maximum %[[REDUCE_ARG_0]], %[[REDUCE_ARG_1]]
// CHECK: stablehlo.return %[[MAX]]

// Check that the attributes window_dimensions & window_strides are also
// permutated to match the new input shape.
// CHECK: {window_dimensions = array<i64: 1, 2, 2, 1>, window_strides = array<i64: 1, 2, 2, 1>}
// CHECK-SAME: (tensor<1x16x16x4xf32>, tensor<f32>) -> tensor<1x8x8x4xf32>
// CHECK: (tensor<1x16x16x4xf32>, tensor<f32>) -> tensor<1x8x8x4xf32>

// Check that a `stablehlo.transpose` is added to the result to match the shape
// of the users.
Expand Down Expand Up @@ -162,15 +162,15 @@ func.func @reduce_window_max_activation_transpose_explicit_optional_attrs(

// Check that the body is not modified.
// CHECK: %[[REDUCE_WINDOW:.+]] = "stablehlo.reduce_window"(%[[ARG]], %[[INIT_VALUE_CONST]])
// CHECK: <{base_dilations = array<i64: 1, 2, 2, 1>, window_dilations = array<i64: 1, 2, 2, 1>, window_dimensions = array<i64: 1, 2, 2, 1>, window_strides = array<i64: 1, 2, 2, 1>}>
// CHECK: ^bb0(%[[REDUCE_ARG_0:.+]]: tensor<f32>, %[[REDUCE_ARG_1:.+]]: tensor<f32>):
// CHECK: %[[MAX:.+]] = stablehlo.maximum %[[REDUCE_ARG_0]], %[[REDUCE_ARG_1]]
// CHECK: stablehlo.return %[[MAX]]

// Check that the attributes window_dimensions & window_strides along with
// optional attributes base_dilations and window_dilations are also permutated
// to match the new input shape.
// CHECK: {base_dilations = array<i64: 1, 2, 2, 1>, window_dilations = array<i64: 1, 2, 2, 1>, window_dimensions = array<i64: 1, 2, 2, 1>, window_strides = array<i64: 1, 2, 2, 1>}
// CHECK-SAME: (tensor<1x16x16x4xf32>, tensor<f32>) -> tensor<1x15x15x4xf32>
// CHECK: (tensor<1x16x16x4xf32>, tensor<f32>) -> tensor<1x15x15x4xf32>

// Check that a `stablehlo.transpose` is added to the result to match the shape
// of the users.
Expand Down
Expand Up @@ -37,7 +37,7 @@ func.func @remove_volatile_qdq_with_requantization(%arg0: tensor<3x2xf32>) -> te
// CHECK-LABEL: @quantize_constant
// CHECK-SAME: %[[ARG0:.*]]: tensor<1x3xf32>
func.func @quantize_constant(%arg0: tensor<1x3xf32>) -> tensor<1x2xf32> {
// CHECK-DAG: %[[QCST:.*]] = stablehlo.constant() {value = dense<-78> : tensor<3x2xi8>} : () -> tensor<3x2x!quant.uniform<i8<-127:127>:f32, 5.000000e-03>>
// CHECK-DAG: %[[QCST:.*]] = stablehlo.constant() <{value = dense<-78> : tensor<3x2xi8>}> : () -> tensor<3x2x!quant.uniform<i8<-127:127>:f32, 5.000000e-03>>
// CHECK-DAG: %[[Q1:.*]] = stablehlo.uniform_quantize %[[ARG0]]
// CHECK-NOT: "quantfork.qcast"
// CHECK: %[[DOT:.*]] = stablehlo.dot %[[Q1]], %[[QCST]]
Expand Down
Expand Up @@ -14,12 +14,12 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p
// CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[Q2]], %[[Q1]])

// CHECK: %[[REDUCE:.*]] = "stablehlo.reduce_window"(%[[CALL]], %[[Q0]])
// CHECK{LITERAL}: padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>
// CHECK-SAME: window_dimensions = array<i64: 1, 3, 3, 1>
// CHECK: %[[ARG1:.*]]: tensor<!quant.uniform<i8:f32, 3.000000e-01:1>>, %[[ARG2:.*]]: tensor<!quant.uniform<i8:f32, 3.000000e-01:1>>
// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ARG1]], %[[ARG2]] : tensor<!quant.uniform<i8:f32, 3.000000e-01:1>>
// CHECK: stablehlo.return %[[MAX]] : tensor<!quant.uniform<i8:f32, 3.000000e-01:1>>
// CHECK{LITERAL}: padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>
// CHECK-SAME: window_dimensions = array<i64: 1, 3, 3, 1>
// CHECK-SAME: (tensor<2x3x1x3x!quant.uniform<i8:f32, 3.000000e-01:1>>, tensor<!quant.uniform<i8:f32, 3.000000e-01:1>>) -> tensor<2x3x1x3x!quant.uniform<i8:f32, 3.000000e-01:1>>
// CHECK: (tensor<2x3x1x3x!quant.uniform<i8:f32, 3.000000e-01:1>>, tensor<!quant.uniform<i8:f32, 3.000000e-01:1>>) -> tensor<2x3x1x3x!quant.uniform<i8:f32, 3.000000e-01:1>>

// CHECK: %[[DQ:.*]] = "quantfork.dcast"(%[[REDUCE]])
// CHECK: return %[[DQ]]
Expand Down Expand Up @@ -70,12 +70,12 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p
// CHECK: %[[Q1:.*]] = "quantfork.qcast"(%[[ARG0]])

// CHECK: %[[REDUCE:.*]] = "stablehlo.reduce_window"(%[[Q1]], %[[Q0]])
// CHECK{LITERAL}: padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>
// CHECK-SAME: window_dimensions = array<i64: 1, 3, 3, 1>
// CHECK: %[[ARG1:.*]]: tensor<!quant.uniform<i8:f32, 5.000000e-01:2>>, %[[ARG2:.*]]: tensor<!quant.uniform<i8:f32, 5.000000e-01:2>>
// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ARG1]], %[[ARG2]] : tensor<!quant.uniform<i8:f32, 5.000000e-01:2>>
// CHECK: stablehlo.return %[[MAX]] : tensor<!quant.uniform<i8:f32, 5.000000e-01:2>>
// CHECK{LITERAL}: padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>
// CHECK-SAME: window_dimensions = array<i64: 1, 3, 3, 1>
// CHECK-SAME: (tensor<2x3x1x1024x!quant.uniform<i8:f32, 5.000000e-01:2>>, tensor<!quant.uniform<i8:f32, 5.000000e-01:2>>) -> tensor<2x3x1x1024x!quant.uniform<i8:f32, 5.000000e-01:2>>
// CHECK: (tensor<2x3x1x1024x!quant.uniform<i8:f32, 5.000000e-01:2>>, tensor<!quant.uniform<i8:f32, 5.000000e-01:2>>) -> tensor<2x3x1x1024x!quant.uniform<i8:f32, 5.000000e-01:2>>

// CHECK: %[[Q2:.*]] = "quantfork.qcast"(%[[CST1]])
// CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[REDUCE]], %[[Q2]])
Expand Down Expand Up @@ -132,12 +132,12 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p
// CHECK: %[[RESHAPE:.*]] = stablehlo.reshape %[[CALL]]

// CHECK: %[[REDUCE:.*]] = "stablehlo.reduce_window"(%[[RESHAPE]], %[[Q0]])
// CHECK{LITERAL}: padding = dense<[[0, 0], [1, 1], [0, 0]]> : tensor<3x2xi64>
// CHECK-SAME: window_dimensions = array<i64: 1, 3, 1>
// CHECK: %[[ARG1:.*]]: tensor<!quant.uniform<i8:f32, 3.000000e-01:1>>, %[[ARG2:.*]]: tensor<!quant.uniform<i8:f32, 3.000000e-01:1>>
// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ARG1]], %[[ARG2]] : tensor<!quant.uniform<i8:f32, 3.000000e-01:1>>
// CHECK: stablehlo.return %[[MAX]] : tensor<!quant.uniform<i8:f32, 3.000000e-01:1>>
// CHECK{LITERAL}: padding = dense<[[0, 0], [1, 1], [0, 0]]> : tensor<3x2xi64>
// CHECK-SAME: window_dimensions = array<i64: 1, 3, 1>
// CHECK-SAME: (tensor<2x3x3x!quant.uniform<i8:f32, 3.000000e-01:1>>, tensor<!quant.uniform<i8:f32, 3.000000e-01:1>>) -> tensor<2x3x3x!quant.uniform<i8:f32, 3.000000e-01:1>>
// CHECK: (tensor<2x3x3x!quant.uniform<i8:f32, 3.000000e-01:1>>, tensor<!quant.uniform<i8:f32, 3.000000e-01:1>>) -> tensor<2x3x3x!quant.uniform<i8:f32, 3.000000e-01:1>>

// CHECK: %[[DQ:.*]] = "quantfork.dcast"(%[[REDUCE]])
// CHECK: return %[[DQ]]
Expand Down Expand Up @@ -191,12 +191,12 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p
// CHECK: %[[Q1:.*]] = "quantfork.qcast"(%[[ARG0]])

// CHECK: %[[REDUCE:.*]] = "stablehlo.reduce_window"(%[[Q1]], %[[Q0]])
// CHECK{LITERAL}: padding = dense<[[0, 0], [1, 1], [0, 0]]> : tensor<3x2xi64>
// CHECK-SAME: window_dimensions = array<i64: 1, 3, 1>
// CHECK: %[[ARG1:.*]]: tensor<!quant.uniform<i8:f32, 5.000000e-01:2>>, %[[ARG2:.*]]: tensor<!quant.uniform<i8:f32, 5.000000e-01:2>>
// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ARG1]], %[[ARG2]] : tensor<!quant.uniform<i8:f32, 5.000000e-01:2>>
// CHECK: stablehlo.return %[[MAX]] : tensor<!quant.uniform<i8:f32, 5.000000e-01:2>>
// CHECK{LITERAL}: padding = dense<[[0, 0], [1, 1], [0, 0]]> : tensor<3x2xi64>
// CHECK-SAME: window_dimensions = array<i64: 1, 3, 1>
// CHECK-SAME: (tensor<2x3x1024x!quant.uniform<i8:f32, 5.000000e-01:2>>, tensor<!quant.uniform<i8:f32, 5.000000e-01:2>>) -> tensor<2x3x1024x!quant.uniform<i8:f32, 5.000000e-01:2>>
// CHECK: (tensor<2x3x1024x!quant.uniform<i8:f32, 5.000000e-01:2>>, tensor<!quant.uniform<i8:f32, 5.000000e-01:2>>) -> tensor<2x3x1024x!quant.uniform<i8:f32, 5.000000e-01:2>>

// CHECK: %[[RESHAPE:.*]] = stablehlo.reshape %[[REDUCE]]
// CHECK: %[[Q2:.*]] = "quantfork.qcast"(%[[CST1]])
Expand Down

0 comments on commit 5420c31

Please sign in to comment.