In [1]:
import numpy as np

In [2]:
def softmax_forward(x):
    exps = np.exp(x)
    sexps = np.sum(exps, axis=1).reshape(-1,1)
    divexps = (1/sexps).reshape(-1,1)
    mul = exps * divexps
    return mul, {"exps":exps, "sexps":sexps, "divexps":divexps}

def softmax_backward(dout, cache, print_=False):
    softmax_grad = {}
    
    dexps_0 = cache["divexps"] * dout
    softmax_grad["dexps_0"] = dexps_0

    ddivexps = np.sum(cache["exps"] * dout, axis=1).reshape(-1,1)
    softmax_grad["ddivexps"] = ddivexps
    
    dsexps = -1.0/(cache["sexps"]**2) * ddivexps
    softmax_grad["dsexps"] = dsexps
    
    dexps_1 = dsexps/(np.sum(cache["exps"], axis=1).reshape(-1,1))*cache["exps"]
    softmax_grad["dexps_1"] = dexps_1
    
    dx = cache["exps"] * (dexps_0 + dexps_1)
    softmax_grad["dx"] = dx
    
    if print_:
        print("dexps_0")
        print(dexps_0.shape)
        print(dexps_0)
        print("ddivexps")
        print(ddivexps.shape)
        print(ddivexps)
        print("dsexps")
        print(dsexps.shape)
        print(dsexps)
        print("dexps_1")
        print(dexps_1.shape)
        print(dexps_1)
        print("dx")
        print(dx.shape)
        print(dx)

    return softmax_grad

In [3]:
a = np.asarray([[1., 2., 3.], [4., 5., 5.]])
N, D = a.shape[0], a.shape[1]
delta = 1e-7

In [18]:
softmax, cache = softmax_forward(a)
softmax_grad = softmax_backward(dout = np.ones(a.shape), cache=cache)

# softmax_mod = {}
# cache_mod = {}
grad_manual = np.zeros((N,D))
for i in range(N):
    for j in range(D):
        a_mod = a.copy()
        a_mod[i, j] = a_mod[i, j] + delta
        softmax_mod, cache_mod = softmax_forward(a_mod)
        grad_manual[i,j] = (softmax_mod[i,j] - softmax[i,j])/delta

In [19]:
softmax_grad

{'dexps_0': array([[0.0331204 , 0.0331204 , 0.0331204 ],
        [0.00284556, 0.00284556, 0.00284556]]),
 'ddivexps': array([[ 30.19287485],
        [351.42446824]]),
 'dsexps': array([[-0.0331204 ],
        [-0.00284556]]),
 'dexps_1': array([[-0.00298185, -0.0081055 , -0.02203304],
        [-0.00044209, -0.00120173, -0.00120173]]),
 'dx': array([[0.08192507, 0.18483645, 0.22269543],
        [0.13122493, 0.24396563, 0.24396563]])}

In [6]:
print("dexps_0 expected to be:")
print(grad_manual/cache["exps"] - softmax_grad["dexps_1"])
print("dexps_1 expected to be:")
dexps_1_exp = (grad_manual/cache["exps"] - softmax_grad["dexps_0"])
print(dexps_1_exp)

dexps_0 expected to be:
[[0.0331204  0.0331204  0.0331204 ]
 [0.00284556 0.00284556 0.00284556]]
dexps_1 expected to be:
[[-0.00298185 -0.0081055  -0.02203304]
 [-0.00044209 -0.00120173 -0.00120173]]


In [17]:
x_copy = a.copy()
x_copy[0,0] += 1e-7
(softmax_forward(x_copy)[0] - softmax_forward(a)[0])/1e-7

array([[ 0.08192507, -0.02203305, -0.05989203],
       [ 0.        ,  0.        ,  0.        ]])

In [15]:
softmax_forward(x_copy)[0]

array([[0.09003058, 0.24472847, 0.66524095],
       [0.1553624 , 0.4223188 , 0.4223188 ]])