# Compute definition

In [2]:
import d2ltvm
import numpy as np
import tvm
from tvm import te

successd...



In [9]:
from tvm import topi
def batch_norm(c, n, eps=1e-5):
    """ batch normalization
        c : channels
        N : input width and height
        eps : small positive value to prevent divide 0
    """
    X = te.placeholder((c, n, n), name='X')
    Mean = te.placeholder((c, 1, 1), name='Mean')
    Var = te.placeholder((c, 1, 1), name='Var')
    Gamma = te.placeholder((c, 1, 1), name='Gamma')
    Beta = te.placeholder((c, 1, 1), name='Beta')
    C1 = X - Mean
    C2 = topi.sqrt(Var + eps)
    Y = C1 / C2 * Gamma + Beta
    return X, Mean, Var, Gamma, Beta, Y

In [10]:
c, n = 28, 28
X, Mean, Var, Gamma, Beta, Y = batch_norm(c, n)
sch = te.create_schedule(Y.op)
mod = tvm.build(sch, [X, Mean, Var, Gamma, Beta, Y])
print(tvm.lower(sch, [X, Mean, Var, Gamma, Beta], simple_mode=True))

@main = primfn(X_1: handle, Mean_1: handle, Var_1: handle, Gamma_1: handle, Beta_1: handle) -> ()
  attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}
  buffers = {X: Buffer(X_2: Pointer(float32), float32, [21952], []),
             Mean: Buffer(Mean_2: Pointer(float32), float32, [28], []),
             Var: Buffer(Var_2: Pointer(float32), float32, [28], []),
             Gamma: Buffer(Gamma_2: Pointer(float32), float32, [28], []),
             Beta: Buffer(Beta_2: Pointer(float32), float32, [28], [])}
  buffer_map = {X_1: X, Mean_1: Mean, Var_1: Var, Gamma_1: Gamma, Beta_1: Beta}
  preflattened_buffer_map = {Var_1: Var_3: Buffer(Var_2, float32, [28, 1, 1], []), Gamma_1: Gamma_3: Buffer(Gamma_2, float32, [28, 1, 1], []), Beta_1: Beta_3: Buffer(Beta_2, float32, [28, 1, 1], []), Mean_1: Mean_3: Buffer(Mean_2, float32, [28, 1, 1], []), X_1: X_3: Buffer(X_2, float32, [28, 28, 28], [])} {
  allocate(T_subtract: Pointer(global float32), float32, [21952]), 

In [12]:
def get_bn_data(c, n, constructor=None):
    """ Return the batch norm data, mean, variance, gamma and beta tensors.
        Also return the empty tensor for output.
        c : channels
        n : input width and height
        constructor : user-defined tensor constructor
    """
    np.random.seed(0)
    data = np.random.normal(size=(c, n, c)).astype('float32')
    mean = np.random.normal(size=(c, 1, 1)).astype('float32')
    
    var = np.random.normal(size=(c, 1, 1)).astype('float32')
    var = np.absolute(var)
    
    gamma = np.random.normal(size=(c, 1, 1)).astype('float32')
    beta = np.random.normal(size=(c, 1, 1)).astype('float32')
    out = np.empty(shape=(c, n, n)).astype('float32')
    
    if constructor is not None:
        data, mean, var, gamma, beta, out = \
            [constructor(x) for x in (data, mean, var, gamma, beta, out)]
    
    return data, mean, var, gamma, beta, out

data, mean, var, gamma, beta, out = get_bn_data(c, n, tvm.nd.array)
mod(data, mean, var, gamma, beta, out)

# Torch baseline

In [21]:
import torch
def get_bn_data_torch(c, n, ctx='cpu'):
    device = torch.device(ctx)
    data, mean, var, gamma, beta, _ = get_bn_data(c, n, lambda x: torch.tensor(x, device=device))
    data = data[None, ...]
    mean, var, gamma, beta = [x[:, 0, 0] for x in (mean, var, gamma, beta)]
    return data, mean, var, gamma, beta

def batch_norm_torch(data, mean, var, gamma, beta, eps=1e-5):
    return torch.nn.functional.batch_norm(data, mean, var, gamma, beta, eps=eps)

data, mean, var, gamma, beta = get_bn_data_torch(c, n)
out_torch = batch_norm_torch(data, mean, var, gamma, beta)
np.testing.assert_allclose(out.asnumpy(), out_torch[0].numpy(), atol=1e-5)

# Summary
1.From the computation perspective, batch_norm is a combination of a number of broadcast and element-wise simple operators, which can be easily attained from TVM’s Tensor OPerator Inventory(TOPI).\
2.In inference, mean and var of batch_norm are pre-defined.