In [1]:
import tvm
from tvm import te
import numpy as np

In [19]:
import tvm.topi as topi

def batch_norm(c, n, eps=1e-5):
    """batch normalization

    c : channels
    N : input width and height
    eps : small positive value to prevent divide 0
    """

    X = te.placeholder((c, n, n), name='X')
    Mean = te.placeholder((c, 1, 1), name='Mean')
    Var = te.placeholder((c, 1, 1), name='Var')
    Gamma = te.placeholder((c, 1, 1), name='Gamma')
    Beta = te.placeholder((c, 1, 1), name='Beta')
    C1 = X - Mean
    C2 = topi.sqrt(Var + eps)
    Y = C1 / C2 * Gamma + Beta
    return X, Mean, Var, Gamma, Beta, Y

In [20]:
c = 32
n = 28
X, Mean, Var, Gamma, Beta, Y = batch_norm(c, n)

sch = te.create_schedule(Y.op)
mod = tvm.build(sch, [X, Mean, Var, Gamma, Beta, Y])

print(tvm.lower(sch, [X, Mean, Var, Gamma, Beta], simple_mode=True))

@main = primfn(X_1: handle, Mean_1: handle, Var_1: handle, Gamma_1: handle, Beta_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {X: Buffer(X_2: Pointer(float32), float32, [25088], []),
             Mean: Buffer(Mean_2: Pointer(float32), float32, [32], []),
             Var: Buffer(Var_2: Pointer(float32), float32, [32], []),
             Gamma: Buffer(Gamma_2: Pointer(float32), float32, [32], []),
             Beta: Buffer(Beta_2: Pointer(float32), float32, [32], [])}
  buffer_map = {X_1: X, Mean_1: Mean, Var_1: Var, Gamma_1: Gamma, Beta_1: Beta}
  preflattened_buffer_map = {Var_1: Var_3: Buffer(Var_2, float32, [32, 1, 1], []), Beta_1: Beta_3: Buffer(Beta_2, float32, [32, 1, 1], []), Gamma_1: Gamma_3: Buffer(Gamma_2, float32, [32, 1, 1], []), X_1: X_3: Buffer(X_2, float32, [32, 28, 28], []), Mean_1: Mean_3: Buffer(Mean_2, float32, [32, 1, 1], [])} {
  allocate(T_subtract: Pointer(global float32), float32, [25088]), 

In [21]:
def get_bn_data(c, n, constructor=None):
    """Return the batch norm data, mean, variance, gamma and beta tensors.
       Also return the empty tensor for output.

    c : channels
    n : input width and height
    constructor : user-defined tensor constructor
    """
    np.random.seed(0)
    data = np.random.normal(size=(c, n, n)).astype('float32')
    mean = np.random.normal(size=(c, 1, 1)).astype('float32')
    # move the mean of the normal distribution to be 1
    var = np.random.normal(loc=1.0, size=(c, 1, 1)).astype('float32')
    # make sure all variance numbers are not negative
    var = np.absolute(var)
    gamma = np.random.normal(size=(c, 1, 1)).astype('float32')
    beta = np.random.normal(size=(c, 1, 1)).astype('float32')
    out = np.empty((c, n, n), dtype='float32')
    if constructor:
        data, mean, var, gamma, beta, out = \
        (constructor(x) for x in [data, mean, var, gamma, beta, out])
    return data, mean, var, gamma, beta, out

data, mean, var, gamma, beta, out = get_bn_data(c, n, tvm.nd.array)
mod(data, mean, var, gamma, beta, out)

In [22]:
import mxnet as mx

def get_bn_data_mxnet(c, n, ctx='cpu'):
    ctx = getattr(mx, ctx)()
    data, mean, var, gamma, beta, out = get_bn_data(c, n,
                                      lambda x: mx.nd.array(x, ctx=ctx))
    data, out = data.expand_dims(axis=0), out.expand_dims(axis=0)
    return data, mean, var, gamma, beta, out

def batch_norm_mxnet(data, mean, var, gamma, beta, out, eps=1e-5):
    # use_global_stats=True to use the input mean and var instead of computing
    # the mean and var of the input data.
    # fix_gamma=False so that gamma won't be set to 1.
    mx.nd.BatchNorm(data, gamma, beta, mean, var, eps,
                    use_global_stats=True, fix_gamma=False, out=out)

data, mean, var, gamma, beta, out_mx = get_bn_data_mxnet(c, n)
batch_norm_mxnet(data, mean, var, gamma, beta, out_mx)

In [None]:
np.testing.assert_allclose(out_mx[0].asnumpy(), out.asnumpy(), atol=1e-5)