Skip to content

Commit

Permalink
Transpose weights for hybrid quantized convolution
Browse files Browse the repository at this point in the history
Factored out function for matching input and kernel and function for transposing the weight values to share implementation between SRQ and weight-only convolution.

PiperOrigin-RevId: 620808162
  • Loading branch information
doyeonkim0 authored and tensorflower-gardener committed Apr 5, 2024
1 parent 44aef7e commit 3861c03
Show file tree
Hide file tree
Showing 2 changed files with 327 additions and 149 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1563,20 +1563,60 @@ func.func @dot_general_hybrid(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x5xf32>

// -----

// Tests that a hybrid quantized convolution is splitted into dequantize and
// float convolution.
// Tests that a hybrid per-channel quantized convolution for tfl.conv_2d is
// splitted into dequantize and float stablehlo.convolution.

// CHECK-LABEL: func @convolution_hybrid
// CHECK-LABEL: func @convolution_hybrid_per_channel
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x3x3x4xf32>
func.func @convolution_hybrid(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x2xf32> {
func.func @convolution_hybrid_per_channel(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x2xf32> {
%0 = stablehlo.constant() {value = dense<3> : tensor<3x3x4x2xi8>} : () -> tensor<3x3x4x2x!quant.uniform<i8:f32:3, {2.000000e+2, 3.000000e+3}>>
%1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x2x!quant.uniform<i8:f32:3, {2.000000e+2, 3.000000e+3}>>) -> tensor<1x3x3x2xf32>
return %1 : tensor<1x3x3x2xf32>
}

// CHECK: %[[WEIGHT:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<3x3x4x2x!quant.uniform<i8:f32:3, {2.000000e+02,3.000000e+03}>>, value = dense<3> : tensor<3x3x4x2xi8>}
// CHECK: %[[DQ:.+]] = "tfl.dequantize"(%[[WEIGHT]]) : (tensor<3x3x4x2x!quant.uniform<i8:f32:3, {2.000000e+02,3.000000e+03}>>) -> tensor<3x3x4x2xf32>
// CHECK: %[[WEIGHT:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x4x!quant.uniform<i8<-127:127>:f32:0, {2.000000e+02,3.000000e+03}>>, value = dense<3> : tensor<2x3x3x4xi8>}
// CHECK: %[[DQ:.+]] = "tfl.dequantize"(%[[WEIGHT]]) : (tensor<2x3x3x4x!quant.uniform<i8<-127:127>:f32:0, {2.000000e+02,3.000000e+03}>>) -> tensor<2x3x3x4xf32>
// CHECK: %[[CONV:.+]] = stablehlo.convolution(%[[ARG0]], %[[DQ]])
// CHECK{LITERAL}: dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64}
// CHECK-SAME: (tensor<1x3x3x4xf32>, tensor<3x3x4x2xf32>) -> tensor<1x3x3x2xf32>
// CHECK{LITERAL}: dim_numbers = [b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64}
// CHECK-SAME: (tensor<1x3x3x4xf32>, tensor<2x3x3x4xf32>) -> tensor<1x3x3x2xf32>
// CHECK: return %[[CONV]]

// -----

// Tests that a hybrid per-tensor quantized convolution for tfl.conv_2d is
// splitted into dequantize and float stablehlo.convolution.

// CHECK-LABEL: func @convolution_hybrid_per_tensor
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x3x3x4xf32>
func.func @convolution_hybrid_per_tensor(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x2xf32> {
%0 = stablehlo.constant() {value = dense<3> : tensor<3x3x4x2xi8>} : () -> tensor<3x3x4x2x!quant.uniform<i8:f32, 3.000000e-01:-5>>
%1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x2x!quant.uniform<i8:f32, 3.000000e-01:-5>>) -> tensor<1x3x3x2xf32>
return %1 : tensor<1x3x3x2xf32>
}

// CHECK: %[[WEIGHT:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<2x3x3x4x!quant.uniform<i8:f32, 3.000000e-01:-5>>, value = dense<3> : tensor<2x3x3x4xi8>}
// CHECK: %[[DQ:.+]] = "tfl.dequantize"(%[[WEIGHT]]) : (tensor<2x3x3x4x!quant.uniform<i8:f32, 3.000000e-01:-5>>) -> tensor<2x3x3x4xf32>
// CHECK: %[[CONV:.+]] = stablehlo.convolution(%[[ARG0]], %[[DQ]])
// CHECK{LITERAL}: dim_numbers = [b, 0, 1, f]x[o, 0, 1, i]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64}
// CHECK-SAME: (tensor<1x3x3x4xf32>, tensor<2x3x3x4xf32>) -> tensor<1x3x3x2xf32>
// CHECK: return %[[CONV]]

// -----

// Tests that a hybrid per-channel quantized convolution for tfl.depthwise_conv
// is splitted into dequantize and float stablehlo.convolution.

// CHECK-LABEL: func @depthwise_convolution_hybrid_per_channel
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x3x3x4xf32>
func.func @depthwise_convolution_hybrid_per_channel(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> {
%0 = stablehlo.constant() {value = dense<3> : tensor<3x3x1x4xi8>} : () -> tensor<3x3x1x4x!quant.uniform<i8:f32:3, {2.000000e+2, 3.000000e+3}>>
%1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 4 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x1x4x!quant.uniform<i8:f32:3, {2.000000e+2, 3.000000e+3}>>) -> tensor<1x3x3x4xf32>
return %1 : tensor<1x3x3x4xf32>
}

// CHECK: %[[WEIGHT:.+]] = "tfl.pseudo_qconst"() {qtype = tensor<1x3x3x4x!quant.uniform<i8<-127:127>:f32:3, {2.000000e+02,3.000000e+03}>>, value = dense<3> : tensor<1x3x3x4xi8>}
// CHECK: %[[DQ:.+]] = "tfl.dequantize"(%[[WEIGHT]]) : (tensor<1x3x3x4x!quant.uniform<i8<-127:127>:f32:3, {2.000000e+02,3.000000e+03}>>) -> tensor<1x3x3x4xf32>
// CHECK: %[[CONV:.+]] = stablehlo.convolution(%[[ARG0]], %[[DQ]])
// CHECK{LITERAL}: dim_numbers = [b, 0, 1, f]x[i, 0, 1, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 4 : i64}
// CHECK-SAME: (tensor<1x3x3x4xf32>, tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32>
// CHECK: return %[[CONV]]

0 comments on commit 3861c03

Please sign in to comment.