# Cmpute definition

In [2]:
import torch
import d2ltvm
import tvm
from tvm import te
import numpy as np

In [6]:
def pool(pool_type, c, nh, nw, kh, kw, ph=0, pw=0, sh=1, sw=1):
    """ 2D pooling
        pool_type: pooling type, 'max' or 'avg'
        c : channels
        nh, nw : input width and height
        kh, kw : kernel width and height
        ph, pw : height and width padding sizes, default 0
        sh, sw : height and width strides, default 1
    """
    # reduction axis
    rkh = te.reduce_axis((0, kh), name='rkh')
    rkw = te.reduce_axis((0, kw), name='rkw')
    # output height and weights
    X = te.placeholder((c, nh, nw), name='X')
    oh = d2ltvm.conv_out_size(nh, kh, ph, sh)
    ow = d2ltvm.conv_out_size(nw, kw, pw, sw)
    
    if pool_type == 'max':
        PaddedX = d2ltvm.padding(X, ph, pw, val=te.min_value(X.dtype)) if ph | pw != 0 else X
        Y = te.compute((c, oh, ow), \
                      lambda c, i, j: te.max(PaddedX[c, i * sh + rkh, j * sw + rkw],\
                        axis=[rkh, rkw]), tag='pool_max', name='PoolMax')
    elif pool_type == 'avg':
        PaddedX = d2ltvm.padding(X, ph, pw, val=0) if ph | pw != 0 else X
        tSum = te.compute((c, oh, ow), \
                      lambda c, i, j: te.sum(PaddedX[c, i * sh + rkh, j * sw + rkw], \
                        axis=[rkh, rkw]), tag='pool_avg1', name='PoolSum')
        Y = te.compute((c, oh, ow), \
                      lambda c, i, j: tSum[c, i, j] / (kh * kw), \
                        tag='pool_avg2', name='PoolAvg')
    else:
        raise ValueError("pool type should be 'avg' or 'max'")
    return X, Y, PaddedX

In [8]:
c, n, k, p, s = 4, 12, 3, 1, 1
X, Y, _ = pool('max', c, n, n, k, k, p, p, s, s)
sch = te.create_schedule(Y.op)
mod = tvm.build(sch, [X, Y])
print(tvm.lower(sch, [X, Y], simple_mode=True))
data, _, out_max = d2ltvm.get_conv_data(c, c, n, k, p, s, tvm.nd.array)
mod(data, out_max)

@main = primfn(X_1: handle, PoolMax_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {X: Buffer(X_2: Pointer(float32), float32, [576], []),
             PoolMax: Buffer(PoolMax_2: Pointer(float32), float32, [576], [])}
  buffer_map = {X_1: X, PoolMax_1: PoolMax}
  preflattened_buffer_map = {X_1: X_3: Buffer(X_2, float32, [4, 12, 12], []), PoolMax_1: PoolMax_3: Buffer(PoolMax_2, float32, [4, 12, 12], [])} {
  allocate(PaddedX: Pointer(global float32), float32, [784]), storage_scope = global {
    for (i0: int32, 0, 4) {
      for (i1: int32, 0, 14) {
        for (i2: int32, 0, 14) {
          PaddedX_1: Buffer(PaddedX, float32, [784], [])[(((i0*196) + (i1*14)) + i2)] = @tir.if_then_else(((((i1 < 1) || (13 <= i1)) || (i2 < 1)) || (13 <= i2)), -3.40282e+38f32, X[((((i0*144) + (i1*12)) + i2) - 13)], dtype=float32)
        }
      }
    }
    for (c: int32, 0, 4) {
      for (i: int32, 0, 12) {
        for (j: int32, 0, 12

In [9]:
X, Y, _ = pool('avg', c, n, n, k, k, p, p, s, s)
sch = te.create_schedule(Y.op)
mod = tvm.build(sch, [X, Y])
print(tvm.lower(sch, [X, Y], simple_mode=True))
data, _, out_avg = d2ltvm.get_conv_data(c, c, n, k, p, s, tvm.nd.array)
mod(data, out_avg)

@main = primfn(X_1: handle, PoolAvg_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {X: Buffer(X_2: Pointer(float32), float32, [576], []),
             PoolAvg: Buffer(PoolAvg_2: Pointer(float32), float32, [576], [])}
  buffer_map = {X_1: X, PoolAvg_1: PoolAvg}
  preflattened_buffer_map = {X_1: X_3: Buffer(X_2, float32, [4, 12, 12], []), PoolAvg_1: PoolAvg_3: Buffer(PoolAvg_2, float32, [4, 12, 12], [])} {
  allocate(PaddedX: Pointer(global float32), float32, [784]), storage_scope = global;
  allocate(PoolSum: Pointer(global float32), float32, [576]), storage_scope = global {
    for (i0: int32, 0, 4) {
      for (i1: int32, 0, 14) {
        for (i2: int32, 0, 14) {
          PaddedX_1: Buffer(PaddedX, float32, [784], [])[(((i0*196) + (i1*14)) + i2)] = @tir.if_then_else(((((i1 < 1) || (13 <= i1)) || (i2 < 1)) || (13 <= i2)), 0f32, X[((((i0*144) + (i1*12)) + i2) - 13)], dtype=float32)
        }
      }
    }
    for (c

# Torch Baseline

In [18]:
def get_pool_data_torch(c, n, k, p, s, ctx='cpu'):
    device = torch.device(ctx)
    data, _, _ = d2ltvm.get_conv_data(c, c, n, k, p, s, lambda x: torch.tensor(x, device=device))
    data = data[None, ...]
    return data

In [19]:
def pool_torch(pool_type, data, k, p, s):
    if pool_type == 'avg':
        return torch.nn.functional.avg_pool2d(data, k, s, p)
    elif pool_type == 'max':
        return torch.nn.functional.max_pool2d(data, k, s, p)
    else:
        raise ValueError("pool type should be 'avg' or 'max'")

In [20]:
data = get_pool_data_torch(c, n, k, p, s)
out_max_torch = pool_torch('max', data, k, p, s)
data = get_pool_data_torch(c, n, k, p, s)
out_avg_torch = pool_torch('avg', data, k, p, s)

In [22]:
np.testing.assert_allclose(out_max.asnumpy(), out_max_torch[0].numpy(), atol=1e-5)
np.testing.assert_allclose(out_avg.asnumpy(), out_avg_torch[0].numpy(), atol=1e-5)

# Summary
1.2D pooling handles the data in the similar way as 2D convolution, but the computation itself is much lighter.(IO)\
2.We can define max pooling and avg pooling easily using TVM expressions.