In [5]:
OPERATORS = {
    '+': (1, 'L'),
    '-': (1, 'L'),
    '*': (2, 'L'),
    '/': (2, 'L'),
    '^': (3, 'R')
}
FUNCTIONS = ['sin', 'cos', 'tan', 'exp', 'ln','log']

In [6]:
class Node:
    def __init__(self, value, left=None, right=None):
        self.value = value
        self.left = left
        self.right = right
    
    def __str__(self):
        if not self.left and not self.right:
            return str(self.value)
        return f"({str(self.left)} {self.value} {str(self.right)})"

    def __repr__(self):
        return self.__str__()

In [7]:
def infix_to_postfix(expression):
    stack = []
    output = []
    tokens = tokenize(expression)

    for token in tokens:
        if token.isalnum() or ('.' in token and token.replace('.', '', 1).isdigit()) or token == 'x':  # Operand
            output.append(token)
        elif token in FUNCTIONS:
            stack.append(token)
        elif token in OPERATORS:
            while (stack and stack[-1] in OPERATORS and
                   ((OPERATORS[token][1] == 'L' and OPERATORS[token][0] <= OPERATORS[stack[-1]][0]) or
                    (OPERATORS[token][1] == 'R' and OPERATORS[token][0] < OPERATORS[stack[-1]][0]))):
                output.append(stack.pop())
            stack.append(token)
        elif token == '(':
            stack.append(token)
        elif token == ')':
            while stack and stack[-1] != '(':
                output.append(stack.pop())
            stack.pop()

    while stack:
        output.append(stack.pop())

    return output

In [8]:
def tokenize(expression):
    tokens = []
    i = 0
    while i < len(expression):
        if expression[i].isspace():  # Skip spaces
            i += 1
            continue
        if expression[i].isdigit() or expression[i] == '.':  # Handle numbers
            token = expression[i]
            i += 1
            while i < len(expression) and (expression[i].isdigit() or expression[i] == '.'):
                token += expression[i]
                i += 1
            tokens.append(token)
        elif expression[i].isalpha():  # Handle variables or functions
            token = expression[i]
            i += 1
            while i < len(expression) and expression[i].isalpha():
                token += expression[i]
                i += 1
            tokens.append(token)
        elif expression[i] in OPERATORS or expression[i] in "()":  # Handle operators and parentheses
            tokens.append(expression[i])
            i += 1
        else:
            raise ValueError(f"Unexpected character: {expression[i]}")
    return tokens

In [11]:
def build_expression_tree(postfix):
    stack = []
    for token in postfix:
        if token.isalnum() or token == 'x' or ('.' in token and token.replace('.', '', 1).isdigit()):
            stack.append(Node(token))
        elif token in OPERATORS or token in FUNCTIONS:
            node = Node(token)
            if token in OPERATORS:  # Binary operators
                node.right = stack.pop()
                node.left = stack.pop()
            elif token in FUNCTIONS:  # Unary functions
                node.left = stack.pop()
            stack.append(node)
    return stack[0]

In [46]:
def differentiate(node):
    if node is None:
        return None

    if node.value == 'x':  # Derivative of x is 1
        return Node('1')
    elif node.value.isdigit():  # Derivative of a constant is 0
        return Node('0')
    elif node.value == '^':
        base = node.left
        exponent = node.right
        # for x^n not f^g
        if exponent.value.isdigit():
            n = float(exponent.value)
            # Differentiate x^n -> n * x^(n-1)
            if n == 1 or n == 0:
                return base  # x^1 -> 1 * x = x
            elif n == 0:
                return Node('0')  # x^0 -> 0
            new_exponent = str(n - 1)
            # Chain rule: n * x^(n-1)
            return Node('*', Node(exponent.value), Node('^', base, Node(new_exponent)))
    elif node.value in OPERATORS:
        if node.value == '+':  # f + g -> f' + g'
            return Node('+', differentiate(node.left), differentiate(node.right))
        elif node.value == '-':  # f - g -> f' - g'
            return Node('-', differentiate(node.left), differentiate(node.right))
        elif node.value == '*': # f * g -> f * g' + f' * g
            left_subtree = Node('*', node.left, differentiate(node.right))
            right_subtree = Node('*', differentiate(node.left), node.right)
            return Node('+', left_subtree, right_subtree)
        elif node.value == '/': # f / g -> (f * g' - f' * g)/g ^ 2
            left_subtree = Node('*', node.left, differentiate(node.right))
            right_subtree = Node('*', differentiate(node.left), node.right)
            numerator = Node('-', left_subtree, right_subtree)
            denominator = Node('*', node.right, node.right)
            return Node('/', numerator, denominator)
        elif node.value == '^': # f^g -> f^g ( (g' * ln(f)) + (g * f' / f))
            power_node = Node('^', node.left, node.right)
            left_term = Node('*', differentiate(node.right), Node('ln', node.left))
            right_term = Node('*', node.right, Node('/', differentiate(node.left), node.left))
            sum_node = Node('+', left_term, right_term)
        return Node('*', power_node, sum_node)
                
    elif node.value in FUNCTIONS:
        pass

In [47]:
def simplify_tree(node):
    if node is None:
        return None

    # Simplify the left and right subtrees first
    node.left = simplify_tree(node.left)
    node.right = simplify_tree(node.right)

    # Now, simplify based on the current node's value
    if node.value == '0':
        return None  # 0 should be removed

    # If it's an operator, look for simplifications
    
    if node.value in OPERATORS:
        # Handling *: any * 0 = 0, x * 1 = x, etc.
        if node.value == '*':
            if node.left == None or node.right == None:
                return Node('0')
            if node.left.value.isdigit():  # If left child is a constant
                constant = node.left.value
                new_node = Node('*', Node(constant), node.right)
                return new_node
            elif node.right.value.isdigit():  # If right child is a constant
                constant = node.right.value
                new_node = Node('*', node.left, Node(constant))
                return new_node
            if node.left.value == '0' or node.right.value == '0':  # x * 0 or 0 * x
                return Node('0')
            if node.left.value == '1':  # 1 * x
                return node.right
            if node.right.value == '1':  # x * 1
                return node.left

        # Handling +: x + 0 = x, 0 + x = x
        if node.value == '+':
            if node.left == None or node.right == None:
                return Node('0')
            if node.left.value == '0':  # 0 + x
                return node.right
            if node.right.value == '0':  # x + 0
                return node.left
        
        # Handling subtraction
        if node.value == '-':
            if node.right.value == '0':  # x - 0
                return node.left
            if node.left.value == '0':  # 0 - x
                return Node('-', node.right, None)  # Return negation (0 - x = -x)

        # Handling exponentiation: x^0 = 1, x^1 = 

    return node


In [48]:
postfix = infix_to_postfix("x^2")
print(postfix)
tree = build_expression_tree(postfix)
print(tree)
diff = differentiate(tree)
print(diff)
output = simplify_tree(diff)
print(output)

['x', '2', '^']
(x ^ 2)
(2 * (x ^ 1.0))
(2 * (x ^ 1.0))


# 