<a href="https://colab.research.google.com/github/uwsampl/tutorial/blob/master/notebook/autotvm_conv2d_cuda.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Tuning High Performance Convolution on NVIDIA GPUs
=========================================================================
**Author**: `Lianmin Zheng <https://github.com/merrymercy>`_

Adapted by `Eddie Yan <https://github.com/eqy>`_

This is an advanced tutorial for writing high performance tunable template for
NVIDIA GPU. By running auto-tuner on this template, we can outperform the
vendor provided library CuDNN in many cases.



Please run the following block to ensure TVM is setup for *this notebook*, each notebook may have its own runtime.



In [2]:
! gsutil cp "gs://tvm-fcrc-binariesd5fce43e-8373-11e9-bfb6-0242ac1c0002/tvm.tar.gz" /tmp/tvm.tar.gz
! mkdir -p /tvm
! tar -xf /tmp/tvm.tar.gz --strip-components=4 --directory /tvm
! ls -la /tvm
# Move this block after we are done with pkg step
! bash /tvm/package.sh
import sys
sys.path.append('/tvm/python')
sys.path.append('/tvm/topi/python')

Copying gs://tvm-fcrc-binariesd5fce43e-8373-11e9-bfb6-0242ac1c0002/tvm.tar.gz...
- [1 files][112.9 MiB/112.9 MiB]                                                
Operation completed over 1 objects/112.9 MiB.                                    
total 164
drwxr-xr-x 21 root root  4096 Jun  5 00:07 .
drwxr-xr-x  1 root root  4096 Jun  5 00:07 ..
drwx------  8 root root  4096 May 31 08:14 3rdparty
drwx------ 12 root root  4096 May 31 08:14 apps
drwx------  3 root root  4096 Jun  4 09:46 build
drwx------  4 root root  4096 May 31 08:14 cmake
-rw-------  1 root root 10406 May 31 08:14 CMakeLists.txt
drwx------  6 root root  4096 May 31 08:14 conda
-rw-------  1 root root  5673 May 31 08:14 CONTRIBUTORS.md
drwx------  3 root root  4096 May 31 08:14 docker
drwx------ 11 root root  4096 May 31 08:14 docs
drwx------  4 root root  4096 May 31 08:14 golang
drwx------  3 root root  4096 May 31 08:14 include
-rw-------  1 root root 10027 May 31 08:14 Jenkinsfile
drwx------  6 root root  4096 May 31 

Import packages:

In [0]:
import logging
import sys
import numpy as np

import tvm
import topi
from topi.testing import conv2d_nchw_python

from tvm import autotvm

Step 0: Vanilla direct 2D convolution implementation without a tunable template
---------------------------------------------------------------------------------------------

Default Schedule:

In [5]:
# the last layer in resnet
N, H, W, CO, CI, KH, KW, stride, padding = 1, 7, 7, 512, 512, 3, 3, (1, 1), (1, 1)
assert N == 1, "Only consider batch_size = 1 in this template"

data = tvm.placeholder((N, CI, H, W), name='data')
kernel = tvm.placeholder((CO, CI, KH, KW), name='kernel')
conv = topi.nn.conv2d_nchw(data, kernel, stride, padding, dilation=1, out_dtype='float32')
s = tvm.create_schedule([conv.op])
print("Default Schedule:")
print(tvm.lower(s, [data, kernel, conv], simple_mode=True))

n, f, y, x = s[conv].op.axis
rc, ry, rx = s[conv].op.reduce_axis

Default Schedule:
// attr [pad_temp] storage_scope = "global"
allocate pad_temp[float32 * 41472]
produce pad_temp {
  for (i1, 0, 512) {
    for (i2, 0, 9) {
      for (i3, 0, 9) {
        pad_temp[((((i1*9) + i2)*9) + i3)] = tvm_if_then_else(((((1 <= i2) && (i2 < 8)) && (1 <= i3)) && (i3 < 8)), data[(((((i1*7) + i2)*7) + i3) + -8)], 0.000000f)
      }
    }
  }
}
produce compute {
  for (ff, 0, 512) {
    for (yy, 0, 7) {
      for (xx, 0, 7) {
        compute[((((ff*7) + yy)*7) + xx)] = 0.000000f
        for (rc, 0, 512) {
          for (ry, 0, 3) {
            for (rx, 0, 3) {
              compute[((((ff*7) + yy)*7) + xx)] = (compute[((((ff*7) + yy)*7) + xx)] + (pad_temp[((((((rc*9) + yy) + ry)*9) + xx) + rx)]*kernel[((((((ff*512) + rc)*3) + ry)*3) + rx)]))
            }
          }
        }
      }
    }
  }
}



Inline padding and create cache stages:

In [6]:
# inline padding
pad_data = s[conv].op.input_tensors[0]
s[pad_data].compute_inline()
input = data
data, raw_data = pad_data, data

output = conv
OL = s.cache_write(conv, 'local')

# create cache stage
AA = s.cache_read(data, 'shared', [OL])
WW = s.cache_read(kernel, 'shared', [OL])
AL = s.cache_read(AA, 'local', [OL])
WL = s.cache_read(WW, 'local', [OL])

print(tvm.lower(s, [input, kernel, conv], simple_mode=True))

// attr [pad_temp.shared] storage_scope = "shared"
allocate pad_temp.shared[float32 * 41472]
// attr [pad_temp.shared.local] storage_scope = "local"
allocate pad_temp.shared.local[float32 * 41472]
// attr [kernel.shared] storage_scope = "shared"
allocate kernel.shared[float32 * 2359296]
// attr [kernel.shared.local] storage_scope = "local"
allocate kernel.shared.local[float32 * 2359296]
// attr [compute.local] storage_scope = "local"
allocate compute.local[float32 * 25088]
produce pad_temp.shared {
  for (ax1, 0, 512) {
    for (ax2, 0, 9) {
      for (ax3, 0, 9) {
        pad_temp.shared[((((ax1*9) + ax2)*9) + ax3)] = tvm_if_then_else(((((1 <= ax2) && (ax2 < 8)) && (1 <= ax3)) && (ax3 < 8)), data[(((((ax1*7) + ax2)*7) + ax3) + -8)], 0.000000f)
      }
    }
  }
}
produce pad_temp.shared.local {
  for (ax1, 0, 512) {
    for (ax2, 0, 9) {
      for (ax3, 0, 9) {
        pad_temp.shared.local[((((ax1*9) + ax2)*9) + ax3)] = pad_temp.shared[((((ax1*9) + ax2)*9) + ax3)]
      }
    }
  }
}

Define and split output spatial axes:

In [7]:
# tile spatial axes
n, f, y, x = s[output].op.axis
tile_f_factors = [32, 32, 32, 1]
tile_x_factors = [1, 1, 1, 1]
tile_y_factors = [7, 7, 7, 1]

bf, vf = s[output].split(f, factor=tile_f_factors[1])
vf, tf = s[output].split(vf, factor=tile_f_factors[2])
tf, fi = s[output].split(tf, factor=tile_f_factors[3])

by, vy = s[output].split(y, factor=tile_y_factors[1])
vy, ty = s[output].split(vy, factor=tile_y_factors[2])
ty, yi = s[output].split(ty, factor=tile_y_factors[3])

bx, vx = s[output].split(x, factor=tile_x_factors[1])
vx, tx = s[output].split(vx, factor=tile_x_factors[2])
tx, xi, = s[output].split(tx, factor=tile_x_factors[3])

kernel_scope = n  # this is the scope to attach global config inside this kernel

s[output].reorder(n, bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi)
print(tvm.lower(s, [input, kernel, conv], simple_mode=True))

// attr [pad_temp.shared] storage_scope = "shared"
allocate pad_temp.shared[float32 * 41472]
// attr [pad_temp.shared.local] storage_scope = "local"
allocate pad_temp.shared.local[float32 * 41472]
// attr [kernel.shared] storage_scope = "shared"
allocate kernel.shared[float32 * 2359296]
// attr [kernel.shared.local] storage_scope = "local"
allocate kernel.shared.local[float32 * 2359296]
// attr [compute.local] storage_scope = "local"
allocate compute.local[float32 * 25088]
produce pad_temp.shared {
  for (ax1, 0, 512) {
    for (ax2, 0, 9) {
      for (ax3, 0, 9) {
        pad_temp.shared[((((ax1*9) + ax2)*9) + ax3)] = tvm_if_then_else(((((1 <= ax2) && (ax2 < 8)) && (1 <= ax3)) && (ax3 < 8)), data[(((((ax1*7) + ax2)*7) + ax3) + -8)], 0.000000f)
      }
    }
  }
}
produce pad_temp.shared.local {
  for (ax1, 0, 512) {
    for (ax2, 0, 9) {
      for (ax3, 0, 9) {
        pad_temp.shared.local[((((ax1*9) + ax2)*9) + ax3)] = pad_temp.shared[((((ax1*9) + ax2)*9) + ax3)]
      }
    }
  }
}

Bind Axes:

In [8]:
s[output].bind(bf, tvm.thread_axis("blockIdx.z"))
s[output].bind(by, tvm.thread_axis("blockIdx.y"))
s[output].bind(bx, tvm.thread_axis("blockIdx.x"))
s[output].bind(vf, tvm.thread_axis("vthread"))
s[output].bind(vy, tvm.thread_axis("vthread"))
s[output].bind(vx, tvm.thread_axis("vthread"))
s[output].bind(tf, tvm.thread_axis("threadIdx.z"))
s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
s[OL].compute_at(s[output], tx)
print(tvm.lower(s, [input, kernel, output], simple_mode=True))

// attr [pad_temp.shared] storage_scope = "shared"
allocate pad_temp.shared[float32 * 4608]
// attr [pad_temp.shared.local] storage_scope = "local"
allocate pad_temp.shared.local[float32 * 4608]
// attr [kernel.shared.local] storage_scope = "local"
allocate kernel.shared.local[float32 * 4608]
produce pad_temp.shared {
  for (ax1, 0, 512) {
    for (ax2, 0, 3) {
      for (ax3, 0, 3) {
        pad_temp.shared[((((ax1*3) + ax2)*3) + ax3)] = tvm_if_then_else((((((1 - threadIdx.y) <= ax2) && (ax2 < (8 - threadIdx.y))) && ((1 - blockIdx.x) <= ax3)) && (ax3 < (8 - blockIdx.x))), data[(((((((ax1*7) + ax2) + threadIdx.y)*7) + ax3) + blockIdx.x) + -8)], 0.000000f)
      }
    }
  }
}
produce pad_temp.shared.local {
  for (ax1, 0, 512) {
    for (ax2, 0, 3) {
      for (ax3, 0, 3) {
        pad_temp.shared.local[((((ax1*3) + ax2)*3) + ax3)] = pad_temp.shared[((((ax1*3) + ax2)*3) + ax3)]
      }
    }
  }
}
produce kernel.shared {
  for (ax1, 0, 512) {
    for (ax2, 0, 3) {
      for (ax3, 0, 3) 

Tile reduction and define shared memory load location:

In [9]:
# tile reduction axes
n, f, y, x = s[OL].op.axis
rc, ry, rx = s[OL].op.reduce_axis
rc_factors = [512, 32, 1]
rx_factors = [3, 1, 1]
ry_factors = [3, 1, 1]
rco, rcm = s[OL].split(rc, factor=rc_factors[1])
rcm, rci = s[OL].split(rcm, factor=rc_factors[2])
ryo, rym = s[OL].split(ry, factor=ry_factors[1])
rym, ryi = s[OL].split(rym, factor=ry_factors[2])
rxo, rxm = s[OL].split(rx, factor=rx_factors[1])
rxm, rxi = s[OL].split(rxm, factor=rx_factors[2])
#rco, rcm, rci = cfg['tile_rc'].apply(s, OL, rc)
#ryo, rym, ryi = cfg['tile_rx'].apply(s, OL, ry)
#rxo, rxm, rxi = cfg['tile_ry'].apply(s, OL, rx)
s[OL].reorder(rco, ryo, rxo, rcm, rym, rxm, rci, ryi, rxi, n, f, y, x)

s[AA].compute_at(s[OL], rxo)
s[WW].compute_at(s[OL], rxo)
s[AL].compute_at(s[OL], rxm)
s[WL].compute_at(s[OL], rxm)

print(tvm.lower(s, [input, kernel, output], simple_mode=True))

produce compute {
  // attr [iter_var(blockIdx.z, , blockIdx.z)] thread_extent = 16
  // attr [compute.local] storage_scope = "local"
  allocate compute.local[float32 * 1]
  // attr [pad_temp.shared] storage_scope = "shared"
  allocate pad_temp.shared[float32 * 224]
  // attr [kernel.shared] storage_scope = "shared"
  allocate kernel.shared[float32 * 1024]
  // attr [pad_temp.shared.local] storage_scope = "local"
  allocate pad_temp.shared.local[float32 * 1]
  // attr [kernel.shared.local] storage_scope = "local"
  allocate kernel.shared.local[float32 * 1]
  // attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 1
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 7
  // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 32
  // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 7
  // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 1
  produce compute.local {
    compute.local[0] = 0.000000f
    for (rc.outer, 0, 16) {
      for (ry.oute

Cooperative fetching:

In [10]:
# cooperative fetching
for load in [AA, WW]:
    n, f, y, x = s[load].op.axis 
    fused = s[load].fuse(n, f, y, x)
    tz, fused = s[load].split(fused, nparts=tile_f_factors[2])
    ty, fused = s[load].split(fused, nparts=tile_y_factors[2])
    tx, fused = s[load].split(fused, nparts=tile_x_factors[2])
    s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
    s[load].bind(ty, tvm.thread_axis("threadIdx.y"))
    s[load].bind(tx, tvm.thread_axis("threadIdx.x"))
print(tvm.lower(s, [input, kernel, output], simple_mode=True))
# tune unroll
#s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
#s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)

produce compute {
  // attr [iter_var(blockIdx.z, , blockIdx.z)] thread_extent = 16
  // attr [compute.local] storage_scope = "local"
  allocate compute.local[float32 * 1]
  // attr [pad_temp.shared] storage_scope = "shared"
  allocate pad_temp.shared[float32 * 224]
  // attr [kernel.shared] storage_scope = "shared"
  allocate kernel.shared[float32 * 1024]
  // attr [pad_temp.shared.local] storage_scope = "local"
  allocate pad_temp.shared.local[float32 * 1]
  // attr [kernel.shared.local] storage_scope = "local"
  allocate kernel.shared.local[float32 * 1]
  // attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 1
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 7
  // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 32
  // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 7
  // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 1
  produce compute.local {
    compute.local[0] = 0.000000f
    for (rc.outer, 0, 16) {
      for (ry.oute

Compile and run:

In [14]:
# check correctness
a_np = np.random.uniform(size=(N, CI, H, W)).astype(np.float32)
w_np = np.random.uniform(size=(CO, CI, KH, KW)).astype(np.float32)
c_np = conv2d_nchw_python(a_np, w_np, stride, padding)

with tvm.target.create('cuda'):
    manual_conv2d = tvm.build(s, [input, kernel, output])
    
ctx = tvm.gpu()
a_tvm = tvm.nd.array(a_np, ctx=ctx)
w_tvm = tvm.nd.array(w_np, ctx=ctx)
c_tvm = tvm.nd.empty(c_np.shape, ctx=ctx)
manual_conv2d(a_tvm, w_tvm, c_tvm)

tvm.testing.assert_allclose(c_np, c_tvm.asnumpy(), rtol=1e-2)

evaluator = manual_conv2d.time_evaluator(manual_conv2d.entry_name, ctx, number=400)
mean = evaluator(a_tvm, w_tvm, c_tvm).mean
print("Time cost of this operator: %f" % mean)
print("GFLOPS:", (autotvm.task.task.compute_flop(s)/mean)/10e9)

Time cost of this operator: 0.001189
GFLOPS: 19.475072378822485
