In [1]:
import nnvm.compiler
import nnvm.symbol as sym
import tvm

In [2]:
tvm_lower_old = tvm.lower

In [3]:
calls = {}

def my_modified_lower(sch,
                      args,
                      name="default_function",
                      binds=None,
                      simple_mode=False):
    # simple mode is usually used for debugging, don't debug debugging
    if not simple_mode:
        calls[name] = {'sch': sch, 'args': args, 'binds': binds}
        print("Lowering " + name)
        res = tvm_lower_old(sch, args, name, binds, simple_mode)
        # Ok, we can't really print res (I haven't found how at least), so we do the same lowering twice
        # simple_mode=True means that we don't wrap theresult in an api-ready function, and we don't apply loop partitioning
        print("result:\n" + str(tvm_lower_old(sch, args, name, binds, simple_mode=True)))
        print("")
        return res
    else:
        return tvm_lower_old(sch, args, name, binds, simple_mode)

tvm.lower = my_modified_lower

In [4]:
x = sym.Variable("x", shape=(1,13,10000,10000), dtype=1)
z = x
z = sym.conv2d(data=z, channels=13, kernel_size=(1,1), padding=[0,0])
z = sym.relu(data=z)
z = x + z
# TODO: Doesn't work without keepdims which might be a bug in nnvm, investigate
z = sym.sum(z, keepdims=True)

graph = nnvm.graph.create(z)
print(graph.ir())

Graph(%x, %conv2d0_weight, %conv2d0_bias) {
  %3 = conv2d(%x, %conv2d0_weight, %conv2d0_bias, channels='13', padding='[0, 0]', kernel_size='(1, 1)')
  %4 = relu(%3)
  %5 = broadcast_add(%x, %4)
  %6 = sum(%5, keepdims='True')
  ret %6
}


To dump results of tvm ir passes (it will create a lot of files), use 

    with tvm.build_config(dump_pass_ir=True):

In [5]:
graph._set_json_attr("target", "llvm", "str")
with tvm.target.create("llvm"):
    compiled = graph.apply('InferType').apply('InferShape').apply('GraphFusePartition').apply('GraphFuseCompile')

Lowering fuse_conv2d_relu_broadcast_add
result:
// attr [pad_temp] storage_scope = "global"
allocate pad_temp[float64 * 1 * 13 * 10000 * 10000]
// attr [compute] storage_scope = "global"
allocate compute[float64 * 1 * 13 * 10000 * 10000]
produce pad_temp {
  parallel (i0.i1.fused, 0, 13) {
    for (i2, 0, 10000) {
      for (i3, 0, 10000) {
        pad_temp[((((i0.i1.fused*10000) + i2)*10000) + i3)] = input0[((((i0.i1.fused*10000) + i2)*10000) + i3)]
      }
    }
  }
}
produce compute {
  parallel (nn.ff.fused, 0, 13) {
    for (yy.init, 0, 10000) {
      for (xx.outer.init, 0, 625) {
        compute[ramp((((((nn.ff.fused*10000) + yy.init)*625) + xx.outer.init)*16), 1, 16)] = x16(0.000000)
      }
    }
    for (rc, 0, 13) {
      for (yy, 0, 10000) {
        for (xx.outer, 0, 625) {
          compute[ramp((((((nn.ff.fused*10000) + yy)*625) + xx.outer)*16), 1, 16)] = (compute[ramp((((((nn.ff.fused*10000) + yy)*625) + xx.outer)*16), 1, 16)] + (pad_temp[ramp((((((rc*10000) + yy)*625) + 

In [6]:
%ls

0_ScheduleOps_ir.cc              51_UnrollLoop_ir.cc
10_Simplify_ir.cc                52_Simplify_ir.cc
11_LowerStorageAccessInfo_ir.cc  53_LowerStorageAccessInfo_ir.cc
12_RemoveNoOp_ir.cc              54_RemoveNoOp_ir.cc
13_RewriteUnsafeSelect_ir.cc     55_RewriteUnsafeSelect_ir.cc
14_MakeAPI_ir.cc                 56_ThreadSync_ir.cc
15_ScheduleOps_ir.cc             57_ThreadSync_ir.cc
16_InjectPrefetch_ir.cc          58_LowerThreadAllreduce_ir.cc
17_StorageFlatten_ir.cc          59_SplitHostDevice_ir.cc
18_CanonicalSimplify_ir.cc       5_VectorizeLoop_ir.cc
19_VectorizeLoop_ir.cc           60_ThreadSync_ir.cc
1_InjectPrefetch_ir.cc           61_ThreadSync_ir.cc
20_InjectVirtualThread_ir.cc     62_LowerThreadAllreduce_ir.cc
21_InjectDoubleBuffer_ir.cc      63_SplitHostDevice_ir.cc
22_StorageRewrite_ir.cc          64_BindDeviceType_ir.cc
23_UnrollLoop_ir.cc              65_BindDeviceType_ir.cc
24_Simplify_ir.cc                66_LowerTVMBuiltin_ir.cc
25_LowerStorageAcc

In [13]:
sch = calls['fuse_sum']['sch']

KeyError: 'fuse_sum'

In [11]:
sch = sch.normalize()
stmt = tvm.schedule.ScheduleOps(sch, tvm.schedule.InferBound(sch))

In [12]:
stmt

// attr [compute(input0_red, 0x217d190)] realize_scope = ""
realize input0_red([0, 1], [0, 1], [0, 1], [0, 1]) {
  produce input0_red {
    input0_red(0, 0, 0, 0) =0.000000
    for (k1, 0, 13) {
      for (k2, 0, 100) {
        for (k3, 0, 100) {
          input0_red(0, 0, 0, 0) =(input0_red(0, 0, 0, 0) + input0(0, k1, k2, k3))
        }
      }
    }
  }
}

In [None]:
tvm.ir_pass.