From b3e4950e8e7b80d72423b3cde8a1f14b76c2b121 Mon Sep 17 00:00:00 2001 From: Adam Straw Date: Wed, 20 Oct 2021 08:14:07 -0700 Subject: [PATCH] Adjust Hexagon conv2d schedule to split channel out (k) and move to outer loop (#9287) * Adjust Hexagon conv2d schedule to split channel out (k) and move to outermost loop * add missing reference data verify --- tests/python/contrib/test_hexagon/README.md | 448 +++++++++++------- .../test_hexagon/test_conv2d_blocked.py | 86 +++- 2 files changed, 346 insertions(+), 188 deletions(-) diff --git a/tests/python/contrib/test_hexagon/README.md b/tests/python/contrib/test_hexagon/README.md index 1d6a298d48d6..a47c3438bf57 100644 --- a/tests/python/contrib/test_hexagon/README.md +++ b/tests/python/contrib/test_hexagon/README.md @@ -29,14 +29,14 @@ Documents manual TE schedule to illustrate Hexagon operator slicing. * Added spacing and line breaks * Naming conventions * Using input (instead of activation) - * Using kernel (instead of weight, filter) + * Using filter (instead of weight, kernel) * Using `k` to denote channel-out and `c` or `rc` (reduction channel) to denote channel-in - * Using `rh` and `rw` (reduction height / width) to denote kernel height and width + * Using `rh` and `rw` (reduction height / width) to denote filter height and width # Calling Convention TODO: Map this packed string to parameters -conv2d_packed_filter-1-1-0-float32-1-1-64-64-64-llvm +conv2d_packed_filter-1-1-0-float32-1-1-1-64-64-128-llvm # Baseline conv2d @@ -44,70 +44,80 @@ This is a baseline 1x1 conv2d schedule for Hexagon. ## Command -pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[conv2d_packed_filter-1-1-0-float32-1-1-64-64-64-llvm]" +pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[conv2d_packed_filter-1-1-0-float32-1-1-1-64-64-128-llvm]" ## Parameters | Parameter | Value | | --------- | ----------- | | Batch | 1 | -| Kernel | 1x1 | +| Filter | 1x1 | | Spatial | 64x64 | | Input Ch | 64 | -| Output Ch | 64 | +| Output Ch | 128 | | Stride | 1 | | Padding | 0 | | Layout | NHWC8h8w32c | ## Assumptions -* Microkernels will compute "full depth" in channel-out (k) dimension. - * The compute schedule (see TIR below) - * Places the outer channel-out loop over `ko` inside the outer width loop over `wo` - * Encodes the assumption that Hexagon microkernels will compute "full depth" in the channel-out (k) dimension +* Pattern matching for microkernels is not senstive to cache reads and writes between the outer height (ho) and outer width (wo) loops. ## To Do -* Adjust compute schedule and add kernel cache read once Hexagon microkernel semantics are understood - +* n/a + ## Annotated TIR ``` -primfn(input_handle: handle, kernel_handle: handle, output_handle: handle) -> () +primfn(input_handle: handle, filter_handle: handle, output_handle: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "default_function", "tir.noalias": True, "target": meta[Target][0]} - buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, [1, 8, 8, 2, 8, 8, 32], []), // NHWC8h8w32c - kernel_buffer: Buffer(kernel_pointer: Pointer(float32), float32, [2, 2, 1, 1, 8, 32, 4], []), // OIHW8i32o4i - input_buffer: Buffer(input_pointer: Pointer(float32), float32, [1, 64, 64, 64], [])} // NHWC (pending layout RFC) - buffer_map = {input_handle: input_buffer, kernel_handle: kernel_buffer, output_handle: output_buffer} { - + buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, [1, 8, 8, 4, 8, 8, 32], []), // NHWC8h8w32c + filter_buffer: Buffer(filter_pointer: Pointer(float32), float32, [4, 2, 1, 1, 8, 32, 4], []), // OIHW8i32o4i + input_buffer: Buffer(input_pointer: Pointer(float32), float32, [1, 64, 64, 64], [])} // NHWC (pending RFC) + buffer_map = {input_handle: input_buffer, filter_handle: filter_buffer, output_handle: output_buffer} { allocate(input.cache: Pointer(global float32), float32, [32768]), storage_scope = global; - allocate(output.cache: Pointer(global float32), float32, [32768]), storage_scope = global; + allocate(filter.cache: Pointer(global float32), float32, [2048]), storage_scope = global; + allocate(output.cache: Pointer(global float32), float32, [16384]), storage_scope = global; + + for (ko.outer: int32, 0, 4) { + for (ho.outer: int32, 0, 8) { + + // input cache read + // NHWC -> NHWC8h8w32c (pending RFC) + for (wo: int32, 0, 8) { + for (co: int32, 0, 2) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ci: int32, 0, 32) { + input.cache[(((((wo*4096) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = + (float32*)input_pointer[((((((ho.outer*32768) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)] + } + } + } + } + } - for (ho.outer: int32, 0, 8) { - // cache read - // NHWC -> NHWC8h8w32c (pending layout RFC) - for (wo: int32, 0, 8) { + // filter cache read for (co: int32, 0, 2) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (ci: int32, 0, 32) { - input.cache[(((((wo*4096) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = - (float32*)input_pointer[((((((ho.outer*32768) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)] + for (ci8: int32, 0, 8) { + for (ki: int32, 0, 32) { + for (ci4: int32, 0, 4) { + filter.cache[((((co*1024) + (ci8*128)) + (ki*4)) + ci4)] = + (float32*)filter_pointer[(((((ko.outer*2048) + (co*1024)) + (ci8*128)) + (ki*4)) + ci4)] } } } } - } - // compute - for (wo.c: int32, 0, 8) { - for (ko.c: int32, 0, 2) { - + // compute + for (wo.c: int32, 0, 8) { + // init output cache for (hi.c.init: int32, 0, 8) { for (wi.c.init: int32, 0, 8) { for (ki.c.init: int32, 0, 32) { - output.cache[(((((wo.c*4096) + (ko.c*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 + output.cache[((((wo.c*2048) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 } } } @@ -118,173 +128,220 @@ primfn(input_handle: handle, kernel_handle: handle, output_handle: handle) -> () for (wi.c: int32, 0, 8) { for (ki.c: int32, 0, 32) { for (rc.inner: int32, 0, 32) { - output.cache[(((((wo.c*4096) + (ko.c*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = + output.cache[((((wo.c*2048) + (hi.c*256)) + (wi.c*32)) + ki.c)] = ( - (float32*)output.cache[(((((wo.c*4096) + (ko.c*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + + (float32*)output.cache[((((wo.c*2048) + (hi.c*256)) + (wi.c*32)) + ki.c)] + ( (float32*)input.cache[(((((wo.c*4096) + (rc.outer*2048)) + (hi.c*256)) + (wi.c*32)) + rc.inner)] * - (float32*)kernel_pointer[(((((ko.c*2048) + (rc.outer*1024)) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))] + (float32*)filter.cache[((((rc.outer*1024) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))] ) ) } } } } - } // end rc.outer - } // end ko.c - } // end wo.c + } + } // end wo.c - // cache write - for (wo: int32, 0, 8) { - for (ko: int32, 0, 2) { + // cache write + for (wo: int32, 0, 8) { for (hi: int32, 0, 8) { for (wi: int32, 0, 8) { for (ki: int32, 0, 32) { - output_pointer[((((((ho.outer*32768) + (wo*4096)) + (ko*2048)) + (hi*256)) + (wi*32)) + ki)] = - (float32*)output.cache[(((((wo*4096) + (ko*2048)) + (hi*256)) + (wi*32)) + ki)] + output_pointer[((((((ho.outer*65536) + (wo*8192)) + (ko.outer*2048)) + (hi*256)) + (wi*32)) + ki)] = + (float32*)output.cache[((((wo*2048) + (hi*256)) + (wi*32)) + ki)] } } } } - } - } + } // end ho.outer + } // end ko.outer } ``` -# Split on Height - "Full Output Slice" +# Split on Channel Out and Height - "Full Output Slice" -Adds a new parameter `h_split` which creates a loop split on the height `h` dimension. The cache reads and writes are moved to the outer of the two loops created by that split - the loop over `ho.outer`. This increases cache usage by a factor equivalent to `h_split`. The compute is still "full width" and "full depth" in the channel-out dimension and now over multiple slices in the height `h` dimension. +Adds new parameters `k_split` and `h_split` which creates a loop split on the outer channel out `ko` and height `ho` loops creating `outer` and `inner` loops for each split. The cache reads and writes are computed at `ho.outer` which means that cache allocation grow in proportion to `k_split` and `h_split` factors. -The key changes in TIR versus the baseline are ... +The key changes in TIR versus the above are... 1) Increased cache allocations: ``` + // input cache grows by factor of h_split = 2 allocate(input.cache: Pointer(global float32), float32, [65536]), storage_scope = global; + + // filter cache grows by factor of k_split = 2 + allocate(filter.cache: Pointer(global float32), float32, [4096]), storage_scope = global; + + // output cache grows by factor of h_split * k_split = 4 allocate(output.cache: Pointer(global float32), float32, [65536]), storage_scope = global; ``` -2) The loop split on the `h` dimension: +2) Outer loop splits using k_split and h_split factors ``` - for (ho.outer: int32, 0, 4) { - for (ho.inner: int32, 0, 2) { + // ko.outer = outer loop split on ko using k_split factor + for (ko.outer: int32, 0, 2) { + // ho.outer = outer loop split on ho using h_split factor + for (ho.outer: int32, 0, 4) { +``` + +3) Inner loop splits in both cache read / write and compute schedules. This is taken from the compute schedule e.g. +``` + for (ko.c.inner: int32, 0, 2) { + for (ho.c.inner: int32, 0, 2) { ``` ## Command -pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[conv2d_packed_filter-1-1-0-float32-2-1-64-64-64-llvm]" +pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[conv2d_packed_filter-1-1-0-float32-2-2-1-64-64-128-llvm]" ## Parameters | Parameter | Value | | --------- | ----------- | | Batch | 1 | -| Kernel | 1x1 | +| Filter | 1x1 | | Spatial | 64x64 | | Input Ch | 64 | -| Output Ch | 64 | +| Output Ch | 128 | | Stride | 1 | | Padding | 0 | | Layout | NHWC8h8w32c | +| k_split | 2 | | h_split | 2 | ## Assumptions -Same as baseline +* n/a - With the loop splits on `ko` and `ho` the compute schedule is now over `ko.inner` `ho.inner` `wo` etc. This should fit the pattern matching for microkernels. ## To Do -Same as baseline +* n/a ## Annotated TIR ``` -primfn(input_handle: handle, kernel_handle: handle, output_handle: handle) -> () +primfn(input_handle: handle, filter_handle: handle, output_handle: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "default_function", "tir.noalias": True, "target": meta[Target][0]} - buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, [1, 8, 8, 2, 8, 8, 32], []), - kernel_buffer: Buffer(kernel_pointer: Pointer(float32), float32, [2, 2, 1, 1, 8, 32, 4], []), - input_buffer: Buffer(input_pointer: Pointer(float32), float32, [1, 64, 64, 64], [])} - buffer_map = {input_handle: input_buffer, kernel_handle: kernel_buffer, output_handle: output_buffer} { - - // increased cache usage due to h_split parameter + buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, [1, 8, 8, 4, 8, 8, 32], []), // NHWC8h8w32c + filter_buffer: Buffer(filter_pointer: Pointer(float32), float32, [4, 2, 1, 1, 8, 32, 4], []), // OIHW8i32o4i + input_buffer: Buffer(input_pointer: Pointer(float32), float32, [1, 64, 64, 64], [])} // NHWC (pending RFC) + buffer_map = {input_handle: input_buffer, filter_handle: filter_buffer, output_handle: output_buffer} { + + // input cache grows by factor of h_split = 2 allocate(input.cache: Pointer(global float32), float32, [65536]), storage_scope = global; + + // filter cache grows by factor of k_split = 2 + allocate(filter.cache: Pointer(global float32), float32, [4096]), storage_scope = global; + + // output cache grows by factor of h_split * k_split = 4 allocate(output.cache: Pointer(global float32), float32, [65536]), storage_scope = global; + + // ko.outer = outer loop split on ko using k_split factor + for (ko.outer: int32, 0, 2) { + // ho.outer = outer loop split on ho using h_split factor + for (ho.outer: int32, 0, 4) { + + // input cache read + // NHWC -> NHWC8h8w32c (pending RFC) + for (ho.inner: int32, 0, 2) { + for (wo: int32, 0, 8) { + for (co: int32, 0, 2) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ci: int32, 0, 32) { + input.cache[((((((ho.inner*32768) + (wo*4096)) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = + (float32*)input_pointer[(((((((ho.outer*65536) + (ho.inner*32768)) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)] + } + } + } + } + } + } // end ho.inner - // loop split ho.outer vs. ho.inner based on h_split parameter - for (ho.outer: int32, 0, 4) { - for (ho.inner: int32, 0, 2) { - for (wo: int32, 0, 8) { + // filter cache read + for (ko.inner: int32, 0, 2) { for (co: int32, 0, 2) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (ci: int32, 0, 32) { - input.cache[((((((ho.inner*32768) + (wo*4096)) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = - (float32*)input_pointer[(((((((ho.outer*65536) + (ho.inner*32768)) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)] + for (ci8: int32, 0, 8) { + for (ki: int32, 0, 32) { + for (ci4: int32, 0, 4) { + filter.cache[(((((ko.inner*2048) + (co*1024)) + (ci8*128)) + (ki*4)) + ci4)] = + (float32*)filter_pointer[((((((ko.outer*4096) + (ko.inner*2048)) + (co*1024)) + (ci8*128)) + (ki*4)) + ci4)] } } } } - } - } - for (ho.c.inner: int32, 0, 2) { - for (wo.c: int32, 0, 8) { - for (ko.c: int32, 0, 2) { - for (hi.c.init: int32, 0, 8) { - for (wi.c.init: int32, 0, 8) { - for (ki.c.init: int32, 0, 32) { - output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 + } // end ko.inner + + // compute + for (ko.c.inner: int32, 0, 2) { + for (ho.c.inner: int32, 0, 2) { + for (wo.c: int32, 0, 8) { + + // init output cache + for (hi.c.init: int32, 0, 8) { + for (wi.c.init: int32, 0, 8) { + for (ki.c.init: int32, 0, 32) { + output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 + } } } - } - for (rc.outer: int32, 0, 2) { - for (hi.c: int32, 0, 8) { - for (wi.c: int32, 0, 8) { - for (ki.c: int32, 0, 32) { - for (rc.inner: int32, 0, 32) { - output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = - ( - (float32*)output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + + + // convolution + for (rc.outer: int32, 0, 2) { + for (hi.c: int32, 0, 8) { + for (wi.c: int32, 0, 8) { + for (ki.c: int32, 0, 32) { + for (rc.inner: int32, 0, 32) { + output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = ( - (float32*)input.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (rc.outer*2048)) + (hi.c*256)) + (wi.c*32)) + rc.inner)] * - (float32*)kernel_pointer[(((((ko.c*2048) + (rc.outer*1024)) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))] + (float32*)output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + + ( + (float32*)input.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (rc.outer*2048)) + (hi.c*256)) + (wi.c*32)) + rc.inner)] * + (float32*)filter.cache[(((((ko.c.inner*2048) + (rc.outer*1024)) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))] + ) ) - ) + } } } } } - } - } - } - } - for (ho.inner: int32, 0, 2) { - for (wo: int32, 0, 8) { - for (ko: int32, 0, 2) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (ki: int32, 0, 32) { - output_pointer[(((((((ho.outer*65536) + (ho.inner*32768)) + (wo*4096)) + (ko*2048)) + (hi*256)) + (wi*32)) + ki)] = - (float32*)output.cache[((((((ho.inner*32768) + (wo*4096)) + (ko*2048)) + (hi*256)) + (wi*32)) + ki)] + } // end wo.c + } // end ho.c.inner + } // end ko.c.inner + + // cache write + for (ko.inner: int32, 0, 2) { + for (ho.inner: int32, 0, 2) { + for (wo: int32, 0, 8) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ki: int32, 0, 32) { + output_pointer[((((((((ho.outer*131072) + (ho.inner*65536)) + (wo*8192)) + (ko.outer*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] = + (float32*)output.cache[((((((ho.inner*32768) + (wo*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] + } } } } - } - } - } - } + } // end ho.inner + } // end ko.inner + } // end ho.outer + } // end ko.outer } ``` # 3x3 conv2d (no padding) -Change from a 1x1 kernel to a 3x3 kernel. The implication of this change is that `h_split + 1` rather than just `h_split` "full width" slices of the input are required to compute the output. This is due to the fact that the 3x3 kernel will "fall off the bottom" of the input and thus the vertically adjacent "full width" slice must be prefetched into the input cache. +Change from a 1x1 filter to a 3x3 filter. The implication of this change is that `h_split + 1` rather than just `h_split` "full width" slices of the input are required to compute the output. This is due to the fact that the 3x3 filter will "fall off the bottom" of the input and thus the vertically adjacent "full width" slice must be prefetched into the input cache. The key changes in TIR versus the above are... 1) Increased input cache size to hold the vertically adjacent slice ``` + // input cache grows to hold vertically adjacent slice allocate(input.cache: Pointer(global float32), float32, [98304]), storage_scope = global; ``` @@ -298,19 +355,33 @@ The key changes in TIR versus the above are... The `if` statement above indicates NOT to prefetch the vertically adjacent slice at the "bottom" of the input since it does not exist. + +3) Increased filter cache size to hold 3x3 filter + +``` + // filter cache grows to hold larger 3x3 filter + allocate(filter.cache: Pointer(global float32), float32, [36864]), storage_scope = global; +``` + +4) Loops over `rh` and `rw` the kernel spatial dimensions: +``` + for (rh: int32, 0, 3) { + for (rw: int32, 0, 3) { +``` + ## Command -pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[conv2d_packed_filter-3-1-0-float32-2-1-64-64-64-llvm]" +pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[conv2d_packed_filter-3-1-0-float32-2-2-1-64-64-128-llvm]" ## Parameters | Parameter | Value | | --------- | ----------- | | Batch | 1 | -| Kernel | 3x3 | +| Filter | 3x3 | | Spatial | 64x64 | | Input Ch | 64 | -| Output Ch | 64 | +| Output Ch | 128 | | Stride | 1 | | Padding | 0 | | Layout | NHWC8h8w32c | @@ -318,12 +389,10 @@ pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2d ## Assumptions -Same as above +* n/a ## To Do -Same as above, and ... - There may be some opportunity to optimize cache reuse in this case. Consider the loops over `ho.outer` and `ho.inner` and the index calculation `ho.outer * 64k + ho.inner * 32k` into the input pointer: | ho.outer | ho.inner | ho.outer * 64k + ho.inner * 32k | @@ -346,86 +415,103 @@ Noe that the vertically adjacent slice in loop N (i.e. the loop where `ho.outer` ## Annotated TIR ``` -primfn(input_handle: handle, kernel_handle: handle, output_handle: handle) -> () +primfn(input_handle: handle, filter_handle: handle, output_handle: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "default_function", "tir.noalias": True, "target": meta[Target][0]} - buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, [1, 8, 8, 2, 8, 8, 32], []), - kernel_buffer: Buffer(kernel_pointer: Pointer(float32), float32, [2, 2, 3, 3, 8, 32, 4], []), - input_buffer: Buffer(input_pointer: Pointer(float32), float32, [1, 64, 64, 64], [])} - buffer_map = {input_handle: input_buffer, kernel_handle: kernel_buffer, output_handle: output_buffer} { - - // increased input cache size to hold vertically adjacent slice + buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, [1, 8, 8, 4, 8, 8, 32], []), // NHWC8h8w32c + filter_buffer: Buffer(filter_pointer: Pointer(float32), float32, [4, 2, 3, 3, 8, 32, 4], []), // OIHW8i32o4i + input_buffer: Buffer(input_pointer: Pointer(float32), float32, [1, 64, 64, 64], [])} // NHWC (pending RFC) + buffer_map = {input_handle: input_buffer, filter_handle: filter_buffer, output_handle: output_buffer} { + // input cache grows to hold vertically adjacent slice allocate(input.cache: Pointer(global float32), float32, [98304]), storage_scope = global; + // filter cache grows to hold larger 3x3 filter + allocate(filter.cache: Pointer(global float32), float32, [36864]), storage_scope = global; allocate(output.cache: Pointer(global float32), float32, [65536]), storage_scope = global; - for (ho.outer: int32, 0, 4) { - - // iterate over h_split + 1 = 3 input slices - for (ho.inner: int32, 0, 3) { - - // don't prefetch the vertically adjacent slice at the "bottom" of the input - if (((ho.outer*2) + ho.inner) < 8) { - for (wo: int32, 0, 8) { - for (co: int32, 0, 2) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (ci: int32, 0, 32) { - input.cache[((((((ho.inner*32768) + (wo*4096)) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = - (float32*)input_pointer[(((((((ho.outer*65536) + (ho.inner*32768)) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)] + for (ko.outer: int32, 0, 2) { + for (ho.outer: int32, 0, 4) { + // input cache read + // NHWC -> NHWC8h8w32c (pending RFC) + for (ho.inner: int32, 0, 3) { + if (((ho.outer*2) + ho.inner) < 8) { + for (wo: int32, 0, 8) { + for (co: int32, 0, 2) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ci: int32, 0, 32) { + input.cache[((((((ho.inner*32768) + (wo*4096)) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = + (float32*)input_pointer[(((((((ho.outer*65536) + (ho.inner*32768)) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)] + } } } } } } } - } - for (ho.c.inner: int32, 0, 2) { - for (wo.c: int32, 0, 8) { - for (ko.c: int32, 0, 2) { - for (hi.c.init: int32, 0, 8) { - for (wi.c.init: int32, 0, 8) { - for (ki.c.init: int32, 0, 32) { - output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 + // filter cache read + for (ko.inner: int32, 0, 2) { + for (co: int32, 0, 2) { + for (rh: int32, 0, 3) { + for (rw: int32, 0, 3) { + for (ci8: int32, 0, 8) { + for (ki: int32, 0, 32) { + for (ci4: int32, 0, 4) { + filter.cache[(((((((ko.inner*18432) + (co*9216)) + (rh*3072)) + (rw*1024)) + (ci8*128)) + (ki*4)) + ci4)] = + (float32*)filter_pointer[((((((((ko.outer*36864) + (ko.inner*18432)) + (co*9216)) + (rh*3072)) + (rw*1024)) + (ci8*128)) + (ki*4)) + ci4)] + } + } + } + } // end rw + } // end rh + } + } + for (ko.c.inner: int32, 0, 2) { + for (ho.c.inner: int32, 0, 2) { + for (wo.c: int32, 0, 8) { + for (hi.c.init: int32, 0, 8) { + for (wi.c.init: int32, 0, 8) { + for (ki.c.init: int32, 0, 32) { + output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32 + } } } - } - for (rc.outer: int32, 0, 2) { - for (hi.c: int32, 0, 8) { - for (wi.c: int32, 0, 8) { - for (rh: int32, 0, 3) { - for (rw: int32, 0, 3) { - for (ki.c: int32, 0, 32) { - for (rc.inner: int32, 0, 32) { - output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = - ( - (float32*)output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + + for (rc.outer: int32, 0, 2) { + for (hi.c: int32, 0, 8) { + for (wi.c: int32, 0, 8) { + for (rh: int32, 0, 3) { + for (rw: int32, 0, 3) { + for (ki.c: int32, 0, 32) { + for (rc.inner: int32, 0, 32) { + output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = ( - (float32*)input.cache[((((((((floordiv((hi.c + rh), 8)*32768) + (ho.c.inner*32768)) + (floordiv((wi.c + rw), 8)*4096)) + (wo.c*4096)) + (rc.outer*2048)) + (floormod((hi.c + rh), 8)*256)) + (floormod((wi.c + rw), 8)*32)) + rc.inner)] * - (float32*)kernel_pointer[(((((((ko.c*18432) + (rc.outer*9216)) + (rh*3072)) + (rw*1024)) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))] + (float32*)output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + + ( + (float32*)input.cache[((((((((floordiv((hi.c + rh), 8)*32768) + (ho.c.inner*32768)) + (floordiv((wi.c + rw), 8)*4096)) + (wo.c*4096)) + (rc.outer*2048)) + (floormod((hi.c + rh), 8)*256)) + (floormod((wi.c + rw), 8)*32)) + rc.inner)] * + (float32*)filter.cache[(((((((ko.c.inner*18432) + (rc.outer*9216)) + (rh*3072)) + (rw*1024)) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))] + ) ) - ) + } } - } - } + } // end rw + } // end rh } } } - } - } - } - } - for (ho.inner: int32, 0, 2) { - for (wo: int32, 0, 8) { - for (ko: int32, 0, 2) { - for (hi: int32, 0, 8) { - for (wi: int32, 0, 8) { - for (ki: int32, 0, 32) { - output_pointer[(((((((ho.outer*65536) + (ho.inner*32768)) + (wo*4096)) + (ko*2048)) + (hi*256)) + (wi*32)) + ki)] = - (float32*)output.cache[((((((ho.inner*32768) + (wo*4096)) + (ko*2048)) + (hi*256)) + (wi*32)) + ki)] + } // end wo.c + } // end ho.c.inner + } // end ko.c.inner + for (ko.inner: int32, 0, 2) { + for (ho.inner: int32, 0, 2) { + for (wo: int32, 0, 8) { + for (hi: int32, 0, 8) { + for (wi: int32, 0, 8) { + for (ki: int32, 0, 32) { + output_pointer[((((((((ho.outer*131072) + (ho.inner*65536)) + (wo*8192)) + (ko.outer*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] = + (float32*)output.cache[((((((ho.inner*32768) + (wo*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] + } } } } - } - } - } - } -} -``` \ No newline at end of file + } // end ho.inner + } // end ko.inner + } // end ho.outer + } // end ko.outer +}``` \ No newline at end of file diff --git a/tests/python/contrib/test_hexagon/test_conv2d_blocked.py b/tests/python/contrib/test_hexagon/test_conv2d_blocked.py index 37a623b613f8..1304d341eda2 100644 --- a/tests/python/contrib/test_hexagon/test_conv2d_blocked.py +++ b/tests/python/contrib/test_hexagon/test_conv2d_blocked.py @@ -162,6 +162,7 @@ def conv2d_packed_filter( stride, padding, dtype, + k_split_factor, h_split_factor, storage_scope="global", ): @@ -263,6 +264,7 @@ def compute(n, ho, wo, ko, hi, wi, ki): # cache read for the input / activation (X) Xl = s.cache_read(X_packed, storage_scope, [Y]) + Fl = s.cache_read(filt_packed, storage_scope, [Y]) # cache write for the output (Y) Yl = s.cache_write(Y, storage_scope) @@ -277,7 +279,9 @@ def compute(n, ho, wo, ko, hi, wi, ki): # loop split h and compute cache write at outer loop split # to increase cache usage by factor of h_split_factor + koo, koi = s[Y].split(ko, factor=k_split_factor) hoo, hoi = s[Y].split(ho, factor=h_split_factor) + s[Y].reorder(n, koo, hoo, koi, hoi, wo, hi, wi, ki) s[Yl].compute_at(s[Y], hoo) #################### @@ -297,9 +301,11 @@ def compute(n, ho, wo, ko, hi, wi, ki): # loop split h and compute cache write at outer loop split # to increase cache usage by factor of h_split_factor + koo, koi = s[Yl].split(ko, factor=k_split_factor) hoo, hoi = s[Yl].split(ho, factor=h_split_factor) - s[Yl].reorder(n, hoo, hoi, wo, ko, rco, hi, wi, ki, rci) + s[Yl].reorder(n, koo, hoo, koi, hoi, wo, rco, hi, wi, ki, rci) s[Xl].compute_at(s[Yl], hoo) + s[Fl].compute_at(s[Yl], hoo) binds = {} if storage_scope and storage_scope != "global": @@ -318,6 +324,7 @@ def conv2d_packed_filter_nhwhwc( stride, padding, dtype, + k_split_factor, h_split_factor, storage_scope="global", ): @@ -406,6 +413,7 @@ def compute(n, ho, wo, hi, wi, k): # cache read for the input / activation (X) Xl = s.cache_read(X_packed, storage_scope, [Y]) + Fl = s.cache_read(filt_packed, storage_scope, [Y]) # cache write for the output (Y) Yl = s.cache_write(Y, storage_scope) @@ -423,8 +431,9 @@ def compute(n, ho, wo, hi, wi, k): # loop split h and compute cache write at outer loop split # to increase cache usage by factor of h_split_factor + koo, koi = s[Y].split(ko, factor=k_split_factor) hoo, hoi = s[Y].split(ho, factor=h_split_factor) - s[Y].reorder(n, hoo, hoi, wo, ko, hi, wi, ki) + s[Y].reorder(n, koo, hoo, koi, hoi, wo, hi, wi, ki) s[Yl].compute_at(s[Y], hoo) #################### @@ -445,9 +454,11 @@ def compute(n, ho, wo, hi, wi, k): # loop split h and compute cache write at outer loop split # to increase cache usage by factor of h_split_factor + koo, koi = s[Yl].split(ko, factor=k_split_factor) hoo, hoi = s[Yl].split(ho, factor=h_split_factor) - s[Yl].reorder(n, hoo, hoi, wo, ko, rco, hi, wi, ki, rci) + s[Yl].reorder(n, koo, hoo, koi, hoi, wo, rco, hi, wi, ki, rci) s[Xl].compute_at(s[Yl], hoo) + s[Fl].compute_at(s[Yl], hoo) ####################### # cache read schedule # @@ -474,12 +485,13 @@ def compute(n, ho, wo, hi, wi, k): class BaseConv2d: batch = tvm.testing.parameter(1) in_size = tvm.testing.parameter(8, 56, 64) - in_channel = tvm.testing.parameter(64) - out_channel = tvm.testing.parameter(64) + in_channel = tvm.testing.parameter(64, 128) + out_channel = tvm.testing.parameter(64, 128) kernel = tvm.testing.parameter(1, 3) stride = tvm.testing.parameter(1) pad = tvm.testing.parameter(0, 1) dtype = tvm.testing.parameter("float32") + k_split_factor = tvm.testing.parameter(1, 2) h_split_factor = tvm.testing.parameter(1, 2) @@ -504,7 +516,30 @@ def test_conv2d(self, shape_nhwc, shape_oihw, kernel, stride, pad, dtype, target padding=(pad, pad, pad, pad), dtype=dtype, ) - return output, ref_output + + # nhwc8h8w32c -> nhwc + output = output.transpose(0, 1, 4, 2, 5, 3, 6).reshape( + output.shape[0], + output.shape[1] * output.shape[4], + output.shape[2] * output.shape[5], + output.shape[3] * output.shape[6], + ) + + # slice output to match ref_output shape + # e.g. 8x8 spatial 3x3 filter = 6x6 ref output + # but still 8x8 output given the blocked layout + output = output[ + 0 : ref_output.shape[0] : 1, + 0 : ref_output.shape[1] : 1, + 0 : ref_output.shape[2] : 1, + 0 : ref_output.shape[3] : 1, + ] + + if "int" in dtype: + tol = {"atol": 0, "rtol": 0} + elif dtype == "float32": + tol = {"rtol": 1e-4, "atol": 2e-4} + tvm.testing.assert_allclose(output, ref_output, **tol) class TestConv2dPackedFilter(BaseConv2d): @@ -522,6 +557,7 @@ def test_conv2d( pad, dtype, target, + k_split_factor, h_split_factor, ): inputs = [ @@ -543,9 +579,45 @@ def test_conv2d( stride=(stride, stride), padding=(pad, pad, pad, pad), dtype=dtype, + k_split_factor=k_split_factor, h_split_factor=h_split_factor, ) - return output, ref_output + + # nhwc8h8w32c + if len(output.shape) == 7: + # nhwc8h8w32c -> nhwc + output = output.transpose(0, 1, 4, 2, 5, 3, 6).reshape( + output.shape[0], + output.shape[1] * output.shape[4], + output.shape[2] * output.shape[5], + output.shape[3] * output.shape[6], + ) + + # nhwhwc + else: + # nhwhwc -> nhwc + output = output.transpose(0, 1, 3, 2, 4, 5).reshape( + output.shape[0], + output.shape[1] * output.shape[3], + output.shape[2] * output.shape[4], + output.shape[5], + ) + + # slice output to match ref_output shape + # e.g. 8x8 spatial 3x3 filter = 6x6 ref output + # but still 8x8 output given the blocked layout + output = output[ + 0 : ref_output.shape[0] : 1, + 0 : ref_output.shape[1] : 1, + 0 : ref_output.shape[2] : 1, + 0 : ref_output.shape[3] : 1, + ] + + if "int" in dtype: + tol = {"atol": 0, "rtol": 0} + elif dtype == "float32": + tol = {"rtol": 1e-4, "atol": 2e-4} + tvm.testing.assert_allclose(output, ref_output, **tol) if __name__ == "__main__":