In [1]:
import numpy as np
import sys

sys.path.append('../')

from mandala.nodecore import Node
from mandala.nodecore import Variable

from mandala.autodiff import basic_math
from mandala.autodiff import autodiff

In [2]:
def linear_forward(x, W):
    y = np.matmul(x, W.T)
    return y


def linear_backward_W(x, gy):
    gW = np.matmul(gy.T, x)
    return gW


def linear_backward_x(W, gy):
    gx = np.matmul(gy, W)
    return gx


class LinearFunction(autodiff.AutoDiff):

    def forward(self, xs):
        x, W = xs
        y = Node(linear_forward, [x, W])
        return y

    def backward(self, xs, gy):
        x, W = xs
        gW = Node(linear_backward_W, [x, gy])
        gx = Node(linear_backward_x, [W, gy])
        return gx, gW


class Linear(object):
    def __init__(self, in_ch, out_ch):
        self.W = Variable(
            np.random.normal(0, 1 / out_ch, (out_ch, in_ch)).astype('f')
        )

    def __call__(self, x):
        return LinearFunction()([x, self.W])

In [3]:
def sum_forward(x):
    return np.sum(x)


def sum_backward(x, gy):
    return np.ones_like(x) * gy


class SumFunction(autodiff.AutoDiff):
    def forward(self, xs):
        x = xs[0]
        y = Node(sum_forward, [x])
        return y

    def backward(self, xs, gy):
        x = xs[0]
        gx = Node(sum_backward, [x, gy])
        return gx,


def _sum(x):
    return SumFunction()([x])

In [4]:
l0 = Linear( 5, 10)
l1 = Linear(10, 10)
l2 = Linear(10,  3)

In [5]:
# 真の係数
W = np.arange(15, dtype=np.float32).reshape(3, 5)
#W = np.zeros((3, 5), dtype=np.float32)

In [6]:
batchsize = 32
x = Variable(np.random.random((batchsize, 5)).astype(np.float32))
t = Variable(np.matmul(x.data, W.T))

In [7]:
h0 = l0(x)
h1 = l1(h0)
y  = l2(h1)
loss = (y - t) ** 2 / batchsize

In [8]:
loss.backward()

In [9]:
l0.W.grad.data

array([[-1.3350790e+00, -1.3097345e+00, -1.3445444e+00, -1.3944225e+00,
        -1.3555093e+00],
       [-1.2213749e+00, -1.1940767e+00, -1.2292883e+00, -1.2965100e+00,
        -1.2710457e+00],
       [ 1.7635702e+00,  1.7353930e+00,  1.7758858e+00,  1.8072665e+00,
         1.7380435e+00],
       [-2.4629545e+00, -2.4163306e+00, -2.4800611e+00, -2.5691853e+00,
        -2.4955053e+00],
       [-2.5102615e+00, -2.4677534e+00, -2.5285373e+00, -2.5926657e+00,
        -2.5049376e+00],
       [ 1.0786615e-02, -4.2770896e-03,  9.3357731e-03,  9.4534099e-02,
         1.3567208e-01],
       [-3.4154446e+00, -3.3447127e+00, -3.4382198e+00, -3.5946226e+00,
        -3.5080700e+00],
       [-7.3155270e+00, -7.1698360e+00, -7.3649578e+00, -7.6671047e+00,
        -7.4657254e+00],
       [ 9.9035621e-01,  9.7345650e-01,  9.9747097e-01,  1.0230319e+00,
         9.8844206e-01],
       [-5.0363865e+00, -4.9314995e+00, -5.0697474e+00, -5.3027792e+00,
        -5.1761112e+00]], dtype=float32)

In [10]:
lr = 1e-3

for i in range(100):
    # make batch
    x = Variable(np.random.random((batchsize, 5)).astype(np.float32))
    t = Variable(np.matmul(x.data, W.T))
    
    # forward
    h0 = l0(x)
    h1 = l1(h0)
    y  = l2(h1)

    # loss
    loss = _sum((y - t) ** 2) / batchsize
    
    l0.W.grad = 0.
    l1.W.grad = 0.
    l2.W.grad = 0.

    # backward
    loss.backward()
    
    # update
    l0.W.data -= lr * l0.W.grad.data
    l1.W.data -= lr * l1.W.grad.data
    l2.W.data -= lr * l2.W.grad.data

    print(loss.data)

1289.57861328125
1203.0684814453125
1361.8775634765625
1216.796630859375
1148.502197265625
1198.65869140625
1184.680419921875
1080.9808349609375
1425.421630859375
1340.38916015625
1514.458251953125
1162.359130859375
1312.409912109375
1474.2393798828125
1161.37646484375
1255.52294921875
1263.897216796875
1322.7415771484375
1466.624267578125
1048.9349365234375
1373.959228515625
1160.9873046875
1380.9730224609375
1149.5445556640625
1227.605712890625
1141.93115234375
1136.90771484375
1103.2921142578125
1239.6630859375
885.6874389648438
830.4423217773438
651.9061279296875
448.90191650390625
274.09765625
185.13731384277344
50.04864501953125
17.464618682861328
9.264480590820312
4.739041328430176
3.9357776641845703
3.3026888370513916
1.9444721937179565
1.9236879348754883
1.513634204864502
2.148855686187744
1.191824197769165
1.459651231765747
1.2925808429718018
1.6675642728805542
1.4637925624847412
1.3674167394638062
1.693906307220459
1.6847611665725708
1.2400810718536377
1.2170233726501465
1.1

In [11]:
y.data[0], t.data[0]

(array([ 5.104726, 17.702671, 30.313293], dtype=float32),
 array([ 5.27361 , 17.882418, 30.491226], dtype=float32))