# BatchNormalization

In [None]:
class BatchNormalization:
    '''
    gamma:スケール係数
    beta:オフセット
    momentum：慣性
    running_mean:テスト時に使用する平均
    running_var:テスト時に使用する分散
    '''_
    def __init__(self, gamma, beta, momentum=0.9, running_mean=None, running_var=None):
        self.gamma = gamma
        self.beta = beta
        self.momentum = momentum
        self.input_shape = None
        
        self.runnning_mean = running_mean
        self.runnning_var = running_var
        
        #　backword時に使用する中間データ
        self.batch_size = None
        self.xc = None
        self.std = None
        self.dgamma = None
        self.dbeta = None
        
    def forword(self, x, train_flg=True):
        if self.running_mean is None:
            N, D = x.shape
            self.running_mean = np.zeros(D)
            self.running_var = np.zeros(D)
            
        if train_flg:
            mu = x_mean(axis=0)　#平均
            xc = x - mu　#センタリング
            var = np.mean(xc**2, axis=0) #分散
            std = np.sqrt(var +10e-7)　#スケーリング
            xn = xc / std
            
            self.batch_size = x.shape[0]
            self.xc = xc
            self.xn = xn
            self.std = std
            self.runnning_mean = self.momentum * self.running_mean + (1 - self.momentum) * mu #　平均値の加重平均
            self.running_var = self.momentum * self.runnning_var + (1 - self.momentum) * var #　分散の加重平均
        else:
            xc = x - self.runnning_mean
            xn = xc / ((np.sqrt(self.runnning_var + 10e-7))
                       
        out = self.gamma * xn + self.beta
        
        return out
    
    def backword(self, dout):
        dbeta = dout.sum(axis=0)
        dgamma = np.sum(self.xn * dout, axis=0)
        dxn = self.gamma * dout
        dxc = dxn / self.std
        dstd = -np.sum((dxn * self.xc) / (self.std * self.std), axis=0)
        dvar = 0.5 * dstd / self.std
        dxc += (2.0 / self.batch_size) * self.xc *dvar
        dmu = np.sum(dxc, axis=0)
        dx = dxc - dmu / self.batch_size
        
        self.dgamma = dgamma
        self.dbeta = dbeta
        
        return dx