In [19]:

class Base(object):
    def forward(self):
        raise NotImplementedError()
    def backward(self, upstream_grad):
        raise NotImplementedError()

class Multiple(Base):
    def __init__(self, x = None, y = None):
        self.ctx = None

    def forward(self, x, y):
        self.save_ctx(x, y)
        return x * y
    """
        x * y's grad.
    """
    def backward(self, upstream_grad):
        x, y = self.saved_ctx()
        return y * upstream_grad, x * upstream_grad

    def save_ctx(self, x, y):
        self.ctx = (x, y)
    
    def saved_ctx(self):
        return self.ctx[0], self.ctx[1]

class Add(Base):
    def __init__(self, x = None, y = None):
        self.ctx = None
    
    def forward(self, x, y):
        self.save_ctx(x, y)
        return x + y
    
    def backward(self, upstream_grad):
        return upstream_grad, upstream_grad
    
    def save_ctx(self, x, y):
        self.ctx = (x, y)

    def saved_ctx(self):
        return self.ctx[0], self.ctx[1]

class One(Base):
    def __init__(self, x = None):
        self.ctx = None

    def forward(self, x):
        self.save_ctx(x)
        return x

    def backward(self, upstream_grad):
        return upstream_grad

    def save_ctx(self, x):
        self.ctx = x

    def saved_ctx(self):
        return self.ctx[0]

In [28]:
mul = Multiple()
add = Add()

x, y = 3, 4

z = mul.forward(x, y)
z = add.forward(z, 5)

grad_add_x, grad_add_y = add.backward(1)
grad_mul_x, grad_mul_y = mul.backward(grad_add_y)
print(grad_mul_y)


3
