In [23]:
from sympy import *
from sympy.abc import x
import operator
import random
from collections import deque
import math
from sympy import Interval, S, Union, solveset

In [248]:
class BinaryTree():
    def __init__(self, value):
        self.left = None
        self.right = None
        self.value = value
#         self.domain = None
        
    def __str__(self):
        return str(self.value)
        
    def get_left_child(self):
        if self.left:
            return self.left
        
    def get_right_child(self):
        if self.right:
            return self.right
        
    def set_left_child(self, tree):
        if self.left == None:
            self.left = tree
        else:
            tree.left = self.left
            self.left = tree
        
    def set_right_child(self, tree):
        if self.right == None:
            self.right = tree
        else:
            tree.right = self.right
            self.right = tree
        
    def set_node_value(self, value):
        self.value = value
        
    def get_node_value(self):
        return self.value

    def insert_right(self, new_node):
        if self.right == None:
            self.right = BinaryTree(new_node)
        else:
            self.right.value = new_node

    def insert_left(self, new_node):
        if self.left == None:
            self.left = BinaryTree(new_node)
        else:
            self.left.value = new_node            

    def print_expression(self):
        if self.value in name_opers:
            print('(', end = '')
            self.left.print_expression()
            print(self.value, end = '')
            self.right.print_expression()
            print(')', end = '')
        elif self.value in name_functions:
            print(self.value + '(', end = '')
            self.left.print_expression()
            print(')', end = '')
        elif self.value in variance:
            print(self.value, end = '')
        elif isinstance(self.value, int):
            print(self.value, end = '')

    def print_tree(self, depth):
        current_level = [self]
        a = ' ' * (2 ** depth)
        while depth:
            a = a[:len(a)//2]
            next_level = list()
            for n in current_level:
                print(a + str(n.value), end = '') if n != None else print(a, end = '')
                if n == None or n.left == None:
                    next_level.append(None)
                else:
                    next_level.append(n.left)
                if n == None or n.right == None:
                    next_level.append(None)
                else:
                    next_level.append(n.right)
                current_level = next_level
            depth -= 1
            print("\n")
            

    def check_poly(self):
        left_flag = True
        right_flag = True
        center_flag = True
        
        if (self.left != None):
            left_flag = self.left.check_poly()
        if (self.right != None):
            right_flag = self.right.check_poly()
        if self.value not in ['x', '+', '-', '*'] and not isinstance(self.value, int):
            center_flag = False
            
        return left_flag & right_flag & center_flag
    
    def check_x(self):
        left_flag = False
        right_flag = False
        center_flag = False
        
        if (self.left != None):
            left_flag = self.left.check_x()
        if (self.right != None):
            right_flag = self.right.check_x()
        if self.value == 'x':
            center_flag = True
            
        return left_flag | right_flag | center_flag


    def get_poly_domain(self):
        expr = self.evaluate()
        solve = list(solveset(expr, x, domain = S.Reals).args)


        if not solve:
            if isinstance(expr, int):
                res = expr
            else:
                res = N(expr.subs(x, 0))

            if res > 0:
                return S.Reals
            else:
                return S.EmptySet
        elif solve[0] == -oo and solve[1] == oo:
            return S.EmptySet
        else:
            solve = [N(_) for _ in solve]
            solve.sort()
            domain = S.EmptySet
            eps = 10e-5
            n = len(solve)

            for i in range(0, n):
                res = N(expr.subs(x, solve[i] - eps))
                if (res > 0):
                    if (i == 0):
                        domain = Union(domain, Interval.open(-oo, solve[i]))
                    else:
                        domain = Union(domain, Interval.open(solve[i - 1], solve[i]))

            if ( N(expr.subs(x, solve[n - 1] + eps)) > 0):
                domain = Union(domain, Interval.open(solve[n - 1], oo))
            return domain.intersect(Interval(-max_constant, max_constant))

    def evaluate(self):
        left_child = self.get_left_child()
        right_child = self.get_right_child()

        if left_child and right_child:
            fn = opers[self.get_node_value()]
            return fn(evaluate(left_child), evaluate(right_child))
        elif left_child:
            return functions[self.get_node_value()](evaluate(left_child))
        else:
            num = self.get_node_value()
            if num == 'x':
                return symbols(num)
            else:
                return int(num)
            
    def get_domain(self):
#         print(self.value)
        if (self.check_poly()):
            return self.get_poly_domain()
        elif (self.left == None and self.right == None):
            return Interval.open(-max_constant, max_constant)
        else:
            if (self.left != None):
                left_domain = self.left.get_domain()
            if (self.right != None):
                right_domain = self.right.get_domain()
            else:
                right_domain = Interval.open(-max_constant, max_constant)

            domain = left_domain.intersect(right_domain)
            
            if (self.value not in ['sqrt', 'ln']):
                return domain
            else:
                if (self.left.check_x()):
                    expr = self.left.evaluate()
                    new_domain = S.EmptySet

                    positive_point = []
                    borders =  list(domain.boundary)
                    n = len(borders) // 2
                    new_domain = S.EmptySet


                    for i in range(n):
                        l = borders[i * 2]
                        r = borders[i * 2 + 1]
                        points = linspace(l, r, 1e-1)
                        prev_point = points[0]
                        interval = False
                        for p in points:
                            val = N(expr.subs(x, p))
                            if val > 0 and interval == False:
                                positive_point.append(p)
                                interval = True
                            if val < 0 and interval == True:
                                positive_point.append(prev_point)
                                interval = False
                            prev_point = p  

                        if len(positive_point) % 2 == 1:
                            positive_point.append(points[len(points) - 1]) 

                    m = len(positive_point)
                    m = m // 2
                    for i in range(m):
                        l = positive_point[i * 2]
                        r = positive_point[i * 2 + 1]
                        if (l != r):
                            new_domain = new_domain.union(Interval(l, r))
                    return new_domain
                else:
                    if N(self.left.evaluate()) > 0:
                        return Interval(1e-1, max_constant)
                    else:
                        return S.EmptySet

In [32]:
def linspace(l, r, eps):
    ans = []
    m = l
    while m <= r:
        ans.append(m)
        m += eps
    return ans

In [33]:
global max_constant
max_constant = 10

In [34]:
operation_type = ['O', 'F', 'V', 'C']
name_opers = ['+', '-', '*', '/', '**']
opers = {'+': operator.add, '-': operator.sub, '*':operator.mul, '/': operator.truediv, '**': operator.pow}
name_functions = ['exp', 'ln', 'sqrt', 'sin', 'cos']
functions = {'sin': sin, 'cos': cos, 'exp': exp, 'ln': log, 'sqrt':sqrt}

In [35]:
def rand_type():
    global var_weight
    weight = [20, 10, var_weight, var_weight]
    ans = random.choices(operation_type, weight)[0]   
    var_weight += 1
    return ans

def rand_oper(weights = [4, 3, 2, 2, 1]):
    return random.choices(name_opers, weights)[0]  

def rand_const(min = 1, max = 5):
    return str(random.randint(min, max))

def rand_func():
    return random.choices(name_functions)[0]

In [36]:
def generate_tree(depth):
    if depth == 1:
        if random.randint(0, 1):
            return BinaryTree(rand_const()) 
        else:
            return BinaryTree('x')  
    else:
        t = rand_type()
        if t == "O":
            tree = BinaryTree(rand_oper())
            tree.set_left_child(generate_tree(depth - 1))
            tree.set_right_child(generate_tree(depth - 1))
        elif t == "F":
            tree = BinaryTree(rand_func())
            tree.set_left_child(generate_tree(depth - 1))
        elif t == "V":
            tree = BinaryTree('x')
        else:
            tree = BinaryTree(int(rand_const()))
        return tree

In [311]:
var_weight = 0
gt = generate_tree(4)

gt.print_tree(4)

        +

    ln    **

  sin    /  -

 x    5 x x 4



In [312]:
gt.get_domain()

[0.1, 3.1] U [6.29999999999999, 9.39999999999998]