In [1]:
import nnvm.compiler
import nnvm.symbol as sym
import tvm

In [2]:
tvm_lower_old = tvm.lower

In [3]:
calls = {}

def my_modified_lower(sch,
                      args,
                      name="default_function",
                      binds=None,
                      simple_mode=False):
    # simple mode is usually used for debugging, don't debug debugging
    if not simple_mode:
        calls[name] = {'sch': sch, 'args': args, 'binds': binds}
        print("Lowering " + name)
        res = tvm_lower_old(sch, args, name, binds, simple_mode)
        # Ok, we can't really print res (I haven't found how at least), so we do the same lowering twice
        # simple_mode=True means that we don't wrap theresult in an api-ready function, and we don't apply loop partitioning
        print("result:\n" + str(tvm_lower_old(sch, args, name, binds, simple_mode=True)))
        print("")
        return res
    else:
        return tvm_lower_old(sch, args, name, binds, simple_mode)

tvm.lower = my_modified_lower

In [4]:
x = sym.Variable("x", shape=(1,13,100,100), dtype=1)
z = x
z = sym.conv2d(data=z, channels=13, kernel_size=(1,1), padding=[0,0])
z = sym.relu(data=z)
z = x + z
# TODO: Doesn't work without keepdims which might be a bug in nnvm, investigate
z = sym.sum(z, keepdims=True)

graph = nnvm.graph.create(z)
print(graph.ir())

Graph(%x, %conv2d0_weight, %conv2d0_bias) {
  %3 = conv2d(%x, %conv2d0_weight, %conv2d0_bias, channels='13', padding='[0, 0]', kernel_size='(1, 1)')
  %4 = relu(%3)
  %5 = broadcast_add(%x, %4)
  %6 = sum(%5, keepdims='True')
  ret %6
}


In [5]:
graph._set_json_attr("target", "llvm", "str")
with tvm.target.create("llvm"):
    compiled = graph.apply('InferType').apply('InferShape').apply('GraphFusePartition').apply('GraphFuseCompile')

Lowering fuse_conv2d_relu_broadcast_add
result:
// attr [pad_temp] storage_scope = "global"
allocate pad_temp[float64 * 1 * 13 * 100 * 100]
// attr [compute] storage_scope = "global"
allocate compute[float64 * 1 * 13 * 100 * 100]
produce pad_temp {
  parallel (i0.i1.fused, 0, 13) {
    for (i2, 0, 100) {
      for (i3, 0, 100) {
        pad_temp[((((i0.i1.fused*100) + i2)*100) + i3)] = input0[((((i0.i1.fused*100) + i2)*100) + i3)]
      }
    }
  }
}
produce compute {
  parallel (nn.ff.fused, 0, 13) {
    for (yy.init, 0, 100) {
      for (xx.outer.init, 0, 7) {
        for (xx.inner.init.s, 0, 16) {
          if (likely(((xx.outer.init*16) < (100 - xx.inner.init.s)))) {
            compute[(((((nn.ff.fused*100) + yy.init)*100) + (xx.outer.init*16)) + xx.inner.init.s)] = 0.000000
          }
        }
      }
    }
    for (rc, 0, 13) {
      for (yy, 0, 100) {
        for (xx.outer, 0, 7) {
          for (xx.inner.s, 0, 16) {
            if (likely(((xx.outer*16) < (100 - xx.inner.s))

In [9]:
sch = calls['fuse_conv2d_relu_broadcast_add']['sch']