In [6]:
import tvm
from tvm import tensorir

def get_phase0(s, args, simple_mode=True):
    """get statement after phase 0"""
    ret = []

    def fetch_pass(stmt):
        ret.append(stmt)
        return stmt

    with tvm.build_config(add_lower_pass=[(0, fetch_pass)]):
        tvm.lower(s, args, simple_mode=simple_mode)

    return ret[0]


N = M = K = 128

A = tvm.placeholder((N, M, K), name='A')
B = tvm.compute((N, M, K), lambda i, j, k: A[i, j, k] + 1, name='B')
C = tvm.compute((N, M, K), lambda i, j, k: B[i, j, k] * 2, name='C')
D = tvm.compute((N, M, K), lambda i, j, k: C[i, j, k] - 3, name='D')

s = tvm.create_schedule([C.op, D.op])
s[B].split(s[B].op.axis[2], 16)
stmt = get_phase0(s, [A, C, D])

print(stmt)

// attr [compute(B, 0x24137b0)] realize_scope = ""
realize B([0, 128], [0, 128], [0, 128]) {
  produce B {
    for (i, 0, 128) {
      for (j, 0, 128) {
        for (k.outer, 0, 8) {
          for (k.inner, 0, 16) {
            B(i, j, (k.inner + (k.outer*16))) =(A(i, j, (k.inner + (k.outer*16))) + 1.000000f)
          }
        }
      }
    }
  }
  // attr [compute(C, 0x2408ae0)] realize_scope = ""
  realize C([0, 128], [0, 128], [0, 128]) {
    produce C {
      for (i, 0, 128) {
        for (j, 0, 128) {
          for (k, 0, 128) {
            C(i, j, k) =(B(i, j, k)*2.000000f)
          }
        }
      }
    }
    // attr [compute(D, 0x2415900)] realize_scope = ""
    realize D([0, 128], [0, 128], [0, 128]) {
      produce D {
        for (i, 0, 128) {
          for (j, 0, 128) {
            for (k, 0, 128) {
              D(i, j, k) =(C(i, j, k) - 3.000000f)
            }
          }
        }
      }
    }
  }
}



In [13]:
s = tensorir.create_schedule(stmt)

B, C, D = s.statements()
i, j, k = s.axis(C)

print(s.root)

for root = 0 to 1
  for i = 0 to 128
    for j = 0 to 128
      for k.outer = 0 to 8
        for k.inner = 0 to 16
          B(i, j, (k.inner + (k.outer*16))) =(A(i, j, (k.inner + (k.outer*16))) + 1.000000f)
  for i = 0 to 128
    for j = 0 to 128
      for k = 0 to 128
        C(i, j, k) =(B(i, j, k)*2.000000f)
  for i = 0 to 128
    for j = 0 to 128
      for k = 0 to 128
        D(i, j, k) =(C(i, j, k) - 3.000000f)



In [14]:
i, j, k = s.axis(D)
s.compute_at(C, j)

print(s.root)

for root = 0 to 1
  for i = 0 to 128
    for j = 0 to 128
      for k.outer = 0 to 8
        for k.inner = 0 to 16
          B(i, j, (k.inner + (k.outer*16))) =(A(i, j, (k.inner + (k.outer*16))) + 1.000000f)
  for i = 0 to 128
    for j = 0 to 128
      for axis_2 = 0 to 127
        C(i, j, axis_2) =(B(i, j, axis_2)*2.000000f)
      for k = 0 to 128
        D(i, j, k) =(C(i, j, k) - 3.000000f)



In [15]:
s.compute_inline(B)

print(s.root)

for root = 0 to 1
  for i = 0 to 128
    for j = 0 to 128
      for axis_2 = 0 to 127
        C(i, j, axis_2) =((A(i, j, axis_2) + 1.000000f)*2.000000f)
      for k = 0 to 128
        D(i, j, k) =(C(i, j, k) - 3.000000f)



In [16]:
s.shrink_layout(C)

print(s.root)

for root = 0 to 1
  for i = 0 to 128
    for j = 0 to 128
      for axis_2 = 0 to 127
        C(axis_2) =((A(i, j, axis_2) + 1.000000f)*2.000000f)
      for k = 0 to 128
        D(i, j, k) =(C(k) - 3.000000f)



In [None]:
s.expand_layout(C)
s.compute_root(C)

print(s.root)