Skip to content

Commit

Permalink
Fix USMP parallel to serial loop transform test (apache#9254)
Browse files Browse the repository at this point in the history
Caused by apache#8469 being stale on merge when apache#9115 had changed the namespace for `tvm.script`.
  • Loading branch information
Mousius authored and ylc committed Jan 7, 2022
1 parent ab35722 commit 695de94
Showing 1 changed file with 20 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,31 @@
import pytest

import tvm
from tvm import tir, script
from tvm.script import ty

from tvm.script import tir as T
from tvm.tir import stmt_functor

# fmt: off
@tvm.script.tir
def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_30: ty.handle, placeholder_31: ty.handle, placeholder_32: ty.handle, T_cast_8: ty.handle) -> None:
@T.prim_func
def fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2(placeholder_30: T.handle, placeholder_31: T.handle, placeholder_32: T.handle, T_cast_8: T.handle) -> None:
# function attr dict
tir.func_attr({"global_symbol": "fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", "tir.noalias": True})
placeholder_33 = tir.match_buffer(placeholder_30, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1)
placeholder_34 = tir.match_buffer(placeholder_31, [1, 1, 192, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1)
placeholder_35 = tir.match_buffer(placeholder_32, [1, 1, 1, 16], dtype="int32", elem_offset=0, align=128, offset_factor=1)
T_cast_9 = tir.match_buffer(T_cast_8, [1, 28, 28, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1)
T.func_attr({"global_symbol": "fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_2", "tir.noalias": True})
placeholder_33 = T.match_buffer(placeholder_30, [1, 28, 28, 192], dtype="int16", elem_offset=0, align=128, offset_factor=1)
placeholder_34 = T.match_buffer(placeholder_31, [1, 1, 192, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1)
placeholder_35 = T.match_buffer(placeholder_32, [1, 1, 1, 16], dtype="int32", elem_offset=0, align=128, offset_factor=1)
T_cast_9 = T.match_buffer(T_cast_8, [1, 28, 28, 16], dtype="int16", elem_offset=0, align=128, offset_factor=1)
# body
PaddedInput_3 = tir.allocate([1, 28, 28, 192], "int16", "global")
for i0_i1_fused_3 in tir.parallel(0, 28):
for i2_3, i3_3 in tir.grid(28, 192):
tir.store(PaddedInput_3, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3), tir.load("int16", placeholder_33.data, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)), True)
for ax0_ax1_fused_ax2_fused_3 in tir.parallel(0, 784):
for ax3_2 in tir.serial(0, 16):
Conv2dOutput_3 = tir.allocate([1, 1, 1, 1], "int32", "global")
tir.store(Conv2dOutput_3, 0, 0, True)
for rc_3 in tir.serial(0, 192):
tir.store(Conv2dOutput_3, 0, (tir.load("int32", Conv2dOutput_3, 0) + (tir.cast(tir.load("int16", PaddedInput_3, ((ax0_ax1_fused_ax2_fused_3*192) + rc_3)), "int32")*tir.cast(tir.load("int16", placeholder_34.data, ((rc_3*16) + ax3_2)), "int32"))), True)
tir.store(T_cast_9.data, ((ax0_ax1_fused_ax2_fused_3*16) + ax3_2), tir.cast(tir.cast(tir.max(tir.min(tir.q_multiply_shift((tir.load("int32", Conv2dOutput_3, 0) + tir.load("int32", placeholder_35.data, ax3_2)), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True)
PaddedInput_3 = T.allocate([1, 28, 28, 192], "int16", "global")
for i0_i1_fused_3 in T.parallel(0, 28):
for i2_3, i3_3 in T.grid(28, 192):
T.store(PaddedInput_3, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3), T.load("int16", placeholder_33.data, (((i0_i1_fused_3*5376) + (i2_3*192)) + i3_3)), True)
for ax0_ax1_fused_ax2_fused_3 in T.parallel(0, 784):
for ax3_2 in T.serial(0, 16):
Conv2dOutput_3 = T.allocate([1, 1, 1, 1], "int32", "global")
T.store(Conv2dOutput_3, 0, 0, True)
for rc_3 in T.serial(0, 192):
T.store(Conv2dOutput_3, 0, (T.load("int32", Conv2dOutput_3, 0) + (T.cast(T.load("int16", PaddedInput_3, ((ax0_ax1_fused_ax2_fused_3*192) + rc_3)), "int32")*T.cast(T.load("int16", placeholder_34.data, ((rc_3*16) + ax3_2)), "int32"))), True)
T.store(T_cast_9.data, ((ax0_ax1_fused_ax2_fused_3*16) + ax3_2), T.cast(T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_3, 0) + T.load("int32", placeholder_35.data, ax3_2)), 1764006585, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True)
# fmt: on


Expand Down

0 comments on commit 695de94

Please sign in to comment.