# Custom graph

In [15]:
import operator
from numbers import Number
import numpy as np

op_map = {
    operator.neg: "-",
    operator.add: "+",
    operator.mul: "*",
    operator.sub: "-",
    operator.truediv: "/",
    operator.pow: "**",
    np.reciprocal: "reciprocal",
    np.exp: "exp",
    np.sqrt: "sqrt"
}

In [18]:
class Node:  
    def __pow__(a, p): return Pow(a, p)
    def __neg__(a):    return Negative(a)
    def __mul__(a, b): return Mul(a, b)
    def __add__(a, b): return Add(a, b)
    def __sub__(a, b): return Add(a, Negative(b))
    def __truediv__(a, b): return a * Reciprocal(b)
    def __rtruediv__(a, b):
        if not isinstance(b, Number):
            raise NotImplementedError
        return Mul(Const(b), Reciprocal(a))
    
    def __rsub__(a, b):
        if not isinstance(b, Number):
            raise NotImplementedError
        return Add(Const(b), Negative(a))
    
    def __radd__(a, b):
        if not isinstance(b, Number):
            raise NotImplementedError
        return Add(Const(b), a)
    
    def __rmul__(a, b):
        if not isinstance(b, Number):
            raise NotImplementedError
        return Mul(Const(b), a)
    
    def _str(self): return op_map[self.op]

In [19]:
class Const(Node):
    def __init__(self, value):
        self.value = value
    def evaluate(self):
        return self.value
    def __str__(self):
        return str(self.value)
    def gradient(self, var):
        return Const(0)
    
class Var(Node):
    def __init__(self, name, init_value=0):
        self.value = init_value
        self.name = name
    def evaluate(self):
        return self.value
    def __str__(self):
        return self.name
    
    __repr__ = __str__
    
    def gradient(self, var):
        return Const(1) if self is var else Const(0)
    
def maybe_cast(a):
    if isinstance(a, Node):
        return a
    elif isinstance(a, Number):
        return Const(a)
    else:
        raise NotImplementedError

In [20]:
class Unary(Node):
    def __init__(self, a):
        self.a = maybe_cast(a)
    
    def __str__(self):
        return f"{self._str()}({self.a})"
    
    __repr__  = __str__
    
    def evaluate(self):
        return self.op(self.a.evaluate())
    
    def gradient(self, var):
        return Mul(self.derivative(self.a), self.a.gradient(var))

class Reciprocal(Unary):
    op = np.reciprocal
    def derivative(self, var):
        return Negative(Mul(Reciprocal(var), Reciprocal(var)))
    
class Negative(Unary):
    op = operator.neg
    def derivative(self, var):
        return Negative(Const(1))
    
class Exp(Unary):
    op = np.exp
    def derivative(self, var):
        return self.__class__(var)
    
class Sqrt(Unary):
    op = np.sqrt
    def derivative(self, var):
        return Reciprocal(Mul(Const(2), Sqrt(var)))

In [21]:
class BinaryOperator(Node):
    def __init__(self, a, b):
        self.a = maybe_cast(a)
        self.b = maybe_cast(b)
    
    def __str__(self):
        return "{} {} {}".format(self.a, self._format(), self.b)
    
    __repr__ = __str__
    
    def evaluate(self):
        return self.op(self.a.evaluate(), self.b.evaluate())

class Add(BinaryOperator):
    op = operator.add
    def gradient(self, var):
        return Add(self.a.gradient(var), self.b.gradient(var))

class Mul(BinaryOperator):
    op = operator.mul
    def gradient(self, var):
        return Add(Mul(self.a, self.b.gradient(var)), 
                   Mul(self.a.gradient(var), self.b))

class Pow(BinaryOperator):
    op = operator.pow
    def gradient(self, var):
        return self.b * Pow(self.a, self.b - 1) * self.a.gradient(var)

In [22]:
x = Var(name="x", init_value=3.)
y = Var(name="y", init_value=4.)
f = x**2 * y + y + 2

dfdx = f.gradient(x)  # 2xy
dfdy = f.gradient(y)  # x² + 1

In [23]:
dfdx.evaluate(), dfdy.evaluate()

(24.0, 10.0)

In [24]:
d2fdxdx = dfdx.gradient(x) # 2y
d2fdxdy = dfdx.gradient(y) # 2x
d2fdydx = dfdy.gradient(x) # 2x
d2fdydy = dfdy.gradient(y) # 0

In [25]:
[[d2fdxdx.evaluate(), d2fdxdy.evaluate()],
 [d2fdydx.evaluate(), d2fdydy.evaluate()]]

[[8.0, 6.0], [6.0, 0.0]]