In [585]:
import numpy

class Gate:
    def __init__(self, x=None, y=None, x_callback=None, y_callback=None, weight_gate=False):
        self.dx = None
        self.dy = None
        self.dz = None
        self.z = None
        self.weight_gate = weight_gate
        self.final = None
        self.x_callback = x_callback
        self.y_callback = y_callback
        self.f_callback = None
        if self.x_callback:
            self.x_callback.compute_forward()
            self.x = self.x_callback.z
            self.x_callback.f_callback = self.forward
        else:
            self.x = x
        if self.y_callback:
            self.y_callback.compute_forward()
            self.y = self.y_callback.z
            self.y_callback.f_callback = self.forward
        else:
            self.y = y
        self.compute_forward()
        
    def forward(self):
        self.refresh_inputs()
        self.compute_forward()
        if self.f_callback:
            self.gradients()
            self.f_callback()
        else:
            return(self.z)
        
    def refresh_inputs(self):
        if self.x_callback:
            print(self.x, '->', self.x_callback.z)
            self.x = self.x_callback.z
        if self.y_callback:
            print(self.y, '->', self.y_callback.z)
            self.y = self.y_callback.z
        
    def compute_forward(self):
        pass
        
    def backward(self, dz=1):
        self.dz = dz
        self.compute_backward()
        if self.x_callback:
            self.x_callback.backward(self.dx)
        if self.y_callback:
            self.y_callback.backward(self.dy)
        if self.x_callback == None and self.y_callback == None:
            self.update()
            self.forward()
            
    def compute_backward(self):
        pass
    
    def gradients(self):
        return
        print(self.x, self.y)
        print('dx =', self.dx, 'dy =', self.dy)
        
    def update(self):
        # update only if its a variable weight gate.
        if self.weight_gate:
            wt_upd = self.x + self.dx / self.x
            print(self.x , '-->', wt_upd)
            self.x = wt_upd
    
class AddGate(Gate):
    def __init__(self, x, y, x_callback=None, y_callback=None):
        super().__init__(x, y, x_callback, y_callback)
    def compute_forward(self):
        print('forward+', self.x, self.y)
        self.z = self.x + self.y
    def compute_backward(self):
        print('backward', self.x, self.y)
        self.dx = self.dz
        self.dy = self.dz
        
class MultiplyGate(Gate):
    def __init__(self, x, y, x_callback=None, y_callback=None):
        super().__init__(x, y, x_callback, y_callback)
    def compute_forward(self):
        print('forward*', self.x, self.y)
        self.z = self.x * self.y
    def compute_backward(self):
        self.dx = self.y * self.dz
        self.dy = self.x * self.dz
        
class PundirGate(Gate):
    def __init__(self, x, y, x_callback=None, y_callback=None):
        super().__init__(x, y, x_callback, y_callback)
    def compute_forward(self):
        self.z = self.x + self.y
    def compute_backward(self):
        self.dx = self.y + (self.dz - self.z)
        self.dy = self.x + (self.dz - self.z)
        
class WeightGate(Gate):
    def __init__(self, x=None, y=None, x_callback=None, y_callback=None, weight_gate=True):
        super().__init__(x, y, x_callback, y_callback, weight_gate)
    def compute_forward(self):
        print('weight-forward')
        self.z = self.x
    def compute_backward(self):
        self.dx = self.dz

In [588]:
w1 = WeightGate(2)
a1 = MultiplyGate(None, 3, w1, None)
w2 = WeightGate(4)
a2 = AddGate(None, 6, w2, None)
a3 = MultiplyGate(None, None, a1, a2)
a3.z

weight-forward
weight-forward
forward* 2 3
weight-forward
weight-forward
forward+ 4 6
forward* 2 3
forward+ 4 6
forward* 6 10


60

In [589]:
# Better to redesign it, now that we know the full picture.
# Maybe on paper first
a3.backward(a3.z)
# weight node is being called multiple thats why
print('a3', a3.z, a3.x, a3.y, a3.dz)
print('a2', a2.z, a2.x, a2.y, a2.dz)
print('w2', w2.z, w2.x, w2.y, w2.dz)
print('a1', a1.z, a1.x, a1.y, a1.dz)
print('w1', w1.z, w1.x, w1.y, w1.dz)

2 --> 902.0
weight-forward
2 -> 902.0
forward* 902.0 3
6 -> 2706.0
10 -> 10
forward* 2706.0 10
backward 4 6
4 --> 94.0
weight-forward
4 -> 94.0
forward+ 94.0 6
2706.0 -> 2706.0
10 -> 100.0
forward* 2706.0 100.0
a3 270600.0 2706.0 100.0 60
a2 100.0 94.0 6 360
w2 94.0 94.0 None 360
a1 2706.0 902.0 3 600
w1 902.0 902.0 None 1800
