In [40]:
from sympy import *
from sympy.abc import x, y
import operator
import random
from collections import deque
import math

In [41]:
operation_type = ['O', 'F', 'V', 'C']
name_opers = ['+', '-', '*', '/', '**']
variance = ['x', 'y']
opers = {'+': operator.add, '-': operator.sub, '*':operator.mul, '/': operator.truediv, '**': operator.pow}
name_functions = ['exp', 'ln', 'sqrt', 'sin', 'cos', 'tan', 'acos', 'asin', 'atan']
functions = {'sin': sin, 'cos': cos, 'tan': tan, 'acos': acos, 'asin': asin, 
             'atan': atan, 'exp': exp, 'ln': log, 'sqrt':sqrt}

In [87]:
def rand_type():
    global var_weight
    var_weight += 1
    weight = [10, 0, var_weight, var_weight]
    return random.choices(operation_type, weight)[0]   

In [88]:
def rand_oper(weights = [40, 20, 30, 0, 0]):
    return random.choices(name_opers, weights)[0]    

In [89]:
def rand_const(min = 1, max = 5):
    return str(random.randint(min, max))

In [90]:
def rand_var(weights = [70, 0]):
    return random.choices(variance, weights)[0] 

In [91]:
def rand_func(weights = [10, 20, 20, 5, 5, 5, 1, 1, 1]):
    return random.choices(name_functions, weights)[0]    

In [110]:
class BinaryTree():
    def __init__(self, value):
        self.parent = None
        self.left = None
        self.right = None
        self.value = value
        
    def __str__(self):
        return str(self.value)
        
    def get_left_child(self):
        if self.left:
            return self.left
        
    def get_right_child(self):
        if self.right:
            return self.right
        
    def set_left_child(self, tree):
        if self.left == None:
            self.left = tree
        else:
            tree.left = self.left
            self.left = tree
        
    def set_right_child(self, tree):
        if self.right == None:
            self.right = tree
        else:
            tree.right = self.right
            self.right = tree
        
    def set_node_value(self, value):
        self.value = value
        
    def get_node_value(self):
        return self.value

    def insert_right(self, new_node):
        if self.right == None:
            self.right = BinaryTree(new_node)
        else:
            self.right.value = new_node

    def insert_left(self, new_node):
        if self.left == None:
            self.left = BinaryTree(new_node)
        else:
            self.left.value = new_node
            
    def print_expression(self):
        if self.value in name_opers:
            print('(', end = '')
            self.left.print_expression()
            print(self.value, end = '')
            self.right.print_expression()
            print(')', end = '')
        elif self.value in name_functions:
            print(self.value + '(', end = '')
            self.left.print_expression()
            print(')', end = '')
        elif self.value in variance:
            print(self.value, end = '')
        elif self.value in ['1', '2', '3', '4', '5']:
            print(self.value, end = '')
            
    def print_tree(self, depth):
        current_level = [self]
        a = ' ' * (2 ** depth)
        while depth:
            a = a[:len(a)//2]
            next_level = list()
            for n in current_level:
                print(a + n.value, end = '') if n != None else print(a, end = '')
                if n == None or n.left == None:
                    next_level.append(None)
                else:
                    next_level.append(n.left)
                if n == None or n.right == None:
                    next_level.append(None)
                else:
                    next_level.append(n.right)
                current_level = next_level
            depth -= 1
            print("\n")
            
    def check_poly(self):
        left_flag = True
        right_flag = True
        center_flag = True
        
        if (self.left != None):
            left_flag = self.left.check_poly()
        if (self.right != None):
            right_flag = self.right.check_poly()
        if self.value not in ['x', '+', '-', '*', '1', '2', '3', '4', '5']:
            center_flag = False
            
        return left_flag & right_flag & center_flag

In [93]:
def generate_tree(depth):
    if depth == 1:
        if random.randint(0, 1):
            return BinaryTree(rand_const()) 
        else:
            return BinaryTree(rand_var())  
    else:
        t = rand_type()
        if t == "O":
            tree = BinaryTree(rand_oper())
            tree.set_left_child(generate_tree(depth - 1))
            tree.set_right_child(generate_tree(depth - 1))
        elif t == "F":
            tree = BinaryTree(rand_func())
            tree.set_left_child(generate_tree(depth - 1))
        elif t == "V":
            tree = BinaryTree(rand_var())
        else:
            tree = BinaryTree(rand_const())
        return tree

In [49]:
def evaluate(tree):
    left_child = tree.get_left_child()
    right_child = tree.get_right_child()

    if left_child and right_child:
        fn = opers[tree.get_node_value()]
        return fn(evaluate(left_child), evaluate(right_child))
    elif left_child:
        return functions[tree.get_node_value()](evaluate(left_child))
    else:
        num = tree.get_node_value()
        if num == 'x' or num == 'y':
            return sp.symbols(num)
        else:
            return int(num)

In [120]:
var_weight = 0
gt = generate_tree(7)
gt.print_tree(7)

                                                                +

                                +                                x

                *                5                                

        2        -                                                

            *    +                                                

          *  3  *  1                                                

         3 x   x x                                                  



In [135]:
gt.check_poly()

True

In [122]:
gt.print_expression()

(((2*(((3*x)*3)-((x*x)+1)))+5)+x)

In [121]:
expr = evaluate(gt)

In [123]:
expr

-2*x**2 + 19*x + 3

In [126]:
roots = real_roots(expr)

In [133]:
roots

[-sqrt(385)/4 + 19/4, 19/4 + sqrt(385)/4]

In [128]:
eps = 1e-4

In [134]:
for r in roots:
    print (N(expr.subs(x, r - eps)))
print (N(expr.subs(x, roots[-1] + eps)))

-0.00196216168702673
0.00196212168703384
-0.00196216168703384
