Skip to content

Commit

Permalink
[CUDA] Swap block x and z dimension for conv2d NHWC schedule (apache#…
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi authored and ylc committed Jan 13, 2022
1 parent 9e4d937 commit 0743e21
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/tvm/topi/cuda/conv2d_nhwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,14 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv):

# Schedule for output
ni, hi, wi, fi = s[output].op.axis
bz = s[output].fuse(hi, wi)
bx = s[output].fuse(hi, wi)
tx, fi = s[output].split(fi, factor=tile_c)
txz, tx = s[output].split(tx, factor=num_thread_c)
bx, txz = s[output].split(txz, factor=vthread_c)
bz, txz = s[output].split(txz, factor=vthread_c)
ty, ni = s[output].split(ni, factor=tile_n)
tyz, ty = s[output].split(ty, factor=num_thread_n)
by, tyz = s[output].split(tyz, factor=vthread_n)
s[output].reorder(bz, by, bx, tyz, txz, ty, tx, ni, fi)
s[output].reorder(bx, by, bz, tyz, txz, ty, tx, ni, fi)
s[output].bind(bz, block_z)
s[output].bind(by, block_y)
s[output].bind(bx, block_x)
Expand Down

0 comments on commit 0743e21

Please sign in to comment.