In [7]:
import numpy as np

In [8]:
# copy from https://github.com/oreilly-japan/deep-learning-from-scratch/blob/master/common/functions.py

def softmax(x):
    x = x - np.max(x, axis=-1, keepdims=True)   # オーバーフロー対策
    return np.exp(x) / np.sum(np.exp(x), axis=-1, keepdims=True)


def cross_entropy_error(y, t):
    if y.ndim == 1:
        t = t.reshape(1, t.size)
        y = y.reshape(1, y.size)
        
    # 教師データがone-hot-vectorの場合、正解ラベルのインデックスに変換
    if t.size == y.size:
        t = t.argmax(axis=1)
             
    batch_size = y.shape[0]
    return -np.sum(np.log(y[np.arange(batch_size), t] + 1e-7)) / batch_size

In [9]:
# copy from https://github.com/oreilly-japan/deep-learning-from-scratch/blob/master/common/gradient.py

def numerical_gradient(f, x):
    h = 1e-4 # 0.0001
    grad = np.zeros_like(x)
    
    it = np.nditer(x, flags=['multi_index'], op_flags=['readwrite'])
    while not it.finished:
        idx = it.multi_index
        tmp_val = x[idx]
        x[idx] = tmp_val + h
        fxh1 = f(x) # f(x+h)
        
        x[idx] = tmp_val - h 
        fxh2 = f(x) # f(x-h)
        grad[idx] = (fxh1 - fxh2) / (2*h)
        
        x[idx] = tmp_val # 値を元に戻す
        it.iternext()   
        
    return grad

In [10]:
class simpleNet:
    def __init__(self):
        self.W = np.random.randn(2,3)

    def predict(self, x):
        return np.dot(x, self.W)
    
    def loss(self, x, t):
        z = self.predict(x)
        y = softmax(z)
        loss = cross_entropy_error(y, t)
        return loss

In [11]:
net = simpleNet()
print(net.W)

x = np.array([0.6, 0.9])
p = net.predict(x)
print(p)
print(np.argmax(p))

t = np.array([0, 0, 1])
print(net.loss(x, t))

[[-2.32232647 -0.58241634  1.6354068 ]
 [ 1.54783475  2.81513125  0.68314506]]
[-3.44607467e-04  2.18416833e+00  1.59607463e+00]
1
1.0996692604175307


In [12]:
# ニューラルネットの重みWを変数に、損失関数の値を返す関数
# 引数Wを直接扱ってないように見えるが、net.loss → net.predict といった先で self.Wを参照してるので実際には変数Wに対する勾配を求めている
# 入力データxと、正解ラベルtは固定
def f(W):
    return net.loss(x, t)

dW = numerical_gradient(f, net.W)
print(dW)

[[ 0.04048132  0.35972991 -0.40021122]
 [ 0.06072198  0.53959486 -0.60031683]]
