In [116]:
def _get_output(x):
    if isinstance(x, Node):
        return x()
    else:
        return x

class Node(object):
    def __init__(self, func, args):
        self.func = func
        if not isinstance(args, (tuple, list)):
            self.args = [args]
        else:
            self.args = args
        self.output = None

    def __call__(self):
        if self.output is None:
            args = [_get_output(x) for x in self.args]
            self.output = self.func(*args)
        return self.output

    @property
    def data(self):
        return self.__call__()

In [117]:
class Variable(Node):
    def __init__(self, x):
        self.func = lambda x: x
        self.args = [x]
        self.output = None

In [119]:
def identity(x):
    return x

def add(a, b):
    return a + b

def mul(a, b):
    return a * b

In [135]:
x = Variable(5)

h0 = Node(mul, [x, x])
h1 = Node(mul, [Variable(2), x])
h2 = Node(add, [h0, h1])
y = Node(add, [h2, Variable(1)])

In [170]:
import numpy as np

class Linear(object):
    def __init__(self, in_ch, out_ch):
        self.W = Variable(
            np.random.normal(-1, 1, (out_ch, in_ch)).astype('f')
        )
    
    def forward(self, x):
        self.x = x
        self.y = Node(linear_forward, [x, self.W])
        return self.y

    def backward(self):
        self.W.grad = Node(linear_forward_W, [self.x, self.y.grad])
        self.x.grad = Node(linear_forward_W, [self.x, self.y.grad])
        return None

def linear_forward(x, W):
    y = np.matmul(x, W.T)
    return y

def linear_backward_W(x, gy):
    gW = np.matmul(gy.T, x)
    return gW

def linear_backward_x(W, gy):
    gx = np.matmul(gy, W)
    return gx

In [151]:
x = np.ones((8, 5))
W = np.ones((10, 5))
gy = np.ones((8, 10))

In [149]:
np.matmul(x, W.T)

(8, 10)

In [169]:
np.matmul(gy, W)

(8, 5)

In [144]:
Linear(2, 3)

<__main__.Linear at 0x22cc2c532b0>

In [30]:
x = Node(identity, [1])