In [1]:
import numpy as np
import sys

sys.path.append('../')

from mandala.nodecore import Node
from mandala.nodecore import Variable

from mandala import basic_math
from mandala.autodiff import autodiff

basic_math.install_node_arithmetics()
autodiff.install_node_backward()

## AutoDiff supporters
以下の関数は、backward における Node の生成をサポートする。Node を受け取り Node を返す点に注意。

In [50]:
class Add(autodiff.AutoDiff):
    
    def forward(self, xs):
        y = basic_math.add(*xs)
        return y

    def backward(self, xs, gy):
        return gy, gy

def add(lhs, rhs):
    return Add()([lhs, rhs])


class Mul(autodiff.AutoDiff):

    def forward(self, xs):
        y = basic_math.mul(*xs)
        return y

    def backward(self, xs, gy):
        gx0 = basic_math.mul(xs[1], gy)
        gx1 = basic_math.mul(xs[0], gy)
        return gx0, gx1

def mul(lhs, rhs):
    return Mul()([lhs, rhs])


class Sub(autodiff.AutoDiff):
    
    def forward(self, xs):
        y = basic_math.sub(*xs)
        return y

    def backward(self, xs, gy):
        return gy, basic_math.neg(gy)


def sub(lhs, rhs):
    return Sub()([lhs, rhs])


def rsub(rhs, lhs):
    return Sub()([lhs, rhs])


class Div(autodiff.AutoDiff):
    
    def forward(self, xs):
        y = basic_math.div(*xs)
        return y

    def backward(self, xs, gy):

        def _calc_gx0(x0, x1, gy):
            return 1 / x1 * gy

        def _calc_gx1(x0, x1, gy):
            return - x0 / (x1 ** 2) * gy

        gx0 = Node(_calc_gx0, [*xs, gy])
        gx1 = Node(_calc_gx1, [*xs, gy])
        return gx0, gx1


def div(lhs, rhs):
    return Div()([lhs, rhs])


def rdiv(rhs, lhs):
    return Div()([lhs, rhs])


def floordiv(lhs, rhs):
    raise NotImplementedError


def rfloordiv(rhs, lhs):
    raise NotImplementedError


class Pow(autodiff.AutoDiff):

    def forward(self, xs):
        y = basic_math.pow(*xs)
        return y

    def backward(self, xs, gy):

        def _calc_gx0(x0, x1, gy):
            return x1 * (x0 ** (x1 - 1)) * gy

        def _calc_gx1(x0, x1, gy):
            return x0 ** x1 * np.log(x0)

        gx0 = Node(_calc_gx0, [*xs, gy])
        gx1 = Node(_calc_gx1, [*xs, gy])
        return gx0, gx1


def pow(lhs, rhs):
    return Pow()([lhs, rhs])


def rpow(rhs, lhs):
    return Pow()([lhs, rhs])


class Neg(autodiff.AutoDiff):

    def forward(self, xs):
        y = basic_math.neg(*xs)
        return y

    def backward(self, xs, gy):
        return basic_math.neg(gy),


def neg(x):
    return Neg()([x])



def absolute(a):
    raise NotImplementedError


def matmul(lhs, rhs):
    raise NotImplementedError


def rmatmul(lhs, rhs):
    raise NotImplementedError


def install_node_arithmetics():
    Node.__neg__ = neg
    Node.__abs__ = absolute
    Node.__add__ = add
    Node.__radd__ = add
    Node.__sub__ = sub
    Node.__rsub__ = rsub
    Node.__mul__ = mul
    Node.__rmul__ = mul
    Node.__div__ = div
    Node.__truediv__ = div
    Node.__rdiv__ = rdiv
    Node.__rtruediv__ = rdiv
    Node.__floordiv__ = floordiv
    Node.__rfloordiv__ = rfloordiv
    Node.__pow__ = pow
    Node.__rpow__ = rpow
    Node.__matmul__ = matmul
    Node.__rmatmul__ = rmatmul


In [51]:
install_node_arithmetics()

In [52]:
x = Variable(5.)

In [53]:
y = - (x ** 2 + 4 * x - 2)

In [54]:
y.data

-43.0

In [55]:
y.backward()

In [56]:
x.grad.data

-14.0

In [36]:
x -= 1

In [11]:
x.data

4.0

In [12]:
y.args

(<mandala.nodecore.Variable at 0x17811cfe4e0>,
 <mandala.nodecore.Variable at 0x17811cfcbe0>)