In [13]:
from __future__ import annotations
from typing import Callable
from dataclasses import dataclass
import math


@dataclass
class Node:
    def diff(self, var: Symbol) -> Node:
        raise NotImplementedError("Differentiation not implemented for this node.")
    
    def __str__(self) -> str:
        raise NotImplementedError


@dataclass
class Constant(Node):
    value: float

    def diff(self, var: Symbol) -> Node:
        return Constant(0)

    def __str__(self) -> str:
        return str(self.value)


@dataclass
class Symbol(Node):
    name: str

    def diff(self, var: Symbol) -> Node:
        return Constant(1) if self.name == var.name else Constant(0)

    def __str__(self) -> str:
        return self.name


@dataclass
class Operation(Node):
    operator: str
    operands: list[Node]

    def __str__(self) -> str:
        operand_str = f" {self.operator} ".join(map(str, self.operands))
        return f"({operand_str})"


@dataclass
class BinaryOperation(Operation):
    @property
    def left(self) -> Node:
        return self.operands[0]

    @property
    def right(self) -> Node:
        return self.operands[-1]


@dataclass
class Add(BinaryOperation):
    def __init__(self, left: Node, right: Node) -> None:
        super().__init__('+', [left, right])

    def diff(self, var: Symbol) -> Node:
        return Add(self.left.diff(var), self.right.diff(var))


@dataclass
class Subtract(BinaryOperation):
    def __init__(self, left: Node, right: Node) -> None:
        super().__init__('-', [left, right])

    def diff(self, var: Symbol) -> Node:
        return Subtract(self.left.diff(var), self.right.diff(var))


@dataclass
class Multiply(BinaryOperation):
    def __init__(self, left: Node, right: Node) -> None:
        super().__init__('*', [left, right])

    def diff(self, var: Symbol) -> Node:
        return Add(Multiply(self.left.diff(var), self.right), Multiply(self.left, self.right.diff(var)))


@dataclass
class Divide(BinaryOperation):
    def __init__(self, left: Node, right: Node) -> None:
        super().__init__('/', [left, right])

    def diff(self, var: Symbol) -> Node:
        numerator = Subtract(Multiply(self.left.diff(var), self.right), Multiply(self.left, self.right.diff(var)))
        denominator = Multiply(self.right, self.right)
        return Divide(numerator, denominator)


@dataclass
class Exponentiation(BinaryOperation):
    def __init__(self, base: Node, exponent: Node) -> None:
        super().__init__('^', [base, exponent])

    def diff(self, var: Symbol) -> Node:
        u = self.left
        v = self.right
        if isinstance(v, Constant):
            return Multiply(Multiply(v, Exponentiation(u, Constant(v.value - 1))), u.diff(var))
        else:
            return Multiply(Exponentiation(u, v), Add(Multiply(v.diff(var), Function("ln", [u])), Multiply(v, Divide(u.diff(var), u))))


@dataclass
class Function(Node):
    func_name: str
    arguments: list[Node]

    def diff(self, var: Symbol) -> Node:
        if self.func_name == "exp":
            return Multiply(Function("exp", self.arguments), self.arguments[0].diff(var))
        elif self.func_name == "sqrt":
            sqrt_term = Function("sqrt", self.arguments)
            return Multiply(Divide(Constant(1), Multiply(Constant(2), sqrt_term)), self.arguments[0].diff(var))
        elif self.func_name == "ln":
            return Multiply(Divide(Constant(1), self.arguments[0]), self.arguments[0].diff(var))
        raise NotImplementedError(f"Differentiation not implemented for function {self.func_name}")

    def __str__(self) -> str:
        arg_str = ", ".join(map(str, self.arguments))
        return f"{self.func_name}({arg_str})"


### Simplification System ###

@dataclass
class SimplificationSystem:
    rules: list[Callable[[Node], Node]]

    def apply(self, node: Node) -> Node:
        simplified_node = node
        while True:
            new_node = self._apply_rules_recursively(simplified_node)
            if new_node == simplified_node:
                break
            simplified_node = new_node
        return simplified_node

    def _apply_rules_recursively(self, node: Node) -> Node:
        # First, apply rules to the operands (recursive step)
        if isinstance(node, Operation):
            node.operands = [self._apply_rules_recursively(op) for op in node.operands]

        # Now, apply the rules to the current node
        for rule in self.rules:
            simplified_node = rule(node)
            if simplified_node != node:
                # If a simplification happened, restart to apply all rules on the new simplified node
                return self._apply_rules_recursively(simplified_node)

        return node



### Simplification Rule Functions ###

def simplify_multiply_by_zero(node: Node) -> Node:
    if isinstance(node, Multiply):
        if isinstance(node.left, Constant) and node.left.value == 0:
            return Constant(0)
        if isinstance(node.right, Constant) and node.right.value == 0:
            return Constant(0)
    return node

def simplify_divide_by_zero_numerator(node: Node) -> Node:
    if isinstance(node, Divide):
        if isinstance(node.left, Constant) and node.left.value == 0:
            return Constant(0)
    return node

def simplify_multiply_by_one(node: Node) -> Node:
    if isinstance(node, Multiply):
        if isinstance(node.left, Constant) and node.left.value == 1:
            return node.right
        if isinstance(node.right, Constant) and node.right.value == 1:
            return node.left
    return node

def simplify_add_zero(node: Node) -> Node:
    if isinstance(node, Add):
        if isinstance(node.left, Constant) and node.left.value == 0:
            return node.right
        if isinstance(node.right, Constant) and node.right.value == 0:
            return node.left
    return node

def simplify_subtract_zero(node: Node) -> Node:
    if isinstance(node, Subtract):
        if isinstance(node.right, Constant) and node.right.value == 0:
            return node.left
    return node

def simplify_exponentiation(node: Node) -> Node:
    if isinstance(node, Exponentiation):
        if isinstance(node.right, Constant):
            if node.right.value == 0:
                return Constant(1)
            if node.right.value == 1:
                return node.left
    return node

def simplify_fractional_multiplication(node: Node) -> Node:
    if isinstance(node, Multiply):
        # Check if one operand is a division of the same symbol (a / a)
        if isinstance(node.left, Divide) and node.left.left == node.left.right:
            return node.right  # a/a * b = b
        if isinstance(node.right, Divide) and node.right.left == node.right.right:
            return node.left  # b * a/a = b
    elif isinstance(node, Divide):
        # Check if division comes first and is then multiplied (b * a/a = b)
        if node.left == node.right:
            return Constant(1)  # Simplify a / a = 1
    return node

def simplify_divide_with_common_factor(node: Node) -> Node:
    if isinstance(node, Divide):
        numerator = node.left
        denominator = node.right
        
        # If numerator is a multiplication, check for common factors
        if isinstance(numerator, Multiply):
            if numerator.left == denominator:
                return numerator.right  # a * b / a = b
            if numerator.right == denominator:
                return numerator.left  # b * a / a = b
    return node



# Example usage with the SimplificationSystem
# Normal distribution representation using the symbolic math graph
def normal_distribution(mu: Node, sigma: Node, x: Node) -> Node:
    two = Constant(2)
    pi = Constant(math.pi)
    half = Constant(0.5)

    sqrt_2_pi = Function("sqrt", [Multiply(two, pi)])
    denom = Multiply(sigma, sqrt_2_pi)
    exponent = Multiply(half, Function("exp", [Multiply(Constant(-1), Exponentiation(Divide(Subtract(x, mu), sigma), Constant(2)))]))
    return Multiply(Divide(Constant(1), denom), exponent)


# Example usage
mu = Symbol("mu")
sigma = Symbol("sigma")
x = Symbol("x")

# Create the normal distribution expression
normal_dist = normal_distribution(mu, sigma, x)
print(f"Normal distribution expression: {normal_dist}")

# Differentiate the normal distribution with respect to x
normal_dist_diff = normal_dist.diff(x)
print(f"Derivative of normal distribution w.r.t. x: {normal_dist_diff}")

# Simplification System
simplification_system = SimplificationSystem([
    simplify_multiply_by_zero,
    simplify_divide_by_zero_numerator,
    simplify_multiply_by_one,
    simplify_add_zero,
    simplify_subtract_zero,
    simplify_exponentiation,
    simplify_fractional_multiplication,
    simplify_divide_with_common_factor,
])

# Simplify the derivative
simplified_diff = simplification_system.apply(normal_dist_diff)
print(f"Simplified derivative: {simplified_diff}")


Normal distribution expression: ((1 / (sigma * sqrt((2 * 3.141592653589793)))) * (0.5 * exp((-1 * (((x - mu) / sigma) ^ 2)))))
Derivative of normal distribution w.r.t. x: (((((0 * (sigma * sqrt((2 * 3.141592653589793)))) - (1 * ((0 * sqrt((2 * 3.141592653589793))) + (sigma * ((1 / (2 * sqrt((2 * 3.141592653589793)))) * ((0 * 3.141592653589793) + (2 * 0))))))) / ((sigma * sqrt((2 * 3.141592653589793))) * (sigma * sqrt((2 * 3.141592653589793))))) * (0.5 * exp((-1 * (((x - mu) / sigma) ^ 2))))) + ((1 / (sigma * sqrt((2 * 3.141592653589793)))) * ((0 * exp((-1 * (((x - mu) / sigma) ^ 2)))) + (0.5 * (exp((-1 * (((x - mu) / sigma) ^ 2))) * ((0 * (((x - mu) / sigma) ^ 2)) + (-1 * ((2 * (((x - mu) / sigma) ^ 1)) * ((((1 - 0) * sigma) - ((x - mu) * 0)) / (sigma * sigma))))))))))
Simplified derivative: ((1 / (sigma * sqrt((2 * 3.141592653589793)))) * (0.5 * (exp((-1 * (((x - mu) / sigma) ^ 2))) * (-1 * ((2 * ((x - mu) / sigma)) * (sigma / (sigma * sigma)))))))


In [14]:
a = Symbol("a")
b = Symbol("b")
expr = Divide(Multiply(a, b), a)

print(simplification_system.apply(expr))

b


In [12]:
print(expr)


((a * b) / a)


In [9]:
simplified_diff.left.left.left

Subtract(operator='-', operands=[Constant(value=0), Add(operator='+', operands=[Constant(value=0), Multiply(operator='*', operands=[Symbol(name='sigma'), Multiply(operator='*', operands=[Divide(operator='/', operands=[Constant(value=1), Multiply(operator='*', operands=[Constant(value=2), Function(func_name='sqrt', arguments=[Multiply(operator='*', operands=[Constant(value=2), Constant(value=3.141592653589793)])])])]), Add(operator='+', operands=[Constant(value=0), Constant(value=0)])])])])])

In [14]:
simplification_system.apply(Add(Multiply(Constant(0), Constant(2)), Constant(3)))

Add(operator='+', operands=[Constant(value=0), Constant(value=3)])