In [1]:
from sympy import *
from sympy.abc import x, y
import operator
import random
from collections import deque
import math

In [2]:
operation_type = ['O', 'F', 'V', 'C']
name_opers = ['+', '-', '*', '/', '**']
variance = ['x', 'y']
opers = {'+': operator.add, '-': operator.sub, '*':operator.mul, '/': operator.truediv, '**': operator.pow}
name_functions = ['exp', 'ln', 'sqrt', 'sin', 'cos', 'tan', 'acos', 'asin', 'atan']
functions = {'sin': sin, 'cos': cos, 'tan': tan, 'acos': acos, 'asin': asin, 
             'atan': atan, 'exp': exp, 'ln': log, 'sqrt':sqrt}

In [3]:
def rand_type():
    global var_weight
    var_weight += 1
    weight = [10, 0, var_weight, var_weight]
    return random.choices(operation_type, weight)[0]   

In [4]:
def rand_oper(weights = [40, 20, 30, 0, 0]):
    return random.choices(name_opers, weights)[0]    

In [5]:
def rand_const(min = 1, max = 5):
    return str(random.randint(min, max))

In [6]:
def rand_var(weights = [70, 0]):
    return random.choices(variance, weights)[0] 

In [7]:
def rand_func(weights = [10, 20, 20, 5, 5, 5, 1, 1, 1]):
    return random.choices(name_functions, weights)[0]    

In [17]:
class BinaryTree():
    def __init__(self, value):
        self.parent = None
        self.left = None
        self.right = None
        self.value = value
        self.domain = S.Reals
        
    def __str__(self):
        return str(self.value)
        
    def get_left_child(self):
        if self.left:
            return self.left
        
    def get_right_child(self):
        if self.right:
            return self.right
        
    def set_left_child(self, tree):
        if self.left == None:
            self.left = tree
        else:
            tree.left = self.left
            self.left = tree
        
    def set_right_child(self, tree):
        if self.right == None:
            self.right = tree
        else:
            tree.right = self.right
            self.right = tree
        
    def set_node_value(self, value):
        self.value = value
        
    def get_node_value(self):
        return self.value

    def insert_right(self, new_node):
        if self.right == None:
            self.right = BinaryTree(new_node)
        else:
            self.right.value = new_node

    def insert_left(self, new_node):
        if self.left == None:
            self.left = BinaryTree(new_node)
        else:
            self.left.value = new_node
            
    def print_expression(self):
        if self.value in name_opers:
            print('(', end = '')
            self.left.print_expression()
            print(self.value, end = '')
            self.right.print_expression()
            print(')', end = '')
        elif self.value in name_functions:
            print(self.value + '(', end = '')
            self.left.print_expression()
            print(')', end = '')
        elif self.value in variance:
            print(self.value, end = '')
        elif isinstance(self.value, int):
            print(self.value, end = '')
            
    def print_tree(self, depth):
        current_level = [self]
        a = ' ' * (2 ** depth)
        while depth:
            a = a[:len(a)//2]
            next_level = list()
            for n in current_level:
                print(a + str(n.value), end = '') if n != None else print(a, end = '')
                if n == None or n.left == None:
                    next_level.append(None)
                else:
                    next_level.append(n.left)
                if n == None or n.right == None:
                    next_level.append(None)
                else:
                    next_level.append(n.right)
                current_level = next_level
            depth -= 1
            print("\n")
            
    def check_poly(self):
        left_flag = True
        right_flag = True
        center_flag = True
        
        if (self.left != None):
            left_flag = self.left.check_poly()
        if (self.right != None):
            right_flag = self.right.check_poly()
        if self.value not in ['x', '+', '-', '*'] and not isinstance(self.value, int):
            print(self.value)
            center_flag = False
            
        return left_flag & right_flag & center_flag

In [9]:
def generate_tree(depth):
    if depth == 1:
        if random.randint(0, 1):
            return BinaryTree(rand_const()) 
        else:
            return BinaryTree(rand_var())  
    else:
        t = rand_type()
        if t == "O":
            tree = BinaryTree(rand_oper())
            tree.set_left_child(generate_tree(depth - 1))
            tree.set_right_child(generate_tree(depth - 1))
        elif t == "F":
            tree = BinaryTree(rand_func())
            tree.set_left_child(generate_tree(depth - 1))
        elif t == "V":
            tree = BinaryTree(rand_var())
        else:
            tree = BinaryTree(int(rand_const()))
        return tree

In [10]:
def evaluate(tree):
    left_child = tree.get_left_child()
    right_child = tree.get_right_child()

    if left_child and right_child:
        fn = opers[tree.get_node_value()]
        return fn(evaluate(left_child), evaluate(right_child))
    elif left_child:
        return functions[tree.get_node_value()](evaluate(left_child))
    else:
        num = tree.get_node_value()
        if num == 'x' or num == 'y':
            return symbols(num)
        else:
            return num

In [19]:
var_weight = 0
gt = generate_tree(7)
gt.print_tree(7)

                                                                +

                                +                                1

                x                -                                

                        +        -                                

                    +    x    x    x                                

                  x  x                                            

                                                                



In [20]:
gt.check_poly()

True

In [156]:
expr1 = x - 3
expr2 = x ** 2 - 2

In [163]:
d1 = get_domain(expr1)
d2 = get_domain(expr2)
d3 = get
d = Intersection(d1, d2)

In [191]:
d1

(3.0, oo)

In [192]:
d2

(-oo, -1.4142135623731) U (1.4142135623731, oo)

In [193]:
d

(3.0, oo)

In [189]:
expr = ln(sqrt(x - 3) + sqrt(x ** 2 - 2))

In [190]:
solveset(expr, x, domain = d)

ConditionSet(x, Eq(log(sqrt(x - 3) + sqrt(x**2 - 2)), 0), (3.0, oo))

In [155]:
def get_domain(expr):
    solve = list(solveset(expr, x, domain = S.Reals).args)
    

    if not solve:
        if isinstance(expr, int):
            res = expr
        else:
            res = N(expr.subs(x, 0))
            
        if res > 0:
            return S.Reals
        else:
            return S.EmptySet
    elif solve[0] == -oo and solve[1] == oo:
        return S.EmptySet
    else:
        solve = [N(_) for _ in solve]
        solve.sort()
        domain = S.EmptySet
        eps = 10e-5
        n = len(solve)

        for i in range(0, n):
            res = N(expr.subs(x, solve[i] - eps))
#             print(i - eps, ": ", res)
            if (res > 0):
                if (i == 0):
                    domain = Union(domain, Interval.open(-oo, solve[i]))
                else:
                    domain = Union(domain, Interval.open(solve[i - 1], solve[i]))

        if ( N(expr.subs(x, solve[n - 1] + eps)) > 0):
            domain = Union(domain, Interval.open(solve[n - 1], oo))
        return domain