In [577]:
import copy
import numpy as np
from cs231n.gradient_check import eval_numerical_gradient, eval_numerical_gradient_array

In [578]:
def print_mean_std(x,axis=0):
    print('  means: ', x.mean(axis=axis))
    print('  stds:  ', x.std(axis=axis))
    print() 
def rel_error(x, y):
    """ returns relative error """
    return np.max(np.abs(x - y) / (np.maximum(1e-8, np.abs(x) + np.abs(y))))

In [575]:
def batchnorm_forward(x, gamma, beta):
    # x - N, M
    x_avg = np.mean(x, axis=0)
    x_sqwrd = x**2
    x_avg_sqwrd = x_avg**2
    x_sqwrd_avg = np.mean(x_sqwrd, axis=0)
    v = x_sqwrd_avg - x_avg_sqwrd
    v_shifted = v + 1e-5
    denom = np.sqrt(v_shifted)
    x_cntrd = x - x_avg
    x_stndrd = x_cntrd/denom
    x_strchd = x_stndrd * gamma
    x_rstrd = x_strchd + beta
    cache = {"x": x,
            "x_avg":x_avg,
            "x_sqwrd":x_sqwrd,
            "x_avg_sqwrd":x_avg_sqwrd,
            "x_sqwrd_avg":x_sqwrd_avg,
            "v":v,
            "v_shifted":v_shifted,
            "denom":denom,
            "x_cntrd":x_cntrd,
            "x_stndrd":x_stndrd,
            "x_strchd":x_strchd,
            "x_rstrd":x_rstrd,
             "gamma":gamma,
             "beta":beta
            }
    return x_rstrd, cache

def batchnorm_backward(dout, cache):
    # x - N, M
    dbeta = np.sum(dout, axis=0)
    dx_strchd = dout
    dgamma = np.sum(dx_strchd * cache["x_stndrd"], axis=0)
    dx_stndrd = dx_strchd * cache["gamma"]
    dx_cntrd = dx_stndrd * (1/cache["denom"])
    ddenom = np.sum(dx_stndrd * (-1*cache["x_cntrd"]*cache["denom"]**(-2)), axis=0)
    dv_shifted = ddenom * (0.5*(cache["v_shifted"]**(-0.5)))
    dv = dv_shifted
    dx_avg_sqwrd = -1 * dv
    dx_avg = dx_avg_sqwrd * 2 * cache["x_avg"] + np.sum(dx_cntrd * (-1), axis = 0)
    dx_sqwrd_avg = dv
    dx_sqwrd = dx_sqwrd_avg * np.ones(cache["x_sqwrd"].shape)/cache["x_sqwrd"].shape[0]
    dx = dx_sqwrd * 2 * cache["x"] + \
         dx_avg * np.ones(cache["x"].shape)/cache["x"].shape[0] + \
         dx_cntrd
    return dx, dgamma, dbeta

In [570]:
def batchnorm_forward_1(x, gamma, beta, replacements={}):
    # x - N, M
    x_avg = replacements.get("x_avg", np.mean(replacements.get("x", x), axis=0))
    x_sqwrd = replacements.get("x_sqwrd", replacements.get("x", x)**2)
    x_avg_sqwrd = replacements.get("x_avg_sqwrd", x_avg**2)
    x_sqwrd_avg = replacements.get("x_sqwrd_avg", np.mean(x_sqwrd, axis=0))
    v = replacements.get("v", x_sqwrd_avg - x_avg_sqwrd)
    v_shifted = replacements.get("v_shifted", v + 1e-5)
    denom = replacements.get("denom", np.sqrt(v_shifted))
    x_cntrd = replacements.get("x_cntrd", replacements.get("x", x) - x_avg)
    x_stndrd = replacements.get("x_stndrd", x_cntrd/denom)
    x_strchd = replacements.get("x_strchd", x_stndrd * replacements.get("gamma", gamma))
    x_rstrd = replacements.get("x_rstrd", x_strchd + replacements.get("beta", beta))
    cache = {"x": replacements.get("x", x),
            "x_avg":replacements.get("x_avg", x_avg),
            "x_sqwrd":replacements.get("x_sqwrd", x_sqwrd),
            "x_avg_sqwrd":replacements.get("x_avg_sqwrd", x_avg_sqwrd),
            "x_sqwrd_avg":replacements.get("x_sqwrd_avg", x_sqwrd_avg),
            "v":replacements.get("v", v),
            "v_shifted":replacements.get("v_shifted", v_shifted),
            "denom":replacements.get("denom", denom),
            "x_cntrd":replacements.get("x_cntrd", x_cntrd),
            "x_stndrd":replacements.get("x_stndrd", x_stndrd),
            "x_strchd":replacements.get("x_strchd", x_strchd),
            "x_rstrd":replacements.get("x_rstrd", x_rstrd),
             "gamma":replacements.get("gamma", gamma),
             "beta":replacements.get("beta", beta)
            }
    return x_rstrd, cache

def batchnorm_backward_1(dout, cache):
    # x - N, M
    dbeta = np.sum(dout, axis=0)
    dx_strchd = dout
    dgamma = np.sum(dx_strchd * cache["x_stndrd"], axis=0)
    dx_stndrd = dx_strchd * cache["gamma"]
    dx_cntrd = dx_stndrd * (1/cache["denom"])
    ddenom = np.sum(dx_stndrd * (-1*cache["x_cntrd"]*cache["denom"]**(-2)), axis=0)
    dv_shifted = ddenom * (0.5*(cache["v_shifted"]**(-0.5)))
    dv = dv_shifted
    dx_avg_sqwrd = -1 * dv
    dx_avg = dx_avg_sqwrd * 2 * cache["x_avg"] + np.sum(dx_cntrd * (-1), axis = 0)
    dx_sqwrd_avg = dv
    dx_sqwrd = dx_sqwrd_avg * np.ones(cache["x_sqwrd"].shape)/cache["x_sqwrd"].shape[0]
    dx = dx_sqwrd * 2 * cache["x"] + \
         dx_avg * np.ones(cache["x"].shape)/cache["x"].shape[0] + \
         dx_cntrd
    grad = {"dx": dx,
            "dx_avg":dx_avg,
            "dx_sqwrd":dx_sqwrd,
            "dx_avg_sqwrd":dx_avg_sqwrd,
            "dx_sqwrd_avg":dx_sqwrd_avg,
            "dv":dv,
            "dv_shifted":dv_shifted,
            "ddenom":ddenom,
            "dx_cntrd":dx_cntrd,
            "dx_stndrd":dx_stndrd,
            "dx_strchd":dx_strchd,
             "dgamma":dgamma,
             "dbeta":dbeta
            }
    return grad



In [571]:
# check any
N = 10
D = 5
delta = 1e-7
replacements = {}

inspected = "x"
replacements[inspected] = (np.random.rand(N,D) + 1)/100

x_dummy = np.random.rand(N,D)
gamma = np.ones(D) + 0.1
beta = np.zeros(D) + 0.1

out, cache = batchnorm_forward_1(x_dummy, gamma, beta, replacements)
w = np.random.rand(N,D)
dout = w
s = np.sum(out*w)
grad = batchnorm_backward_1(dout, cache=cache)



if len(replacements[inspected].shape) == 2:
    print(2)
    grad_manual = np.zeros((N, D))
    for i in range(N):
        for j in range(D):
            v_shifted_mod = copy.deepcopy(replacements)
            v_shifted_mod[inspected][i, j] = v_shifted_mod[inspected][i, j] + delta
            out_mod, cache_mod = batchnorm_forward_1(x_dummy, gamma, beta, v_shifted_mod)
            s_mod = np.sum(out_mod*w)
            grad_manual[i, j] = (s_mod - s)/delta
    print(rel_error(grad_manual,grad["d"+inspected]))
    
elif len(replacements[inspected].shape) == 1:
    print(1)
    grad_manual = np.zeros((D))
    for j in range(D):
        v_shifted_mod = copy.deepcopy(replacements)
        v_shifted_mod[inspected][j] = v_shifted_mod[inspected][j] + delta
        #print(v_shifted_mod[inspected]-replacements[inspected])
        out_mod, cache_mod = batchnorm_forward_1(x_dummy, gamma, beta, v_shifted_mod)
        s_mod = np.sum(out_mod*w)
        grad_manual[j] = (s_mod - s)/delta
    print(rel_error(grad_manual,grad["d"+inspected]))

2
4.8611279461042055e-05


In [576]:
# Check the training-time forward pass by checking means and variances
# of features both before and after batch normalization   

# Simulate the forward pass for a two-layer network
np.random.seed(231)
N, D1, D2, D3 = 200, 50, 60, 3
X = np.random.randn(N, D1)
W1 = np.random.randn(D1, D2)
W2 = np.random.randn(D2, D3)
a = np.maximum(0, X.dot(W1)).dot(W2)

print('Before batch normalization:')
print_mean_std(a,axis=0)

gamma = np.ones((D3,))
beta = np.zeros((D3,))
# Means should be close to zero and stds close to one
print('After batch normalization (gamma=1, beta=0)')
a_norm, _ = batchnorm_forward(a, gamma, beta)
print_mean_std(a_norm,axis=0)

gamma = np.asarray([1.0, 2.0, 3.0])
beta = np.asarray([11.0, 12.0, 13.0])
# Now means should be close to beta and stds close to gamma
print('After batch normalization (gamma=', gamma, ', beta=', beta, ')')
a_norm, _ = batchnorm_forward(a, gamma, beta)
print_mean_std(a_norm,axis=0)

Before batch normalization:
  means:  [ -2.3814598  -13.18038246   1.91780462]
  stds:   [27.18502186 34.21455511 37.68611762]

After batch normalization (gamma=1, beta=0)
  means:  [4.66293670e-17 3.55271368e-17 1.85962357e-17]
  stds:   [0.99999999 1.         1.        ]

After batch normalization (gamma= [1. 2. 3.] , beta= [11. 12. 13.] )
  means:  [11. 12. 13.]
  stds:   [0.99999999 1.99999999 2.99999999]



In [552]:
# Gradient check batchnorm backward pass
np.random.seed(231)
N, D = 4, 5
x = 5 * np.random.randn(N, D) + 12
gamma = np.random.randn(D)
beta = np.random.randn(D)
dout = np.random.randn(N, D)

bn_param = {'mode': 'train'}
fx = lambda x: batchnorm_forward_1(x, gamma, beta)[0]
fg = lambda a: batchnorm_forward_1(x, a, beta)[0]
fb = lambda b: batchnorm_forward_1(x, gamma, b)[0]

dx_num = eval_numerical_gradient_array(fx, x, dout)
da_num = eval_numerical_gradient_array(fg, gamma.copy(), dout)
db_num = eval_numerical_gradient_array(fb, beta.copy(), dout)

_, cache = batchnorm_forward_1(x, gamma, beta)
grad_for_check = batchnorm_backward_1(dout, cache)
dx, dgamma, dbeta = grad_for_check["dx"], grad_for_check["dgamma"], grad_for_check["dbeta"]
#You should expect to see relative errors between 1e-13 and 1e-8
print('dx error: ', rel_error(dx_num, dx))
print('dgamma error: ', rel_error(da_num, dgamma))
print('dbeta error: ', rel_error(db_num, dbeta))

dx error:  6.6664407728161185e-09
dgamma error:  2.2067811910924764e-12
dbeta error:  2.6853898000928757e-12
