In [1]:
import numpy as np

from __future__ import annotations

In [37]:
class BatchNorm1d:
    def __init__(self, dim, eps=1e-5, momentum=0.01):
        self.dim = dim
        self.eps = eps
        self.momentum = momentum
        self.training = True

        # scale and shift
        self.alpha = np.ones(dim)
        self.beta = np.zeros(dim)

        # buffers
        self.running_m = np.zeros(dim)
        self.running_v = np.ones(dim)

    def __call__(self, X):
        if self.training:
            xmean = X.mean(0, keepdims=True)
            xvar = X.var(0, keepdims=True)

        else:
            xmean = self.running_m
            xvar = self.running_v

        xhat =  (X - xmean) / np.sqrt(xvar * self.eps) 
        self.out = self.alpha * xhat + self.beta

        if self.training:
            self.running_m = (1 - self.momentum) * self.running_m + xmean * self.momentum
            self.running_v = (1 - self.momentum) * self.running_v + xmean * self.momentum
    
        return self.out

    def parameters(self):
        return [self.alpha, self.beta]

In [54]:
dim = 10
X = np.random.rand(10, dim)

In [55]:
batchnorm = BatchNorm1d(dim)

In [58]:
batchnorm(X) # backprop trains alpha and beta to scale and shift the distribution.

(1, 10)


array([[ -73.94154373, -477.09240825, -633.04240822, -562.78663677,
        -168.91455215,  314.90142252, -220.25589394,  449.79367929,
         216.07911272, -312.06563321],
       [ 292.33120788,  -80.80236981,  150.76935706,  534.16450269,
         315.72279614, -517.36361463, -270.83692471,  -94.70184021,
         416.74598972,  133.28015306],
       [  90.80324785,  331.24602773,   29.66031185,  181.87270948,
        -291.9782515 ,  211.30144452, -145.80574459, -291.59967123,
         354.50403873,  510.36388209],
       [-448.89481368, -352.9054246 ,  -43.66418039, -261.65853319,
         197.48382546, -119.82594463,  457.62234955,  221.61740045,
          21.69036161,   86.59271904],
       [  67.75167971,  -58.00247915, -242.13617595, -143.62208535,
        -203.23515387, -140.47823494, -478.79486499, -457.02715749,
        -148.65778578, -318.55600237],
       [-490.88058651,  490.48234639,  -95.53701694, -207.56404908,
         356.81486901,  107.12960518,  276.72627218,   99

In [60]:
batchnorm.parameters() # currently 1's and 0's and no grad is used.

[array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),
 array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])]