<a href="https://colab.research.google.com/github/quxiaojing1985/OpenSearch/blob/main/MyConv.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!python3 -m  pip install mlc-ai-nightly -f https://mlc.ai/wheels

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://mlc.ai/wheels
Collecting mlc-ai-nightly
  Downloading https://github.com/mlc-ai/utils/releases/download/v0.9.dev0/mlc_ai_nightly-0.9.dev1956%2Bge3f218d71-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (44.2 MB)
[K     |████████████████████████████████| 44.2 MB 1.8 MB/s 
Collecting synr==0.6.0
  Downloading synr-0.6.0-py3-none-any.whl (18 kB)
Installing collected packages: synr, mlc-ai-nightly
Successfully installed mlc-ai-nightly-0.9.dev1956+ge3f218d71 synr-0.6.0


In [2]:
import tvm
from tvm.ir.module import IRModule
from tvm.script import tir as T
import numpy as np

In [3]:
N, CI, H, W, CO, K = 1, 1, 8, 8, 2, 3
OUT_H, OUT_W = H - K +1, W - K + 1
data =  np.arange(N*CI*H*W).reshape(N, CI, H, W)
Weight = np.arange(CO*CI*K*K).reshape(CO, CI, K, K)
print(data)

[[[[ 0  1  2  3  4  5  6  7]
   [ 8  9 10 11 12 13 14 15]
   [16 17 18 19 20 21 22 23]
   [24 25 26 27 28 29 30 31]
   [32 33 34 35 36 37 38 39]
   [40 41 42 43 44 45 46 47]
   [48 49 50 51 52 53 54 55]
   [56 57 58 59 60 61 62 63]]]]


In [4]:
import torch
data_torch = torch.Tensor(data)
weight_torch =  torch.Tensor(Weight)
conv_torch =  torch.nn.functional.conv2d(data_torch, weight_torch)
conv_torch = conv_torch.numpy().astype(np.int64)
conv_torch

array([[[[ 474,  510,  546,  582,  618,  654],
         [ 762,  798,  834,  870,  906,  942],
         [1050, 1086, 1122, 1158, 1194, 1230],
         [1338, 1374, 1410, 1446, 1482, 1518],
         [1626, 1662, 1698, 1734, 1770, 1806],
         [1914, 1950, 1986, 2022, 2058, 2094]],

        [[1203, 1320, 1437, 1554, 1671, 1788],
         [2139, 2256, 2373, 2490, 2607, 2724],
         [3075, 3192, 3309, 3426, 3543, 3660],
         [4011, 4128, 4245, 4362, 4479, 4596],
         [4947, 5064, 5181, 5298, 5415, 5532],
         [5883, 6000, 6117, 6234, 6351, 6468]]]])

In [5]:
@tvm.script.ir_module
class MyConv:
  @T.prim_func
  def conv(D: T.Buffer[(1, 1, 8, 8), "int64"],
           W: T.Buffer[(2, 1, 3, 3), "int64"],
           C: T.Buffer[(1, 2, 6, 6), "int64"]):
    T.func_attr({"global_symbol": "conv", "tir.noalias" : True})
    for b, k, i, j, q, di, dj in T.grid(1, 2, 6, 6, 1, 3, 3):
      with T.block("C"):
        vb, vk, vi, vj, vq, vdi, vdj = T.axis.remap("SSSSRRR", [b, k, i, j, q, di, dj])
        with T.init():
          C[vb, vk, vi, vj] = T.int64(0)
        C[vb, vk, vi, vj] = C[vb, vk, vi, vj] + D[vb, vq, vi+vdi, j +vdj] * W[vk, vq, vdi, vdj]

In [6]:
rt_lib = tvm.build(MyConv, target="llvm")
data_tvm = tvm.nd.array(data)
weight_tvm = tvm.nd.array(Weight)
conv_tvm = tvm.nd.array(np.empty((N, CO, OUT_H, OUT_W), dtype=np.int64))
rt_lib["conv"](data_tvm, weight_tvm, conv_tvm)
np.testing.assert_allclose(conv_tvm.numpy(), conv_torch, rtol=1e-5)
print(conv_tvm)

[[[[ 474  510  546  582  618  654]
   [ 762  798  834  870  906  942]
   [1050 1086 1122 1158 1194 1230]
   [1338 1374 1410 1446 1482 1518]
   [1626 1662 1698 1734 1770 1806]
   [1914 1950 1986 2022 2058 2094]]

  [[1203 1320 1437 1554 1671 1788]
   [2139 2256 2373 2490 2607 2724]
   [3075 3192 3309 3426 3543 3660]
   [4011 4128 4245 4362 4479 4596]
   [4947 5064 5181 5298 5415 5532]
   [5883 6000 6117 6234 6351 6468]]]]
