Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added per channel kernels for depthwise conv. #37621

Closed
wants to merge 25 commits into from

Conversation

kimishpatel
Copy link
Contributor

@kimishpatel kimishpatel commented Apr 30, 2020

Stack from ghstack:

Summary:
Due to potential perf issues with using same depthwise conv kernels for perf channel depthwise conv, we opt for replicating the kernels and adding per channel support to them.
Note that the large kernels files are largely duplication of original kernels. Main difference in the kernels is that for each iteration (over a group of output channels) of the loop we need to obtain corresponding zero point and requantization scale. Rest of the compute is the same.

Test Plan:
qnnpack tests.
q8dwconv-test
convolution-test

Differential Revision: D21339042

Summary:
Due to potential perf issues with using same depthwise conv kernels for
perf channel depthwise conv, we opt for replicating the kernels and
adding per channel support to them.
Note that the large kernels files are largely duplication of original
kernels. Assembly kernels have little more modifications than intrinsics
ones.

Test Plan:
qnnpack tests.
q8dwconv-test
convolution-test

[ghstack-poisoned]
@dr-ci
Copy link

dr-ci bot commented Apr 30, 2020

💊 CI failures summary and remediations

As of commit ec089d6 (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-CircleCI failure(s)

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 122 times.

Summary:
Due to potential perf issues with using same depthwise conv kernels for
perf channel depthwise conv, we opt for replicating the kernels and
adding per channel support to them.
Note that the large kernels files are largely duplication of original
kernels. Assembly kernels have little more modifications than intrinsics
ones.

Test Plan:
qnnpack tests.
q8dwconv-test
convolution-test

Differential Revision: [D21339042](https://our.internmc.facebook.com/intern/diff/D21339042)

[ghstack-poisoned]
Summary:
Due to potential perf issues with using same depthwise conv kernels for
perf channel depthwise conv, we opt for replicating the kernels and
adding per channel support to them.
Note that the large kernels files are largely duplication of original
kernels. Assembly kernels have little more modifications than intrinsics
ones.

Test Plan:
qnnpack tests.
q8dwconv-test
convolution-test

Differential Revision: [D21339042](https://our.internmc.facebook.com/intern/diff/D21339042)

[ghstack-poisoned]
Summary:
Due to potential perf issues with using same depthwise conv kernels for
perf channel depthwise conv, we opt for replicating the kernels and
adding per channel support to them.
Note that the large kernels files are largely duplication of original
kernels. Assembly kernels have little more modifications than intrinsics
ones.

Test Plan:
qnnpack tests.
q8dwconv-test
convolution-test

Differential Revision: [D21339042](https://our.internmc.facebook.com/intern/diff/D21339042)

[ghstack-poisoned]
Summary:
Due to potential perf issues with using same depthwise conv kernels for
perf channel depthwise conv, we opt for replicating the kernels and
adding per channel support to them.
Note that the large kernels files are largely duplication of original
kernels. Assembly kernels have little more modifications than intrinsics
ones.

Test Plan:
qnnpack tests.
q8dwconv-test
convolution-test

Differential Revision: [D21339042](https://our.internmc.facebook.com/intern/diff/D21339042)

[ghstack-poisoned]
Summary:
Due to potential perf issues with using same depthwise conv kernels for
perf channel depthwise conv, we opt for replicating the kernels and
adding per channel support to them.
Note that the large kernels files are largely duplication of original
kernels. Assembly kernels have little more modifications than intrinsics
ones.

Test Plan:
qnnpack tests.
q8dwconv-test
convolution-test

Differential Revision: [D21339042](https://our.internmc.facebook.com/intern/diff/D21339042)

[ghstack-poisoned]
Summary:
Due to potential perf issues with using same depthwise conv kernels for
perf channel depthwise conv, we opt for replicating the kernels and
adding per channel support to them.
Note that the large kernels files are largely duplication of original
kernels. Assembly kernels have little more modifications than intrinsics
ones.

Test Plan:
qnnpack tests.
q8dwconv-test
convolution-test

Differential Revision: [D21339042](https://our.internmc.facebook.com/intern/diff/D21339042)

[ghstack-poisoned]
@kimishpatel kimishpatel mentioned this pull request May 5, 2020
Summary:
Due to potential perf issues with using same depthwise conv kernels for
perf channel depthwise conv, we opt for replicating the kernels and
adding per channel support to them.
Note that the large kernels files are largely duplication of original
kernels. Assembly kernels have little more modifications than intrinsics
ones.

Test Plan:
qnnpack tests.
q8dwconv-test
convolution-test

Differential Revision: [D21339042](https://our.internmc.facebook.com/intern/diff/D21339042)

[ghstack-poisoned]
Summary:
Due to potential perf issues with using same depthwise conv kernels for
perf channel depthwise conv, we opt for replicating the kernels and
adding per channel support to them.
Note that the large kernels files are largely duplication of original
kernels. Assembly kernels have little more modifications than intrinsics
ones.

Test Plan:
qnnpack tests.
q8dwconv-test
convolution-test

Differential Revision: [D21339042](https://our.internmc.facebook.com/intern/diff/D21339042)

[ghstack-poisoned]
Summary:
Due to potential perf issues with using same depthwise conv kernels for
perf channel depthwise conv, we opt for replicating the kernels and
adding per channel support to them.
Note that the large kernels files are largely duplication of original
kernels. Assembly kernels have little more modifications than intrinsics
ones.

Test Plan:
qnnpack tests.
q8dwconv-test
convolution-test

Differential Revision: [D21339042](https://our.internmc.facebook.com/intern/diff/D21339042)

[ghstack-poisoned]
Summary:
Due to potential perf issues with using same depthwise conv kernels for
perf channel depthwise conv, we opt for replicating the kernels and
adding per channel support to them.
Note that the large kernels files are largely duplication of original
kernels. Assembly kernels have little more modifications than intrinsics
ones.

Test Plan:
qnnpack tests.
q8dwconv-test
convolution-test

Differential Revision: [D21339042](https://our.internmc.facebook.com/intern/diff/D21339042)

[ghstack-poisoned]
Summary:
Due to potential perf issues with using same depthwise conv kernels for
perf channel depthwise conv, we opt for replicating the kernels and
adding per channel support to them.
Note that the large kernels files are largely duplication of original
kernels. Assembly kernels have little more modifications than intrinsics
ones.

Test Plan:
qnnpack tests.
q8dwconv-test
convolution-test

Differential Revision: [D21339042](https://our.internmc.facebook.com/intern/diff/D21339042)

[ghstack-poisoned]
Summary:
Due to potential perf issues with using same depthwise conv kernels for perf channel depthwise conv, we opt for replicating the kernels and adding per channel support to them.
Note that the large kernels files are largely duplication of original kernels. Main difference in the kernels is that for each iteration (over a group of output channels) of the loop we need to obtain corresponding zero point and requantization scale. Rest of the compute is the same.

Test Plan:
qnnpack tests.
q8dwconv-test
convolution-test

Differential Revision: [D21339042](https://our.internmc.facebook.com/intern/diff/D21339042)

[ghstack-poisoned]
xuezhou1998 pushed a commit to xuezhou1998/new_pytorch that referenced this pull request May 9, 2020
Summary:
Due to potential perf issues with using same depthwise conv kernels for
perf channel depthwise conv, we opt for replicating the kernels and
adding per channel support to them.
Note that the large kernels files are largely duplication of original
kernels. Assembly kernels have little more modifications than intrinsics
ones.

Test Plan:
qnnpack tests.
q8dwconv-test
convolution-test

ghstack-source-id: 99c70d9d5186a3e81189602948792205087b241d
Pull Request resolved: pytorch/pytorch#37621
Summary:
Due to potential perf issues with using same depthwise conv kernels for perf channel depthwise conv, we opt for replicating the kernels and adding per channel support to them.
Note that the large kernels files are largely duplication of original kernels. Main difference in the kernels is that for each iteration (over a group of output channels) of the loop we need to obtain corresponding zero point and requantization scale. Rest of the compute is the same.

Test Plan:
qnnpack tests.
q8dwconv-test
convolution-test

Differential Revision: [D21339042](https://our.internmc.facebook.com/intern/diff/D21339042)

[ghstack-poisoned]
Summary:
Due to potential perf issues with using same depthwise conv kernels for perf channel depthwise conv, we opt for replicating the kernels and adding per channel support to them.
Note that the large kernels files are largely duplication of original kernels. Main difference in the kernels is that for each iteration (over a group of output channels) of the loop we need to obtain corresponding zero point and requantization scale. Rest of the compute is the same.

Test Plan:
qnnpack tests.
q8dwconv-test
convolution-test

Differential Revision: [D21339042](https://our.internmc.facebook.com/intern/diff/D21339042)

[ghstack-poisoned]
Summary:
Due to potential perf issues with using same depthwise conv kernels for perf channel depthwise conv, we opt for replicating the kernels and adding per channel support to them.
Note that the large kernels files are largely duplication of original kernels. Main difference in the kernels is that for each iteration (over a group of output channels) of the loop we need to obtain corresponding zero point and requantization scale. Rest of the compute is the same.

Test Plan:
qnnpack tests.
q8dwconv-test
convolution-test

Differential Revision: [D21339042](https://our.internmc.facebook.com/intern/diff/D21339042)

[ghstack-poisoned]
@@ -226,6 +226,12 @@ struct conv_param_t {
ukernel_type = pytorch_qnnp_ukernel_type_conv;
}
}

if (per_channel && ukernel_type == pytorch_qnnp_ukernel_type_xzp_gemm) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rules of operator precedence are unambiguous, but I personally err on the side of grouping with parentheses explicitly. This is a personal preference of course. Please feel free to ignore.

@@ -240,7 +242,7 @@ enum pytorch_qnnp_status pytorch_qnnp_create_convolution2d_nhwc_q8(
cr,
#if !PYTORCH_QNNPACK_RUNTIME_QUANTIZATION
input_zero_point,
kernel_zero_points[0],
kernel_zero_points.data(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the two are not equivalent, this is fixing a typo I assume, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. We never catch this because we dont build this. All this needs cleanup which is on my todo.

@@ -0,0 +1,939 @@
/*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume these files are mostly the same as their per tensor counter-parts but with some changes? Could you mark those places to make the review easier? Also I'm tempted to say I wish the coding style was left uncompliant like the original since that was easier to follow. :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I'm tempted to say I wish the coding style was left uncompliant like the original since that was easier to follow. :)

Not sure if I follow what you mean.

I will simply paste the diff between the files in the comments section.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way as a note aside. Current method of depthwise conv for per channel requires us to maintain two pointers zero point and requant scale. This requires us to do some stack manipulation to save these pointers. I did experiment with another approach where zero point and requant scales are packed with weights. It did not much difference for gemm/conv kernels, but for per channel depthwise I think it makes about 12% difference for aarch32 assembly kernel. But doing this only for depthwise requires some changes that I think I will do in a later PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice.

@AshkanAliabadi
Copy link
Contributor

@supriyar Considering the importance of unit tests in catching issue in lower level code like this, can you please review the reference implementation logic in the unit tests? That's an area you are more familiar with. Thanks!

Summary:
Due to potential perf issues with using same depthwise conv kernels for perf channel depthwise conv, we opt for replicating the kernels and adding per channel support to them.
Note that the large kernels files are largely duplication of original kernels. Main difference in the kernels is that for each iteration (over a group of output channels) of the loop we need to obtain corresponding zero point and requantization scale. Rest of the compute is the same.

Test Plan:
qnnpack tests.
q8dwconv-test
convolution-test

Differential Revision: [D21339042](https://our.internmc.facebook.com/intern/diff/D21339042)

[ghstack-poisoned]
@kimishpatel
Copy link
Contributor Author

So here are list of diff contents.

@kimishpatel
Copy link
Contributor Author

For mp8x25-neon-per-channel kernels:

--- src/q8dwconv/mp8x25-neon.c	2020-05-15 10:38:21.149858349 -0700
+++ src/q8dwconv/mp8x25-neon-per-channel.c	2020-05-15 11:03:26.415501529 -0700
@@ -10,7 +10,7 @@
 
 #include <qnnpack/q8dwconv.h>
 
-void pytorch_q8dwconv_ukernel_mp8x25__neon(
+void pytorch_q8dwconv_ukernel_mp8x25_per_channel__neon(
     size_t channels,
     size_t output_width,
     const uint8_t** input,
@@ -23,10 +23,6 @@
         quantization_params[restrict static 1]) {
   const uint8x8_t vinput_zero_point =
       vld1_dup_u8((const uint8_t*)&quantization_params->neon.input_zero_point);
-  const uint8x8_t vkernel_zero_point =
-      vdup_n_u8(quantization_params->neon.kernel_zero_points[0]);
-  const float32x4_t requantization_scale_v =
-      vdupq_n_f32(quantization_params->neon.requantization_scales[0]);
   const int16x8_t voutput_zero_point =
       vld1q_dup_s16(&quantization_params->neon.output_zero_point);
   const uint8x8_t voutput_min = vld1_dup_u8(&quantization_params->neon.output_min);
@@ -54,6 +50,8 @@
 
       size_t c = channels;
       for (; c >= 8; c -= 8) {
+        const uint8x8_t vkernel_zero_point =
+            vld1_u8(&quantization_params->neon.kernel_zero_points[channels - c]);
         int32x4_t vaccX1_lo = vld1q_s32(w);
         w = (void*)((uintptr_t)w + sizeof(int32x4_t));
         int32x4_t vaccX1_hi = vld1q_s32(w);
@@ -210,6 +208,8 @@
         i8 -= c_predecrement;
         i9 -= c_predecrement;
 
+        const uint8x8_t vkernel_zero_point =
+            vld1_u8(&quantization_params->neon.kernel_zero_points[channels - c]);
         int32x4_t vaccX1_lo = vld1q_s32(w);
         w = (void*)((uintptr_t)w + sizeof(int32x4_t));
         int32x4_t vaccX1_hi = vld1q_s32(w);
@@ -369,6 +369,8 @@
 
       size_t c = channels;
       for (; c >= 8; c -= 8) {
+        const uint8x8_t vkernel_zero_point =
+            vld1_u8(&quantization_params->neon.kernel_zero_points[channels - c]);
         const uint8x8_t vk0 = vld1_u8(w);
         w += 8;
         const uint8x8_t vi0 = vld1_u8(i0);
@@ -523,6 +525,8 @@
         i8 -= c_predecrement;
         i9 -= c_predecrement;
 
+        const uint8x8_t vkernel_zero_point =
+            vld1_u8(&quantization_params->neon.kernel_zero_points[channels - c]);
         const uint8x8_t vk0 = vld1_u8(w);
         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
         const uint8x8_t vi0 = vreinterpret_u8_u64(
@@ -677,6 +681,8 @@
 
       size_t c = channels;
       for (; c >= 8; c -= 8) {
+        const uint8x8_t vkernel_zero_point =
+            vld1_u8(&quantization_params->neon.kernel_zero_points[channels - c]);
         const uint8x8_t vk0 = vld1_u8(w);
         w += 8;
         const uint8x8_t vi0 = vld1_u8(i0);
@@ -750,10 +756,15 @@
         vacc_lo = vaddq_s32(vacc_lo, vacc_lo_old);
         vacc_hi = vaddq_s32(vacc_hi, vacc_hi_old);
 
+        const float32x4_t requantization_scale_v_lo =
+            vld1q_f32(&quantization_params->neon.requantization_scales[channels - c]);
+        const float32x4_t requantization_scale_v_hi =
+            vld1q_f32(&quantization_params->neon.requantization_scales[channels - c + 4]);
+
         const float32x4_t vacc_lo_f =
-          vmulq_f32(vcvtq_f32_s32(vacc_lo), requantization_scale_v);
+          vmulq_f32(vcvtq_f32_s32(vacc_lo), requantization_scale_v_lo);
         const float32x4_t vacc_hi_f =
-          vmulq_f32(vcvtq_f32_s32(vacc_hi), requantization_scale_v);
+          vmulq_f32(vcvtq_f32_s32(vacc_hi), requantization_scale_v_hi);
 
 #ifdef __aarch64__
         vacc_lo = vcvtnq_s32_f32(vacc_lo_f);
@@ -792,6 +803,8 @@
         i3 -= c_predecrement;
         i4 -= c_predecrement;
 
+        const uint8x8_t vkernel_zero_point =
+            vld1_u8(&quantization_params->neon.kernel_zero_points[channels - c]);
         const uint8x8_t vk0 = vld1_u8(w);
         w = (void*)((uintptr_t)w + sizeof(uint8x8_t));
         const uint8x8_t vi0 = vreinterpret_u8_u64(
@@ -863,10 +876,15 @@
         vacc_lo = vaddq_s32(vacc_lo, vacc_lo_old);
         vacc_hi = vaddq_s32(vacc_hi, vacc_hi_old);
 
+        const float32x4_t requantization_scale_v_lo =
+            vld1q_f32(&quantization_params->neon.requantization_scales[channels - c]);
+        const float32x4_t requantization_scale_v_hi =
+            vld1q_f32(&quantization_params->neon.requantization_scales[channels - c + 4]);
+
         const float32x4_t vacc_lo_f =
-          vmulq_f32(vcvtq_f32_s32(vacc_lo), requantization_scale_v);
+          vmulq_f32(vcvtq_f32_s32(vacc_lo), requantization_scale_v_lo);
         const float32x4_t vacc_hi_f =
-          vmulq_f32(vcvtq_f32_s32(vacc_hi), requantization_scale_v);
+          vmulq_f32(vcvtq_f32_s32(vacc_hi), requantization_scale_v_hi);
 
 #ifdef __aarch64__
         vacc_lo = vcvtnq_s32_f32(vacc_lo_f);

@kimishpatel
Copy link
Contributor Author

For mp8x25-sse2-per-channel kernel:

--- mp8x25-sse2.c	2020-05-15 10:38:21.149858349 -0700
+++ mp8x25-sse2-per-channel.c	2020-05-15 11:03:26.416501529 -0700
@@ -10,7 +10,7 @@
 
 #include <qnnpack/q8dwconv.h>
 
-void pytorch_q8dwconv_ukernel_mp8x25__sse2(
+void pytorch_q8dwconv_ukernel_mp8x25_per_channel__sse2(
     size_t channels,
     size_t output_width,
     const uint8_t** input,
@@ -23,8 +23,6 @@
         quantization_params[RESTRICT_STATIC 1]) {
   const __m128i vinput_zero_point = _mm_load_si128(
       (const __m128i*)quantization_params->sse2.input_zero_point);
-  const __m128i vkernel_zero_point = _mm_set1_epi16(
-      quantization_params->sse2.kernel_zero_points[0]);
   const __m128i vzero = _mm_setzero_si128();
 
   do {
@@ -46,6 +44,9 @@
       for (; c >= 8; c -= 8) {
         __m128i vacc_lo = _mm_loadu_si128((const __m128i*)w);
         __m128i vacc_hi = _mm_loadu_si128((const __m128i*)((uintptr_t)w + 16));
+        const __m128i vkernel_zero_point = _mm_loadl_epi64(
+            (const __m128i*)
+            &quantization_params->sse2.kernel_zero_points[channels - c]);
 
         const __m128i vi00 = _mm_loadl_epi64((const __m128i*)i00);
         i00 += 8;
@@ -54,7 +55,9 @@
         const __m128i vk00 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 32));
         const __m128i vxk00 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk00, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk00, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod00_odd = _mm_mullo_epi16(vxi00, vxk00);
         const __m128i vprod00_even = _mm_mulhi_epi16(vxi00, vxk00);
         vacc_lo = _mm_add_epi32(
@@ -69,7 +72,9 @@
         const __m128i vk01 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 40));
         const __m128i vxk01 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk01, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk01, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod01_odd = _mm_mullo_epi16(vxi01, vxk01);
         const __m128i vprod01_even = _mm_mulhi_epi16(vxi01, vxk01);
         vacc_lo = _mm_add_epi32(
@@ -84,7 +89,9 @@
         const __m128i vk02 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 48));
         const __m128i vxk02 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk02, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk02, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod02_odd = _mm_mullo_epi16(vxi02, vxk02);
         const __m128i vprod02_even = _mm_mulhi_epi16(vxi02, vxk02);
         vacc_lo = _mm_add_epi32(
@@ -99,7 +106,9 @@
         const __m128i vk10 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 56));
         const __m128i vxk10 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk10, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk10, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod10_odd = _mm_mullo_epi16(vxi10, vxk10);
         const __m128i vprod10_even = _mm_mulhi_epi16(vxi10, vxk10);
         vacc_lo = _mm_add_epi32(
@@ -114,7 +123,9 @@
         const __m128i vk11 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 64));
         const __m128i vxk11 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk11, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk11, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod11_odd = _mm_mullo_epi16(vxi11, vxk11);
         const __m128i vprod11_even = _mm_mulhi_epi16(vxi11, vxk11);
         vacc_lo = _mm_add_epi32(
@@ -129,7 +140,9 @@
         const __m128i vk12 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 72));
         const __m128i vxk12 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk12, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk12, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod12_odd = _mm_mullo_epi16(vxi12, vxk12);
         const __m128i vprod12_even = _mm_mulhi_epi16(vxi12, vxk12);
         vacc_lo = _mm_add_epi32(
@@ -144,7 +157,9 @@
         const __m128i vk20 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 80));
         const __m128i vxk20 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk20, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk20, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod20_odd = _mm_mullo_epi16(vxi20, vxk20);
         const __m128i vprod20_even = _mm_mulhi_epi16(vxi20, vxk20);
         vacc_lo = _mm_add_epi32(
@@ -159,7 +174,9 @@
         const __m128i vk21 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 88));
         const __m128i vxk21 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk21, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk21, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod21_odd = _mm_mullo_epi16(vxi21, vxk21);
         const __m128i vprod21_even = _mm_mulhi_epi16(vxi21, vxk21);
         vacc_lo = _mm_add_epi32(
@@ -174,7 +191,9 @@
         const __m128i vk22 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 96));
         const __m128i vxk22 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk22, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk22, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod22_odd = _mm_mullo_epi16(vxi22, vxk22);
         const __m128i vprod22_even = _mm_mulhi_epi16(vxi22, vxk22);
         vacc_lo = _mm_add_epi32(
@@ -189,7 +208,9 @@
         const __m128i vk23 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 104));
         const __m128i vxk23 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk23, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk23, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod23_odd = _mm_mullo_epi16(vxi23, vxk23);
         const __m128i vprod23_even = _mm_mulhi_epi16(vxi23, vxk23);
         vacc_lo = _mm_add_epi32(
@@ -206,6 +227,9 @@
       if (c != 0) {
         const size_t i_predecrement = 8 - c;
         const __m128i vi_shift = _mm_cvtsi32_si128(8 * i_predecrement);
+        const __m128i vkernel_zero_point = _mm_loadl_epi64(
+            (const __m128i*)
+            &quantization_params->sse2.kernel_zero_points[channels - c]);
         i00 -= i_predecrement;
         i01 -= i_predecrement;
         i02 -= i_predecrement;
@@ -227,7 +251,9 @@
         const __m128i vk00 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 32));
         const __m128i vxk00 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk00, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk00, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod00_odd = _mm_mullo_epi16(vxi00, vxk00);
         const __m128i vprod00_even = _mm_mulhi_epi16(vxi00, vxk00);
         vacc_lo = _mm_add_epi32(
@@ -242,7 +268,9 @@
         const __m128i vk01 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 40));
         const __m128i vxk01 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk01, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk01, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod01_odd = _mm_mullo_epi16(vxi01, vxk01);
         const __m128i vprod01_even = _mm_mulhi_epi16(vxi01, vxk01);
         vacc_lo = _mm_add_epi32(
@@ -257,7 +285,9 @@
         const __m128i vk02 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 48));
         const __m128i vxk02 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk02, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk02, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod02_odd = _mm_mullo_epi16(vxi02, vxk02);
         const __m128i vprod02_even = _mm_mulhi_epi16(vxi02, vxk02);
         vacc_lo = _mm_add_epi32(
@@ -272,7 +302,9 @@
         const __m128i vk10 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 56));
         const __m128i vxk10 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk10, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk10, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod10_odd = _mm_mullo_epi16(vxi10, vxk10);
         const __m128i vprod10_even = _mm_mulhi_epi16(vxi10, vxk10);
         vacc_lo = _mm_add_epi32(
@@ -287,7 +319,9 @@
         const __m128i vk11 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 64));
         const __m128i vxk11 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk11, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk11, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod11_odd = _mm_mullo_epi16(vxi11, vxk11);
         const __m128i vprod11_even = _mm_mulhi_epi16(vxi11, vxk11);
         vacc_lo = _mm_add_epi32(
@@ -302,7 +336,9 @@
         const __m128i vk12 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 72));
         const __m128i vxk12 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk12, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk12, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod12_odd = _mm_mullo_epi16(vxi12, vxk12);
         const __m128i vprod12_even = _mm_mulhi_epi16(vxi12, vxk12);
         vacc_lo = _mm_add_epi32(
@@ -317,7 +353,9 @@
         const __m128i vk20 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 80));
         const __m128i vxk20 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk20, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk20, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod20_odd = _mm_mullo_epi16(vxi20, vxk20);
         const __m128i vprod20_even = _mm_mulhi_epi16(vxi20, vxk20);
         vacc_lo = _mm_add_epi32(
@@ -332,7 +370,9 @@
         const __m128i vk21 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 88));
         const __m128i vxk21 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk21, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk21, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod21_odd = _mm_mullo_epi16(vxi21, vxk21);
         const __m128i vprod21_even = _mm_mulhi_epi16(vxi21, vxk21);
         vacc_lo = _mm_add_epi32(
@@ -347,7 +387,9 @@
         const __m128i vk22 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 96));
         const __m128i vxk22 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk22, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk22, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod22_odd = _mm_mullo_epi16(vxi22, vxk22);
         const __m128i vprod22_even = _mm_mulhi_epi16(vxi22, vxk22);
         vacc_lo = _mm_add_epi32(
@@ -362,7 +404,9 @@
         const __m128i vk23 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 104));
         const __m128i vxk23 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk23, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk23, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod23_odd = _mm_mullo_epi16(vxi23, vxk23);
         const __m128i vprod23_even = _mm_mulhi_epi16(vxi23, vxk23);
         vacc_lo = _mm_add_epi32(
@@ -393,12 +437,17 @@
       size_t c = channels;
       for (; c >= 8; c -= 8) {
         const __m128i vi00 = _mm_loadl_epi64((const __m128i*)i00);
+        const __m128i vkernel_zero_point = _mm_loadl_epi64(
+            (const __m128i*)
+            &quantization_params->sse2.kernel_zero_points[channels - c]);
         i00 += 8;
         const __m128i vxi00 =
             _mm_sub_epi16(_mm_unpacklo_epi8(vi00, vzero), vinput_zero_point);
         const __m128i vk00 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w));
         const __m128i vxk00 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk00, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk00, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod00_odd = _mm_mullo_epi16(vxi00, vxk00);
         const __m128i vprod00_even = _mm_mulhi_epi16(vxi00, vxk00);
         __m128i vacc_lo = _mm_unpacklo_epi16(vprod00_odd, vprod00_even);
@@ -411,7 +460,9 @@
         const __m128i vk01 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 8));
         const __m128i vxk01 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk01, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk01, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod01_odd = _mm_mullo_epi16(vxi01, vxk01);
         const __m128i vprod01_even = _mm_mulhi_epi16(vxi01, vxk01);
         vacc_lo = _mm_add_epi32(
@@ -426,7 +477,9 @@
         const __m128i vk02 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 16));
         const __m128i vxk02 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk02, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk02, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod02_odd = _mm_mullo_epi16(vxi02, vxk02);
         const __m128i vprod02_even = _mm_mulhi_epi16(vxi02, vxk02);
         vacc_lo = _mm_add_epi32(
@@ -441,7 +494,9 @@
         const __m128i vk10 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 24));
         const __m128i vxk10 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk10, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk10, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod10_odd = _mm_mullo_epi16(vxi10, vxk10);
         const __m128i vprod10_even = _mm_mulhi_epi16(vxi10, vxk10);
         vacc_lo = _mm_add_epi32(
@@ -456,7 +511,9 @@
         const __m128i vk11 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 32));
         const __m128i vxk11 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk11, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk11, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod11_odd = _mm_mullo_epi16(vxi11, vxk11);
         const __m128i vprod11_even = _mm_mulhi_epi16(vxi11, vxk11);
         vacc_lo = _mm_add_epi32(
@@ -471,7 +528,9 @@
         const __m128i vk12 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 40));
         const __m128i vxk12 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk12, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk12, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod12_odd = _mm_mullo_epi16(vxi12, vxk12);
         const __m128i vprod12_even = _mm_mulhi_epi16(vxi12, vxk12);
         vacc_lo = _mm_add_epi32(
@@ -486,7 +545,9 @@
         const __m128i vk20 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 48));
         const __m128i vxk20 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk20, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk20, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod20_odd = _mm_mullo_epi16(vxi20, vxk20);
         const __m128i vprod20_even = _mm_mulhi_epi16(vxi20, vxk20);
         vacc_lo = _mm_add_epi32(
@@ -501,7 +562,9 @@
         const __m128i vk21 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 56));
         const __m128i vxk21 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk21, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk21, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod21_odd = _mm_mullo_epi16(vxi21, vxk21);
         const __m128i vprod21_even = _mm_mulhi_epi16(vxi21, vxk21);
         vacc_lo = _mm_add_epi32(
@@ -516,7 +579,9 @@
         const __m128i vk22 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 64));
         const __m128i vxk22 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk22, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk22, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod22_odd = _mm_mullo_epi16(vxi22, vxk22);
         const __m128i vprod22_even = _mm_mulhi_epi16(vxi22, vxk22);
         vacc_lo = _mm_add_epi32(
@@ -531,7 +596,9 @@
         const __m128i vk23 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 72));
         const __m128i vxk23 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk23, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk23, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod23_odd = _mm_mullo_epi16(vxi23, vxk23);
         const __m128i vprod23_even = _mm_mulhi_epi16(vxi23, vxk23);
         vacc_lo = _mm_add_epi32(
@@ -551,6 +618,9 @@
       if (c != 0) {
         const size_t i_predecrement = 8 - c;
         const __m128i vi_shift = _mm_cvtsi32_si128(8 * i_predecrement);
+        const __m128i vkernel_zero_point = _mm_loadl_epi64(
+            (const __m128i*)
+            &quantization_params->sse2.kernel_zero_points[channels - c]);
         i00 -= i_predecrement;
         i01 -= i_predecrement;
         i02 -= i_predecrement;
@@ -568,7 +638,9 @@
             _mm_sub_epi16(_mm_unpacklo_epi8(vi00, vzero), vinput_zero_point);
         const __m128i vk00 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w));
         const __m128i vxk00 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk00, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk00, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod00_odd = _mm_mullo_epi16(vxi00, vxk00);
         const __m128i vprod00_even = _mm_mulhi_epi16(vxi00, vxk00);
         __m128i vacc_lo = _mm_unpacklo_epi16(vprod00_odd, vprod00_even);
@@ -581,7 +653,9 @@
         const __m128i vk01 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 8));
         const __m128i vxk01 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk01, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk01, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod01_odd = _mm_mullo_epi16(vxi01, vxk01);
         const __m128i vprod01_even = _mm_mulhi_epi16(vxi01, vxk01);
         vacc_lo = _mm_add_epi32(
@@ -596,7 +670,9 @@
         const __m128i vk02 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 16));
         const __m128i vxk02 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk02, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk02, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod02_odd = _mm_mullo_epi16(vxi02, vxk02);
         const __m128i vprod02_even = _mm_mulhi_epi16(vxi02, vxk02);
         vacc_lo = _mm_add_epi32(
@@ -611,7 +687,9 @@
         const __m128i vk10 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 24));
         const __m128i vxk10 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk10, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk10, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod10_odd = _mm_mullo_epi16(vxi10, vxk10);
         const __m128i vprod10_even = _mm_mulhi_epi16(vxi10, vxk10);
         vacc_lo = _mm_add_epi32(
@@ -626,7 +704,9 @@
         const __m128i vk11 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 32));
         const __m128i vxk11 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk11, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk11, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod11_odd = _mm_mullo_epi16(vxi11, vxk11);
         const __m128i vprod11_even = _mm_mulhi_epi16(vxi11, vxk11);
         vacc_lo = _mm_add_epi32(
@@ -641,7 +721,9 @@
         const __m128i vk12 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 40));
         const __m128i vxk12 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk12, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk12, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod12_odd = _mm_mullo_epi16(vxi12, vxk12);
         const __m128i vprod12_even = _mm_mulhi_epi16(vxi12, vxk12);
         vacc_lo = _mm_add_epi32(
@@ -656,7 +738,9 @@
         const __m128i vk20 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 48));
         const __m128i vxk20 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk20, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk20, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod20_odd = _mm_mullo_epi16(vxi20, vxk20);
         const __m128i vprod20_even = _mm_mulhi_epi16(vxi20, vxk20);
         vacc_lo = _mm_add_epi32(
@@ -671,7 +755,9 @@
         const __m128i vk21 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 56));
         const __m128i vxk21 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk21, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk21, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod21_odd = _mm_mullo_epi16(vxi21, vxk21);
         const __m128i vprod21_even = _mm_mulhi_epi16(vxi21, vxk21);
         vacc_lo = _mm_add_epi32(
@@ -686,7 +772,9 @@
         const __m128i vk22 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 64));
         const __m128i vxk22 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk22, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk22, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod22_odd = _mm_mullo_epi16(vxi22, vxk22);
         const __m128i vprod22_even = _mm_mulhi_epi16(vxi22, vxk22);
         vacc_lo = _mm_add_epi32(
@@ -701,7 +789,9 @@
         const __m128i vk23 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 72));
         const __m128i vxk23 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk23, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk23, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod23_odd = _mm_mullo_epi16(vxi23, vxk23);
         const __m128i vprod23_even = _mm_mulhi_epi16(vxi23, vxk23);
         vacc_lo = _mm_add_epi32(
@@ -730,12 +820,17 @@
       size_t c = channels;
       for (; c >= 8; c -= 8) {
         const __m128i vi00 = _mm_loadl_epi64((const __m128i*)i00);
+        const __m128i vkernel_zero_point = _mm_loadl_epi64(
+            (const __m128i*)
+            &quantization_params->sse2.kernel_zero_points[channels - c]);
         i00 += 8;
         const __m128i vxi00 =
             _mm_sub_epi16(_mm_unpacklo_epi8(vi00, vzero), vinput_zero_point);
         const __m128i vk00 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w));
         const __m128i vxk00 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk00, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk00, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod00_odd = _mm_mullo_epi16(vxi00, vxk00);
         const __m128i vprod00_even = _mm_mulhi_epi16(vxi00, vxk00);
         __m128i vacc_lo = _mm_unpacklo_epi16(vprod00_odd, vprod00_even);
@@ -748,7 +843,9 @@
         const __m128i vk01 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 8));
         const __m128i vxk01 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk01, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk01, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod01_odd = _mm_mullo_epi16(vxi01, vxk01);
         const __m128i vprod01_even = _mm_mulhi_epi16(vxi01, vxk01);
         vacc_lo = _mm_add_epi32(
@@ -763,7 +860,9 @@
         const __m128i vk02 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 16));
         const __m128i vxk02 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk02, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk02, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod02_odd = _mm_mullo_epi16(vxi02, vxk02);
         const __m128i vprod02_even = _mm_mulhi_epi16(vxi02, vxk02);
         vacc_lo = _mm_add_epi32(
@@ -778,7 +877,9 @@
         const __m128i vk10 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 24));
         const __m128i vxk10 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk10, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk10, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod10_odd = _mm_mullo_epi16(vxi10, vxk10);
         const __m128i vprod10_even = _mm_mulhi_epi16(vxi10, vxk10);
         vacc_lo = _mm_add_epi32(
@@ -793,7 +894,9 @@
         const __m128i vk11 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 32));
         const __m128i vxk11 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk11, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk11, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod11_odd = _mm_mullo_epi16(vxi11, vxk11);
         const __m128i vprod11_even = _mm_mulhi_epi16(vxi11, vxk11);
         vacc_lo = _mm_add_epi32(
@@ -808,19 +911,21 @@
             _mm_add_epi32(vacc_hi, _mm_loadu_si128((__m128i*)(outacc + 4)));
         outacc += 8;
 
-        const __m128 vmultiplier =
-            _mm_set1_ps(quantization_params->sse2.requantization_scales[0]);
+        const __m128 vmultiplier_lo =
+            _mm_loadu_ps(&quantization_params->sse2.requantization_scales[channels - c]);
+        const __m128 vmultiplier_hi =
+            _mm_loadu_ps(&quantization_params->sse2.requantization_scales[channels - c + 4]);
 
         vacc_lo = _mm_cvtps_epi32(
                       _mm_mul_ps(
                         _mm_cvtepi32_ps(vacc_lo),
-                        vmultiplier
+                        vmultiplier_lo
                         )
                       );
         vacc_hi = _mm_cvtps_epi32(
                       _mm_mul_ps(
                         _mm_cvtepi32_ps(vacc_hi),
-                        vmultiplier
+                        vmultiplier_hi
                         )
                       );
 
@@ -844,6 +949,9 @@
       if (c != 0) {
         const size_t i_predecrement = 8 - c;
         const __m128i vi_shift = _mm_cvtsi32_si128(8 * i_predecrement);
+        const __m128i vkernel_zero_point = _mm_loadl_epi64(
+            (const __m128i*)
+            &quantization_params->sse2.kernel_zero_points[channels - c]);
         i00 -= i_predecrement;
         i01 -= i_predecrement;
         i02 -= i_predecrement;
@@ -856,7 +964,9 @@
             _mm_sub_epi16(_mm_unpacklo_epi8(vi00, vzero), vinput_zero_point);
         const __m128i vk00 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w));
         const __m128i vxk00 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk00, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk00, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod00_odd = _mm_mullo_epi16(vxi00, vxk00);
         const __m128i vprod00_even = _mm_mulhi_epi16(vxi00, vxk00);
         __m128i vacc_lo = _mm_unpacklo_epi16(vprod00_odd, vprod00_even);
@@ -869,7 +979,9 @@
         const __m128i vk01 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 8));
         const __m128i vxk01 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk01, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk01, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod01_odd = _mm_mullo_epi16(vxi01, vxk01);
         const __m128i vprod01_even = _mm_mulhi_epi16(vxi01, vxk01);
         vacc_lo = _mm_add_epi32(
@@ -884,7 +996,9 @@
         const __m128i vk02 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 16));
         const __m128i vxk02 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk02, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk02, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod02_odd = _mm_mullo_epi16(vxi02, vxk02);
         const __m128i vprod02_even = _mm_mulhi_epi16(vxi02, vxk02);
         vacc_lo = _mm_add_epi32(
@@ -899,7 +1013,9 @@
         const __m128i vk10 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 24));
         const __m128i vxk10 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk10, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk10, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod10_odd = _mm_mullo_epi16(vxi10, vxk10);
         const __m128i vprod10_even = _mm_mulhi_epi16(vxi10, vxk10);
         vacc_lo = _mm_add_epi32(
@@ -914,7 +1030,9 @@
         const __m128i vk11 =
             _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 32));
         const __m128i vxk11 =
-            _mm_sub_epi16(_mm_unpacklo_epi8(vk11, vzero), vkernel_zero_point);
+            _mm_sub_epi16(
+                _mm_unpacklo_epi8(vk11, vzero),
+                _mm_unpacklo_epi8(vkernel_zero_point, vzero));
         const __m128i vprod11_odd = _mm_mullo_epi16(vxi11, vxk11);
         const __m128i vprod11_even = _mm_mulhi_epi16(vxi11, vxk11);
         vacc_lo = _mm_add_epi32(
@@ -927,19 +1045,21 @@
             _mm_add_epi32(vacc_hi, _mm_loadu_si128((__m128i*)(outacc + 4)));
         outacc += 8;
 
-        const __m128 vmultiplier =
-            _mm_set1_ps(quantization_params->sse2.requantization_scales[0]);
+        const __m128 vmultiplier_lo =
+            _mm_loadu_ps(&quantization_params->sse2.requantization_scales[channels - c]);
+        const __m128 vmultiplier_hi =
+            _mm_loadu_ps(&quantization_params->sse2.requantization_scales[channels - c + 4]);
 
         vacc_lo = _mm_cvtps_epi32(
                       _mm_mul_ps(
                         _mm_cvtepi32_ps(vacc_lo),
-                        vmultiplier
+                        vmultiplier_lo
                         )
                       );
         vacc_hi = _mm_cvtps_epi32(
                       _mm_mul_ps(
                         _mm_cvtepi32_ps(vacc_hi),
-                        vmultiplier
+                        vmultiplier_hi
                         )
                       );

@kimishpatel
Copy link
Contributor Author

For up8x9-aarch32-neon-per-channel kernel:

--- up8x9-aarch32-neon.S	2020-05-15 10:38:21.150858349 -0700
+++ up8x9-aarch32-neon-per-channel.S	2020-05-15 11:03:26.416501529 -0700
@@ -11,7 +11,7 @@
 
 .syntax unified
 
-# void pytorch_q8dwconv_ukernel_up8x9__aarch32_neon(
+# void pytorch_q8dwconv_ukernel_up8x9_per_channel__aarch32_neon(
 #     size_t channels,
 #     size_t output_width,
 #     const uint8_t** input,
@@ -20,7 +20,7 @@
 #     size_t input_stride,
 #     size_t output_increment,
 #     const union pytorch_qnnp_conv_quantization_params quantization_params[restrict static 1])
-BEGIN_FUNCTION pytorch_q8dwconv_ukernel_up8x9__aarch32_neon
+BEGIN_FUNCTION pytorch_q8dwconv_ukernel_up8x9_per_channel__aarch32_neon
     .arm
 #ifndef __APPLE__
     .arch armv7-a
@@ -36,19 +36,24 @@
 
     STR r0, [sp, #-8]
     STR r3, [sp, #-4]
+    STR r1, [sp, #-12]
+    STR r2, [sp, #-16]
 
     # Load the address zero_point array.
-    # For depth wise kernels the array is of single element.
     LDR r5, [r12], 4
+    # Push the zero_point_array base pointer on stack
+    # We dont have enough registers to maintain
+    # base pointers. Thus we will have to do some pushes
+    # and pops.
+    # At sp #-20 we store updated/working copy pointers
+    # At sp #-28 we store orig pointers that can be reloaded
+    # for more output pixels
+    STR r5, [sp, #-28]
 
     # Load o:
     # - lr = o = output
     LDR lr, [sp, 100]
 
-    # Load kernel zero point:
-    # - d31 = vkernel_zero_point
-    VLD1.8 {d31[]}, [r5]
-
     # Load input zero point:
     # - d30 = vinput_zero_point
     VLD1.8 {d30[]}, [r12]
@@ -56,12 +61,14 @@
     # For depth wise kernels the array is of single element.
     # pre-index r12 = r12 + 4
     LDR r5, [r12, 4]!
+    # Push the requantization_scales base pointer on stack
+    # At sp #-24 we store updated/working copy pointers
+    # At sp #-32 we store orig pointers that can be reloaded
+    # for more output pixels
+    STR r5, [sp, #-32]
 
     # add 8 bytes to get to vfmax
     ADD r12, r12, 8
-    # Load requantization_scale:
-    # - q14 = d28:d29 = requantization_scale
-    VLD1.32 {d28[], d29[]}, [r5]
 
     # Load vfmax:
     # - q13 = d26:d27 = vfmax
@@ -83,15 +90,26 @@
     # on the stack and pop it back.
     VLD1.32 {d22[], d23[]}, [r12]
 
-    VSTR d22, [sp, #-16]
-    VSTR d23, [sp, #-24]
+    VSTR d22, [sp, #-40]
+    VSTR d23, [sp, #-48]
 
     .p2align 3
 0:
+    # Load original zero point base pointer
+    LDR r4, [sp, #-28]
+    # Load original requant scale base pointer
+    LDR r5, [sp, #-32]
+    # Load indirection pointer from stack
+    LDR r2, [sp, #-16]
     # Load input stride
     # - r3 = input_stride
     LDR r3, [sp, 104]
 
+    # Store original zero point to working copy
+    STR r4, [sp, #-20]
+    # Store original requant scale to working copy
+    STR r5, [sp, #-24]
+
     # Load c:
     # - r0 = c = channels
     LDR r0, [sp, #-8]
@@ -114,6 +132,7 @@
     # Increment input by input stride
     # - input = r2 := input + input_stride
     ADD r2, r2, r3
+    STR r2, [sp, #-16]
 
     # Load w:
     # - r3 = w = weights
@@ -128,9 +147,23 @@
     VLD1.8 {d4}, [r4]!
     VLD1.8 {d6}, [r3]!
 
+    # zero point array base address
+    LDR r1, [sp, #-20]
+    # requantization scale array base address
+    LDR r2, [sp, #-24]
+
     VLD1.8 {d8}, [r5]!
     VLD1.8 {d10}, [r3]!
 
+    # - d31 = vkernel_zero_point
+    VLD1.8 {d31}, [r1]!
+    # - q8 = d16:d17= requantization_scale_lo
+    VLD1.32 {d16, d17}, [r2]!
+    # - q14 = d28:d29 = requantization_scale_hi
+    VLD1.32 {d28, d29}, [r2]!
+    STR r1, [sp, #-20]
+    STR r2, [sp, #-24]
+
     SUB_ZERO_POINT q2, d4, d30
     VSUBL.U8 q3, d6, d31
 
@@ -209,7 +242,7 @@
     VCVT.F32.S32 q0, q0
     VCVT.F32.S32 q1, q1
 
-    VMUL.F32 q0, q0, q14
+    VMUL.F32 q0, q0, q8
     VMUL.F32 q1, q1, q14
 
     VMIN.F32 q0, q0, q13
@@ -236,6 +269,11 @@
     CMP r0, -8
     BEQ 5f
 
+    # zero point array base address
+    LDR r1, [sp, #-20]
+    # requantization scale array base address
+    LDR r2, [sp, #-24]
+
     ADD r4, r4, r0
     ADD r5, r5, r0
     ADD r6, r6, r0
@@ -246,6 +284,9 @@
     ADD r11, r11, r0
     ADD r12, r12, r0
 
+    # - d31 = vkernel_zero_point
+    VLD1.8 {d31}, [r1]
+
     LSL r0, r0, 3
     VDUP.32 d22, r0
 
@@ -346,16 +387,21 @@
     VMLAL.S16 q0, d16, d18
     VMLAL.S16 q1, d17, d19
 
+    # - q8 = d16:d17= requantization_scale_lo
+    VLD1.32 {d16, d17}, [r2]!
+    # - q14 = d28:d29 = requantization_scale_hi
+    VLD1.32 {d28, d29}, [r2]
+
     VMLAL.S16 q0, d4, d6
     VMLAL.S16 q1, d5, d7
 
-    VLDR.64 d22, [sp, #-16]
-    VLDR.64 d23, [sp, #-24]
+    VLDR.64 d22, [sp, #-40]
+    VLDR.64 d23, [sp, #-48]
 
     VCVT.F32.S32 q0, q0
     VCVT.F32.S32 q1, q1
 
-    VMUL.F32 q0, q0, q14
+    VMUL.F32 q0, q0, q8
     VMUL.F32 q1, q1, q14
 
     VMIN.F32 q0, q0, q13
@@ -392,12 +438,16 @@
     VST1.8 {d0[0]}, [lr]!
 
 5:
+    # Load output_width from stack
+    LDR r1, [sp, #-12]
     # Load output increment
     # - r3 = output_increment
     LDR r3, [sp, 108]
 
     # Decrement output width
     SUBS r1, r1, 1
+    # store output_width on stack
+    STR r1, [sp, #-12]
 
     # Increment output by output_increment
     ADD lr, lr, r3
@@ -407,7 +457,7 @@
 
     VPOP {d8-d15}
     POP {r4, r5, r6, r7, r8, r9, r10, r11, pc}
-END_FUNCTION pytorch_q8dwconv_ukernel_up8x9__aarch32_neon
+END_FUNCTION pytorch_q8dwconv_ukernel_up8x9_per_channel__aarch32_neon
 
 #ifdef __ELF__
 .section ".note.GNU-stack","",%progbits

@kimishpatel
Copy link
Contributor Author

For up8x9-neon-per-channel kernel:

--- up8x9-neon.c	2020-05-15 10:38:21.150858349 -0700
+++ up8x9-neon-per-channel.c	2020-05-15 11:03:26.416501529 -0700
@@ -11,7 +11,7 @@
 #include <qnnpack/q8dwconv.h>
 #include <requantization/runtime-neon.h>
 
-void pytorch_q8dwconv_ukernel_up8x9__neon(
+void pytorch_q8dwconv_ukernel_up8x9_per_channel__neon(
     size_t channels,
     size_t output_width,
     const uint8_t** input,
@@ -23,10 +23,6 @@
         quantization_params[restrict static 1]) {
   const uint8x8_t va_zero_point =
       vld1_dup_u8((const uint8_t*)&quantization_params->neon.input_zero_point);
-  const uint8x8_t vkernel_zero_point =
-      vdup_n_u8(quantization_params->neon.kernel_zero_points[0]);
-  const float32x4_t requantization_scale_v =
-      vdupq_n_f32(quantization_params->neon.requantization_scales[0]);
   const int16x8_t voutput_zero_point =
       vld1q_dup_s16(&quantization_params->neon.output_zero_point);
   const uint8x8_t voutput_min =
@@ -43,6 +39,23 @@
    * pixels at a time */
   if (input_stride == 3 * sizeof(void*)) {
     for (; output_width >= 3; output_width -= 3) {
+      /*
+       * Following 15 values represent:
+       * -------------------------
+       *| 00 | 01 | 02 | 03 | 04 |
+       * -------------------------
+       *| 10 | 11 | 12 | 13 | 14 |
+       * -------------------------
+       *| 20 | 21 | 22 | 23 | 24 |
+       * -------------------------
+       *  Thus:
+       *  acc0 = 00 + 10 + 20 + 01 + 11 + 21 + 02 + 12 + 22
+       *  acc1 = 01 + 11 + 21 + 02 + 12 + 22 + 03 + 13 + 23
+       *  acc2 = 02 + 12 + 22 + 03 + 13 + 23 + 04 + 14 + 24
+       *
+       *  For channel wise:
+       *  We may have to do one less output for per perhaps? Need to look at the perf.
+       */
       const uint8_t* i00 = input[0];
       const uint8_t* i10 = input[1];
       const uint8_t* i20 = input[2];
@@ -68,6 +81,8 @@
       size_t c = channels;
       const void* w = weights;
       for (; c >= 8; c -= 8) {
+        const uint8x8_t vkernel_zero_point =
+            vld1_u8(&quantization_params->neon.kernel_zero_points[channels - c]);
         int32x4_t vacc0_lo = vld1q_s32(w);
         w = (void*)((uintptr_t)w + sizeof(int32x4_t));
         int32x4_t vacc0_hi = vld1q_s32(w);
@@ -263,18 +278,23 @@
             vmlal_s16(vacc2_lo, vget_low_s16(vxk22), vget_low_s16(vxi24));
         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk22, vxi24);
 
+        const float32x4_t requantization_scale_v_lo =
+            vld1q_f32(&quantization_params->neon.requantization_scales[channels - c]);
+        const float32x4_t requantization_scale_v_hi =
+            vld1q_f32(&quantization_params->neon.requantization_scales[channels - c + 4]);
+
         vacc0_lo = vcvtnq_s32_f32(
-            vmulq_f32(vcvtq_f32_s32(vacc0_lo), requantization_scale_v));
+            vmulq_f32(vcvtq_f32_s32(vacc0_lo), requantization_scale_v_lo));
         vacc0_hi = vcvtnq_s32_f32(
-            vmulq_f32(vcvtq_f32_s32(vacc0_hi), requantization_scale_v));
+            vmulq_f32(vcvtq_f32_s32(vacc0_hi), requantization_scale_v_hi));
         vacc1_lo = vcvtnq_s32_f32(
-            vmulq_f32(vcvtq_f32_s32(vacc1_lo), requantization_scale_v));
+            vmulq_f32(vcvtq_f32_s32(vacc1_lo), requantization_scale_v_lo));
         vacc1_hi = vcvtnq_s32_f32(
-            vmulq_f32(vcvtq_f32_s32(vacc1_hi), requantization_scale_v));
+            vmulq_f32(vcvtq_f32_s32(vacc1_hi), requantization_scale_v_hi));
         vacc2_lo = vcvtnq_s32_f32(
-            vmulq_f32(vcvtq_f32_s32(vacc2_lo), requantization_scale_v));
+            vmulq_f32(vcvtq_f32_s32(vacc2_lo), requantization_scale_v_lo));
         vacc2_hi = vcvtnq_s32_f32(
-            vmulq_f32(vcvtq_f32_s32(vacc2_hi), requantization_scale_v));
+            vmulq_f32(vcvtq_f32_s32(vacc2_hi), requantization_scale_v_hi));
 
         const int16x8_t vacc0 = vqaddq_s16(
             vqmovn_high_s32(vqmovn_s32(vacc0_lo), vacc0_hi),
@@ -321,6 +341,8 @@
         i14 -= c_predecrement;
         i24 -= c_predecrement;
 
+        const uint8x8_t vkernel_zero_point =
+            vld1_u8(&quantization_params->neon.kernel_zero_points[channels - c]);
         int32x4_t vacc0_lo = vld1q_s32(w);
         w = (void*)((uintptr_t)w + sizeof(int32x4_t));
         int32x4_t vacc0_hi = vld1q_s32(w);
@@ -516,18 +538,23 @@
             vmlal_s16(vacc2_lo, vget_low_s16(vxk22), vget_low_s16(vxi24));
         vacc2_hi = vmlal_high_s16(vacc2_hi, vxk22, vxi24);
 
+        const float32x4_t requantization_scale_v_lo =
+            vld1q_f32(&quantization_params->neon.requantization_scales[channels - c]);
+        const float32x4_t requantization_scale_v_hi =
+            vld1q_f32(&quantization_params->neon.requantization_scales[channels - c + 4]);
+
         vacc0_lo = vcvtnq_s32_f32(
-            vmulq_f32(vcvtq_f32_s32(vacc0_lo), requantization_scale_v));
+            vmulq_f32(vcvtq_f32_s32(vacc0_lo), requantization_scale_v_lo));
         vacc0_hi = vcvtnq_s32_f32(
-            vmulq_f32(vcvtq_f32_s32(vacc0_hi), requantization_scale_v));
+            vmulq_f32(vcvtq_f32_s32(vacc0_hi), requantization_scale_v_hi));
         vacc1_lo = vcvtnq_s32_f32(
-            vmulq_f32(vcvtq_f32_s32(vacc1_lo), requantization_scale_v));
+            vmulq_f32(vcvtq_f32_s32(vacc1_lo), requantization_scale_v_lo));
         vacc1_hi = vcvtnq_s32_f32(
-            vmulq_f32(vcvtq_f32_s32(vacc1_hi), requantization_scale_v));
+            vmulq_f32(vcvtq_f32_s32(vacc1_hi), requantization_scale_v_hi));
         vacc2_lo = vcvtnq_s32_f32(
-            vmulq_f32(vcvtq_f32_s32(vacc2_lo), requantization_scale_v));
+            vmulq_f32(vcvtq_f32_s32(vacc2_lo), requantization_scale_v_lo));
         vacc2_hi = vcvtnq_s32_f32(
-            vmulq_f32(vcvtq_f32_s32(vacc2_hi), requantization_scale_v));
+            vmulq_f32(vcvtq_f32_s32(vacc2_hi), requantization_scale_v_hi));
 
         const int16x8_t vacc0 = vqaddq_s16(
             vqmovn_high_s32(vqmovn_s32(vacc0_lo), vacc0_hi),
@@ -622,6 +649,8 @@
     size_t c = channels;
     const void* w = weights;
     for (; c >= 8; c -= 8) {
+      const uint8x8_t vkernel_zero_point =
+          vld1_u8(&quantization_params->neon.kernel_zero_points[channels - c]);
       int32x4_t vaccX1_lo = vld1q_s32(w);
       w = (void*)((uintptr_t)w + sizeof(int32x4_t));
       int32x4_t vaccX1_hi = vld1q_s32(w);
@@ -737,10 +766,15 @@
       int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo);
       int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi);
 
+      const float32x4_t requantization_scale_v_lo =
+          vld1q_f32(&quantization_params->neon.requantization_scales[channels - c]);
+      const float32x4_t requantization_scale_v_hi =
+          vld1q_f32(&quantization_params->neon.requantization_scales[channels - c + 4]);
+
       const float32x4_t vacc_lo_f =
-        vmulq_f32(vcvtq_f32_s32(vacc_lo), requantization_scale_v);
+        vmulq_f32(vcvtq_f32_s32(vacc_lo), requantization_scale_v_lo);
       const float32x4_t vacc_hi_f =
-        vmulq_f32(vcvtq_f32_s32(vacc_hi), requantization_scale_v);
+        vmulq_f32(vcvtq_f32_s32(vacc_hi), requantization_scale_v_hi);
 
 #ifdef __aarch64__
       vacc_lo = vcvtnq_s32_f32(vacc_lo_f);
@@ -783,6 +817,8 @@
       i7 -= c_predecrement;
       i8 -= c_predecrement;
 
+      const uint8x8_t vkernel_zero_point =
+          vld1_u8(&quantization_params->neon.kernel_zero_points[channels - c]);
       int32x4_t vaccX1_lo = vld1q_s32(w);
       w = (void*)((uintptr_t)w + sizeof(int32x4_t));
       int32x4_t vaccX1_hi = vld1q_s32(w);
@@ -897,10 +933,15 @@
       int32x4_t vacc_lo = vaddq_s32(vaccX0_lo, vaccX1_lo);
       int32x4_t vacc_hi = vaddq_s32(vaccX0_hi, vaccX1_hi);
 
+      const float32x4_t requantization_scale_v_lo =
+          vld1q_f32(&quantization_params->neon.requantization_scales[channels - c]);
+      const float32x4_t requantization_scale_v_hi =
+          vld1q_f32(&quantization_params->neon.requantization_scales[channels - c + 4]);
+
       const float32x4_t vacc_lo_f =
-        vmulq_f32(vcvtq_f32_s32(vacc_lo), requantization_scale_v);
+        vmulq_f32(vcvtq_f32_s32(vacc_lo), requantization_scale_v_lo);
       const float32x4_t vacc_hi_f =
-        vmulq_f32(vcvtq_f32_s32(vacc_hi), requantization_scale_v);
+        vmulq_f32(vcvtq_f32_s32(vacc_hi), requantization_scale_v_hi);
 
 #ifdef __aarch64__
       vacc_lo = vcvtnq_s32_f32(vacc_lo_f);

@kimishpatel
Copy link
Contributor Author

For up8x9-sse2-per-channel kernel:

--- up8x9-sse2.c	2020-05-15 10:38:21.150858349 -0700
+++ up8x9-sse2-per-channel.c	2020-05-15 11:03:26.416501529 -0700
@@ -11,7 +11,7 @@
 #include <qnnpack/q8dwconv.h>
 #include <requantization/runtime-sse2.h>
 
-void pytorch_q8dwconv_ukernel_up8x9__sse2(
+void pytorch_q8dwconv_ukernel_up8x9_per_channel__sse2(
     size_t channels,
     size_t output_width,
     const uint8_t** input,
@@ -23,8 +23,6 @@
         quantization_params[RESTRICT_STATIC 1]) {
   const __m128i va_zero_point = _mm_load_si128(
       (const __m128i*)quantization_params->sse2.input_zero_point);
-  const __m128i vkernel_zero_point = _mm_set1_epi16(
-      quantization_params->sse2.kernel_zero_points[0]);
   const __m128i vzero = _mm_setzero_si128();
 
   do {
@@ -45,6 +43,9 @@
     for (; c >= 8; c -= 8) {
       __m128i vacc_lo = _mm_loadu_si128((const __m128i*)w);
       __m128i vacc_hi = _mm_loadu_si128((const __m128i*)((uintptr_t)w + 16));
+      const __m128i vkernel_zero_point = _mm_loadl_epi64(
+          (const __m128i*)
+          &quantization_params->sse2.kernel_zero_points[channels - c]);
 
       const __m128i vi0 = _mm_loadl_epi64((const __m128i*)i0);
       i0 += 8;
@@ -52,7 +53,9 @@
           sub_zero_point(_mm_unpacklo_epi8(vi0, vzero), va_zero_point);
       const __m128i vk0 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 32));
       const __m128i vxk0 =
-          _mm_sub_epi16(_mm_unpacklo_epi8(vk0, vzero), vkernel_zero_point);
+          _mm_sub_epi16(
+              _mm_unpacklo_epi8(vk0, vzero),
+              _mm_unpacklo_epi8(vkernel_zero_point, vzero));
       const __m128i vprod0_odd = _mm_mullo_epi16(vxi0, vxk0);
       const __m128i vprod0_even = _mm_mulhi_epi16(vxi0, vxk0);
       vacc_lo =
@@ -66,7 +69,9 @@
           sub_zero_point(_mm_unpacklo_epi8(vi1, vzero), va_zero_point);
       const __m128i vk1 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 40));
       const __m128i vxk1 =
-          _mm_sub_epi16(_mm_unpacklo_epi8(vk1, vzero), vkernel_zero_point);
+          _mm_sub_epi16(
+              _mm_unpacklo_epi8(vk1, vzero),
+              _mm_unpacklo_epi8(vkernel_zero_point, vzero));
       const __m128i vprod1_odd = _mm_mullo_epi16(vxi1, vxk1);
       const __m128i vprod1_even = _mm_mulhi_epi16(vxi1, vxk1);
       vacc_lo =
@@ -80,7 +85,9 @@
           sub_zero_point(_mm_unpacklo_epi8(vi2, vzero), va_zero_point);
       const __m128i vk2 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 48));
       const __m128i vxk2 =
-          _mm_sub_epi16(_mm_unpacklo_epi8(vk2, vzero), vkernel_zero_point);
+          _mm_sub_epi16(
+              _mm_unpacklo_epi8(vk2, vzero),
+              _mm_unpacklo_epi8(vkernel_zero_point, vzero));
       const __m128i vprod2_odd = _mm_mullo_epi16(vxi2, vxk2);
       const __m128i vprod2_even = _mm_mulhi_epi16(vxi2, vxk2);
       vacc_lo =
@@ -94,7 +101,9 @@
           sub_zero_point(_mm_unpacklo_epi8(vi3, vzero), va_zero_point);
       const __m128i vk3 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 56));
       const __m128i vxk3 =
-          _mm_sub_epi16(_mm_unpacklo_epi8(vk3, vzero), vkernel_zero_point);
+          _mm_sub_epi16(
+              _mm_unpacklo_epi8(vk3, vzero),
+              _mm_unpacklo_epi8(vkernel_zero_point, vzero));
       const __m128i vprod3_odd = _mm_mullo_epi16(vxi3, vxk3);
       const __m128i vprod3_even = _mm_mulhi_epi16(vxi3, vxk3);
       vacc_lo =
@@ -108,7 +117,9 @@
           sub_zero_point(_mm_unpacklo_epi8(vi4, vzero), va_zero_point);
       const __m128i vk4 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 64));
       const __m128i vxk4 =
-          _mm_sub_epi16(_mm_unpacklo_epi8(vk4, vzero), vkernel_zero_point);
+          _mm_sub_epi16(
+              _mm_unpacklo_epi8(vk4, vzero),
+              _mm_unpacklo_epi8(vkernel_zero_point, vzero));
       const __m128i vprod4_odd = _mm_mullo_epi16(vxi4, vxk4);
       const __m128i vprod4_even = _mm_mulhi_epi16(vxi4, vxk4);
       vacc_lo =
@@ -122,7 +133,9 @@
           sub_zero_point(_mm_unpacklo_epi8(vi5, vzero), va_zero_point);
       const __m128i vk5 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 72));
       const __m128i vxk5 =
-          _mm_sub_epi16(_mm_unpacklo_epi8(vk5, vzero), vkernel_zero_point);
+          _mm_sub_epi16(
+              _mm_unpacklo_epi8(vk5, vzero),
+              _mm_unpacklo_epi8(vkernel_zero_point, vzero));
       const __m128i vprod5_odd = _mm_mullo_epi16(vxi5, vxk5);
       const __m128i vprod5_even = _mm_mulhi_epi16(vxi5, vxk5);
       vacc_lo =
@@ -136,7 +149,9 @@
           sub_zero_point(_mm_unpacklo_epi8(vi6, vzero), va_zero_point);
       const __m128i vk6 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 80));
       const __m128i vxk6 =
-          _mm_sub_epi16(_mm_unpacklo_epi8(vk6, vzero), vkernel_zero_point);
+          _mm_sub_epi16(
+              _mm_unpacklo_epi8(vk6, vzero),
+              _mm_unpacklo_epi8(vkernel_zero_point, vzero));
       const __m128i vprod6_odd = _mm_mullo_epi16(vxi6, vxk6);
       const __m128i vprod6_even = _mm_mulhi_epi16(vxi6, vxk6);
       vacc_lo =
@@ -150,7 +165,9 @@
           sub_zero_point(_mm_unpacklo_epi8(vi7, vzero), va_zero_point);
       const __m128i vk7 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 88));
       const __m128i vxk7 =
-          _mm_sub_epi16(_mm_unpacklo_epi8(vk7, vzero), vkernel_zero_point);
+          _mm_sub_epi16(
+              _mm_unpacklo_epi8(vk7, vzero),
+              _mm_unpacklo_epi8(vkernel_zero_point, vzero));
       const __m128i vprod7_odd = _mm_mullo_epi16(vxi7, vxk7);
       const __m128i vprod7_even = _mm_mulhi_epi16(vxi7, vxk7);
       vacc_lo =
@@ -164,7 +181,9 @@
           sub_zero_point(_mm_unpacklo_epi8(vi8, vzero), va_zero_point);
       const __m128i vk8 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 96));
       const __m128i vxk8 =
-          _mm_sub_epi16(_mm_unpacklo_epi8(vk8, vzero), vkernel_zero_point);
+          _mm_sub_epi16(
+              _mm_unpacklo_epi8(vk8, vzero),
+              _mm_unpacklo_epi8(vkernel_zero_point, vzero));
       const __m128i vprod8_odd = _mm_mullo_epi16(vxi8, vxk8);
       const __m128i vprod8_even = _mm_mulhi_epi16(vxi8, vxk8);
       vacc_lo =
@@ -174,19 +193,21 @@
 
       w = (void*)((uintptr_t)w + 104);
 
-      const __m128 vmultiplier =
-          _mm_set1_ps(quantization_params->sse2.requantization_scales[0]);
+      const __m128 vmultiplier_lo =
+          _mm_loadu_ps(&quantization_params->sse2.requantization_scales[channels - c]);
+      const __m128 vmultiplier_hi =
+          _mm_loadu_ps(&quantization_params->sse2.requantization_scales[channels - c + 4]);
 
       vacc_lo = _mm_cvtps_epi32(
                     _mm_mul_ps(
                       _mm_cvtepi32_ps(vacc_lo),
-                      vmultiplier
+                      vmultiplier_lo
                       )
                     );
       vacc_hi = _mm_cvtps_epi32(
                     _mm_mul_ps(
                       _mm_cvtepi32_ps(vacc_hi),
-                      vmultiplier
+                      vmultiplier_hi
                       )
                     );
 
@@ -208,6 +229,9 @@
     if (c != 0) {
       const size_t i_predecrement = 8 - c;
       const __m128i vi_shift = _mm_cvtsi32_si128(8 * i_predecrement);
+      const __m128i vkernel_zero_point = _mm_loadl_epi64(
+          (const __m128i*)
+          &quantization_params->sse2.kernel_zero_points[channels - c]);
       i0 -= i_predecrement;
       i1 -= i_predecrement;
       i2 -= i_predecrement;
@@ -227,7 +251,9 @@
           sub_zero_point(_mm_unpacklo_epi8(vi0, vzero), va_zero_point);
       const __m128i vk0 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 32));
       const __m128i vxk0 =
-          _mm_sub_epi16(_mm_unpacklo_epi8(vk0, vzero), vkernel_zero_point);
+          _mm_sub_epi16(
+              _mm_unpacklo_epi8(vk0, vzero),
+              _mm_unpacklo_epi8(vkernel_zero_point, vzero));
       const __m128i vprod0_odd = _mm_mullo_epi16(vxi0, vxk0);
       const __m128i vprod0_even = _mm_mulhi_epi16(vxi0, vxk0);
       vacc_lo =
@@ -241,7 +267,9 @@
           sub_zero_point(_mm_unpacklo_epi8(vi1, vzero), va_zero_point);
       const __m128i vk1 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 40));
       const __m128i vxk1 =
-          _mm_sub_epi16(_mm_unpacklo_epi8(vk1, vzero), vkernel_zero_point);
+          _mm_sub_epi16(
+              _mm_unpacklo_epi8(vk1, vzero),
+              _mm_unpacklo_epi8(vkernel_zero_point, vzero));
       const __m128i vprod1_odd = _mm_mullo_epi16(vxi1, vxk1);
       const __m128i vprod1_even = _mm_mulhi_epi16(vxi1, vxk1);
       vacc_lo =
@@ -255,7 +283,9 @@
           sub_zero_point(_mm_unpacklo_epi8(vi2, vzero), va_zero_point);
       const __m128i vk2 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 48));
       const __m128i vxk2 =
-          _mm_sub_epi16(_mm_unpacklo_epi8(vk2, vzero), vkernel_zero_point);
+          _mm_sub_epi16(
+              _mm_unpacklo_epi8(vk2, vzero),
+              _mm_unpacklo_epi8(vkernel_zero_point, vzero));
       const __m128i vprod2_odd = _mm_mullo_epi16(vxi2, vxk2);
       const __m128i vprod2_even = _mm_mulhi_epi16(vxi2, vxk2);
       vacc_lo =
@@ -269,7 +299,9 @@
           sub_zero_point(_mm_unpacklo_epi8(vi3, vzero), va_zero_point);
       const __m128i vk3 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 56));
       const __m128i vxk3 =
-          _mm_sub_epi16(_mm_unpacklo_epi8(vk3, vzero), vkernel_zero_point);
+          _mm_sub_epi16(
+              _mm_unpacklo_epi8(vk3, vzero),
+              _mm_unpacklo_epi8(vkernel_zero_point, vzero));
       const __m128i vprod3_odd = _mm_mullo_epi16(vxi3, vxk3);
       const __m128i vprod3_even = _mm_mulhi_epi16(vxi3, vxk3);
       vacc_lo =
@@ -283,7 +315,9 @@
           sub_zero_point(_mm_unpacklo_epi8(vi4, vzero), va_zero_point);
       const __m128i vk4 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 64));
       const __m128i vxk4 =
-          _mm_sub_epi16(_mm_unpacklo_epi8(vk4, vzero), vkernel_zero_point);
+          _mm_sub_epi16(
+              _mm_unpacklo_epi8(vk4, vzero),
+              _mm_unpacklo_epi8(vkernel_zero_point, vzero));
       const __m128i vprod4_odd = _mm_mullo_epi16(vxi4, vxk4);
       const __m128i vprod4_even = _mm_mulhi_epi16(vxi4, vxk4);
       vacc_lo =
@@ -297,7 +331,9 @@
           sub_zero_point(_mm_unpacklo_epi8(vi5, vzero), va_zero_point);
       const __m128i vk5 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 72));
       const __m128i vxk5 =
-          _mm_sub_epi16(_mm_unpacklo_epi8(vk5, vzero), vkernel_zero_point);
+          _mm_sub_epi16(
+              _mm_unpacklo_epi8(vk5, vzero),
+              _mm_unpacklo_epi8(vkernel_zero_point, vzero));
       const __m128i vprod5_odd = _mm_mullo_epi16(vxi5, vxk5);
       const __m128i vprod5_even = _mm_mulhi_epi16(vxi5, vxk5);
       vacc_lo =
@@ -311,7 +347,9 @@
           sub_zero_point(_mm_unpacklo_epi8(vi6, vzero), va_zero_point);
       const __m128i vk6 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 80));
       const __m128i vxk6 =
-          _mm_sub_epi16(_mm_unpacklo_epi8(vk6, vzero), vkernel_zero_point);
+          _mm_sub_epi16(
+              _mm_unpacklo_epi8(vk6, vzero),
+              _mm_unpacklo_epi8(vkernel_zero_point, vzero));
       const __m128i vprod6_odd = _mm_mullo_epi16(vxi6, vxk6);
       const __m128i vprod6_even = _mm_mulhi_epi16(vxi6, vxk6);
       vacc_lo =
@@ -325,7 +363,9 @@
           sub_zero_point(_mm_unpacklo_epi8(vi7, vzero), va_zero_point);
       const __m128i vk7 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 88));
       const __m128i vxk7 =
-          _mm_sub_epi16(_mm_unpacklo_epi8(vk7, vzero), vkernel_zero_point);
+          _mm_sub_epi16(
+              _mm_unpacklo_epi8(vk7, vzero),
+              _mm_unpacklo_epi8(vkernel_zero_point, vzero));
       const __m128i vprod7_odd = _mm_mullo_epi16(vxi7, vxk7);
       const __m128i vprod7_even = _mm_mulhi_epi16(vxi7, vxk7);
       vacc_lo =
@@ -339,7 +379,9 @@
           sub_zero_point(_mm_unpacklo_epi8(vi8, vzero), va_zero_point);
       const __m128i vk8 = _mm_loadl_epi64((const __m128i*)((uintptr_t)w + 96));
       const __m128i vxk8 =
-          _mm_sub_epi16(_mm_unpacklo_epi8(vk8, vzero), vkernel_zero_point);
+          _mm_sub_epi16(
+              _mm_unpacklo_epi8(vk8, vzero),
+              _mm_unpacklo_epi8(vkernel_zero_point, vzero));
       const __m128i vprod8_odd = _mm_mullo_epi16(vxi8, vxk8);
       const __m128i vprod8_even = _mm_mulhi_epi16(vxi8, vxk8);
       vacc_lo =
@@ -347,19 +389,21 @@
       vacc_hi =
           _mm_add_epi32(vacc_hi, _mm_unpackhi_epi16(vprod8_odd, vprod8_even));
 
-      const __m128 vmultiplier =
-          _mm_set1_ps(quantization_params->sse2.requantization_scales[0]);
+      const __m128 vmultiplier_lo =
+          _mm_loadu_ps(&quantization_params->sse2.requantization_scales[channels - c]);
+      const __m128 vmultiplier_hi =
+          _mm_loadu_ps(&quantization_params->sse2.requantization_scales[channels - c + 4]);
 
       vacc_lo = _mm_cvtps_epi32(
                     _mm_mul_ps(
                       _mm_cvtepi32_ps(vacc_lo),
-                      vmultiplier
+                      vmultiplier_lo
                       )
                     );
       vacc_hi = _mm_cvtps_epi32(
                     _mm_mul_ps(
                       _mm_cvtepi32_ps(vacc_hi),
-                      vmultiplier
+                      vmultiplier_hi
                       )
                     );

Summary:
Due to potential perf issues with using same depthwise conv kernels for perf channel depthwise conv, we opt for replicating the kernels and adding per channel support to them.
Note that the large kernels files are largely duplication of original kernels. Main difference in the kernels is that for each iteration (over a group of output channels) of the loop we need to obtain corresponding zero point and requantization scale. Rest of the compute is the same.

Test Plan:
qnnpack tests.
q8dwconv-test
convolution-test

Differential Revision: [D21339042](https://our.internmc.facebook.com/intern/diff/D21339042)

[ghstack-poisoned]
.iterations(3)
.per_channel(true)
.testQ8(ConvolutionOperatorTester::Mode::Runtime);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we test other input sizes too apart from (15, 14)? Same comment for groups

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing it out. Original tests were similar in input size. I am not sure if it makes a big difference. I can try adding few other sizes but I try to have the same base covered that the original tests did.

@@ -205,14 +207,24 @@ class DWConvMicrokernelTester {

std::fill(packedWeights.begin(), packedWeights.end(), 0xA5);

size_t num_zero_points_padded = channels() + 8;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the +8 for alignment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since access can be out of bounds otherwise in corner cases.


if (per_channel && ukernel_type == pytorch_qnnp_ukernel_type_xzp_gemm) {
pytorch_qnnp_log_error(
"Per channel quantized weights are not supported for XZP kernesl");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kernesl -> kernels

size_t c = channels;
for (; c >= 8; c -= 8) {
const uint8x8_t vkernel_zero_point =
vld1_u8(&quantization_params->neon.kernel_zero_points[channels - c]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ARM has very good support for free indexing as part of the memory operations as we have seen in the assembly so I am not sure whether this will be a win, but one common practice in optimization is to aggressively factor out invariables out of loops - in this case:

const uint8_t* const kernel_channel_zero_points = quantization_params->neon.kernel_zero_points + channels;

for ( ... ) {
  vld1_u8(kernel_channel_zero_points - c);
   ...
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AshkanAliabadi so do you feel that compiler cannot optimize this? I am assuming you dont actually end up accessing the params and then offset into it to get to kernel_zero_point and then do indexing on it every time. If that is the case I can see that this will be inefficient, but things like this which are loop invariant, at least base pointer kernel_zero_point, should be moved outside the loop. I dont have much experience in terms of how efficient compilers are for such things, but thats what I am relying on.
And yes without looking at disassembly you wouldn't know if compiler is able to move it out or not.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyways, if you prefer I can make that change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I generally don't trust the compiler to do the right thing but that might be an outdated line of thinking. I think the code is perfectly fine as is. Generally speaking all my comments are suggestions of possible alternate implementations which may or may not be advantageous. Please feel free to pick and choose as you see fit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are simple enough changes, so I made them anyway.

const float32x4_t requantization_scale_v_lo =
vld1q_f32(&quantization_params->neon.requantization_scales[channels - c]);
const float32x4_t requantization_scale_v_hi =
vld1q_f32(&quantization_params->neon.requantization_scales[channels - c + 4]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise:

const float * const quantization_scales_channel = quantization_params->neon.requantization_scales + channels

for (...) {
  const float * const ptr = quantization_scales_channel - c;
  vld1q_f32(ptr)
  vld1q_f32(ptr + 4);
}

Only looking at the disassembly can prove whether this is a win.

const __m128i vxk11 =
_mm_sub_epi16(
_mm_unpacklo_epi8(vk11, vzero),
_mm_unpacklo_epi8(vkernel_zero_point, vzero));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _mm_unpacklo_epi8(vkernel_zero_point, vzero)s are repeated several times below it seems. Does anything change between each invocation? If not, you can consider factoring them out to the outer most loop at a point where the inputs to _mm_unpacklo_epi8 first become available.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are not loop invariant since verkenel_zero_point is loaded every iteration, but yes within the same loop they dont need to be repeated. I am assuming that compiler can atleast do some common subexpression elimination. But more honestly since this is just sse2 code, I did not worry whether compiler actually does this or not.

STR r0, [sp, #-8]
STR r3, [sp, #-4]
STR r1, [sp, #-12]
STR r2, [sp, #-16]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you are pushing these four registers on the stack, right? Is there a reason you are not using PUSH?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mainly for this reason: There are other places in the code that do not do push/pops for stack manipulation. Part of that code and part of my modifications only really do read-only access to some of those stack variables, so you dont really want to pop them out.

Summary:
Due to potential perf issues with using same depthwise conv kernels for perf channel depthwise conv, we opt for replicating the kernels and adding per channel support to them.
Note that the large kernels files are largely duplication of original kernels. Main difference in the kernels is that for each iteration (over a group of output channels) of the loop we need to obtain corresponding zero point and requantization scale. Rest of the compute is the same.

Test Plan:
qnnpack tests.
q8dwconv-test
convolution-test

Differential Revision: [D21339042](https://our.internmc.facebook.com/intern/diff/D21339042)

[ghstack-poisoned]
@kimishpatel
Copy link
Contributor Author

@supriyar @AshkanAliabadi addressed your comments. @supriyar, I chose to not add more test because I am not sure what we are trying to cover or uncover. Since the original tests are quite elaborate I have kept only those. I will give this a little thought to see if there are any new corner cases that may have been introduced here that were not present before and hence not covered.

@AshkanAliabadi AshkanAliabadi self-requested a review May 19, 2020 02:58
@AshkanAliabadi
Copy link
Contributor

Sorry didn't mean to self-request a review. :) Meant to approve.

Summary:
Due to potential perf issues with using same depthwise conv kernels for perf channel depthwise conv, we opt for replicating the kernels and adding per channel support to them.
Note that the large kernels files are largely duplication of original kernels. Main difference in the kernels is that for each iteration (over a group of output channels) of the loop we need to obtain corresponding zero point and requantization scale. Rest of the compute is the same.

Test Plan:
qnnpack tests.
q8dwconv-test
convolution-test

Differential Revision: [D21339042](https://our.internmc.facebook.com/intern/diff/D21339042)

[ghstack-poisoned]
Summary:
Due to potential perf issues with using same depthwise conv kernels for perf channel depthwise conv, we opt for replicating the kernels and adding per channel support to them.
Note that the large kernels files are largely duplication of original kernels. Main difference in the kernels is that for each iteration (over a group of output channels) of the loop we need to obtain corresponding zero point and requantization scale. Rest of the compute is the same.

Test Plan:
qnnpack tests.
q8dwconv-test
convolution-test

Differential Revision: [D21339042](https://our.internmc.facebook.com/intern/diff/D21339042)

[ghstack-poisoned]
Summary:
Due to potential perf issues with using same depthwise conv kernels for perf channel depthwise conv, we opt for replicating the kernels and adding per channel support to them.
Note that the large kernels files are largely duplication of original kernels. Main difference in the kernels is that for each iteration (over a group of output channels) of the loop we need to obtain corresponding zero point and requantization scale. Rest of the compute is the same.

Test Plan:
qnnpack tests.
q8dwconv-test
convolution-test

Differential Revision: [D21339042](https://our.internmc.facebook.com/intern/diff/D21339042)

[ghstack-poisoned]
@facebook-github-bot
Copy link
Contributor

@kimishpatel merged this pull request in 1c9a110.

@facebook-github-bot facebook-github-bot deleted the gh/kimishpatel/15/head branch May 24, 2020 14:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants