In [667]:
#prefix operators
#infix operators

class SqlExpression:
    def __init__(self, operator, *operands):
        self.operator = operator
        self.operands = operands
    
    @classmethod
    def BuildAtomicExpression(cls, opr, lhs, hrs):
        return SqlVariable(cls(opr, lhs, hrs).to_sql())

    def to_sql(self):
        
        left = SqlVariable(self.operands[0]) if type(self.operands[0]) != SqlVariable else self.operands[0]
        right = SqlVariable(self.operands[1]) if type(self.operands[1]) != SqlVariable else self.operands[1]
        
        if self.operator != 'POWER':
            return f"({' '.join(map(str, (left.to_sql(), self.operator, right.to_sql())))})"
        else:
            return f"{self.operator}( {left.to_sql()}, {right.to_sql()} )"
            

class SqlVariable:
    def __init__(self, name):
        self.name = name
        
    #Comparison
    def __lt__(self, rhs):
        return SqlExpression.BuildAtomicExpression('<', self, rhs)
    
    def __le__(self, rhs):
        return SqlExpression.BuildAtomicExpression('<=', self, rhs)
    
    def __eq__(self, rhs):
        return SqlExpression.BuildAtomicExpression('=', self, rhs)
    
    def __ne__(self, rhs):
        return SqlExpression.BuildAtomicExpression('!=', self, rhs)
    
    def __gt__(self, rhs):
        return SqlExpression.BuildAtomicExpression('>', self, rhs)
    
    def __ge__(self, rhs):
        return SqlExpression.BuildAtomicExpression('<=', self, rhs)
    
    #binary arithmetic operations
    def __add__(self, rhs):
        return SqlExpression.BuildAtomicExpression('+', self, rhs)
    
    def __sub__(self, rhs):
        return SqlExpression.BuildAtomicExpression('-', self, rhs)
    
    def __mul__(self, rhs):
        return SqlExpression.BuildAtomicExpression('*', self, rhs)
    
    def __truediv__(self, rhs):
        return SqlExpression.BuildAtomicExpression('/', self, rhs)
    
    def __mod__(self, rhs):
        return SqlExpression.BuildAtomicExpression('%', self, rhs)
    
    def __pow__(self, rhs):
        return SqlExpression.BuildAtomicExpression('POWER', self, rhs)
    
    #reflected binary arithmetic operations
    def __radd__(self, lhs):
        return SqlExpression.BuildAtomicExpression('+', lhs, self)
    
    def __rsub__(self, lhs):
        return SqlExpression.BuildAtomicExpression('-', lhs, self)
    
    def __rmul__(self, lhs):
        return SqlExpression.BuildAtomicExpression('*', lhs, self)
    
    def __rtruediv__(self, lhs):
        return SqlExpression.BuildAtomicExpression('/', lhs, self)
    
    def __rmod__(self, lhs):
        return SqlExpression.BuildAtomicExpression('%', lhs, self)
    
    def __rpow__(self, lhs):
        return SqlExpression.BuildAtomicExpression('POWER', lhs, self)
    
    def __getattr__(self, attr):
        return SqlVariable(f"{self.name}->'{attr}'")
    
    def to_sql(self):
        return self.name    

def test_sql_equality_expression():
    #x is an instance of a class
    expr = (X == 3).to_sql()
    assert expr == '(x = 3)'

def test_sql_gt():
    expr = (X > 15).to_sql()
    assert expr == '(x > 15)'

def test_sql_multiplication():
    expr = ((X * 2) > 12).to_sql()
    assert expr == '((x * 2) > 12)'

def test_sql_json_property_access():
    expr = (X.country == 'Argentina').to_sql()
    assert expr == "x->'country' = 'Argentina'"

def test_nested_property_access():
    expr = (X.country.capitol == 'Bueños Aires').to_sql()
    assert expr == "x->'country'->'capitol' = 'Bueños Aires'"

def test_compare_on_property_access():
    expr = (X.num_windows > 5).to_sql()
    assert expr == "(x->'num_windows' > 5)"

def test_combines_with_other_vars():
    col1 = SqlVariable('col1')
    col2 = SqlVariable('col2')
    expr = ((X + col1) > col2).to_sql()
    assert expr == '(x + col1) > col2'
    
def test_type_errors_are_caught():
    name_col = SqlVariable('name') # TODO: need some way to specify type of column
    num_floors = SqlVariable('num_floors') # TODO: need some way to speciy ...
    
    try:
        name_col * num_floors
        assert False, "Expected to throw because can't multiply string by number"
    except:
        pass # good.

def test_user_defined_functions():
    class Lower(SqlFunction):
        # Lower accepts a single operand (because only one entry in the types list) of type SqlString
        types = [SqlString]
        return_type = SqlString
        
        def to_sql(operand):
            # TODO 
            return 'lower({})'.format(operand)
        # TODO: how to let user define
    
    name_col = SqlVariable('name')
    expr = Lower(name_col).to_sql()
    assert expr == 'lower(name)'
    
    try:
        expr + 5
        assert False
    except:
        pass # Should raise an error because Lower() returns a String, which can't
        # be added to a number
"""
SELECT * FROM my_table
WHERE (x + col1) > col2
"""
    

'\nSELECT * FROM my_table\nWHERE (x + col1) > col2\n'

In [668]:
col1 = SqlVariable('col1')

In [669]:
col2 = SqlVariable('col2')

In [670]:
X = SqlVariable('x')

In [671]:
((X + 'col1') > 'col2').to_sql()

'((x + col1) > col2)'

In [672]:
(X + col1 > col2).to_sql()

'((x + col1) > col2)'

In [673]:
col1.to_sql()

'col1'

In [674]:
(col1 + 5)

<__main__.SqlVariable at 0x1071116d8>

In [675]:
(col1 + 5 + 6 + 7).to_sql()

'(((col1 + 5) + 6) + 7)'

In [676]:
((X * X) > X).to_sql()

'((x * x) > x)'

In [677]:
(X == 3).to_sql()

'(x = 3)'

In [678]:
(X > 3).to_sql()

'(x > 3)'

In [679]:
X.to_sql()

'x'

In [680]:
(3> X).to_sql()

'(x < 3)'

In [681]:
((X *5) + (3 + X)).to_sql()

'((x * 5) + (3 + x))'

In [682]:
(X * 5 + 3 + X).to_sql()

'(((x * 5) + 3) + x)'

In [683]:
(X * 5 + 3 / X).to_sql()

'((x * 5) + (3 / x))'

In [684]:
(X ** 2 + 3).to_sql()

'(POWER( x, 2 ) + 3)'

In [685]:
(X.country == 'Argentina').to_sql()

"(x->'country' = Argentina)"

In [686]:
(X.num_windows > 5).to_sql()

"(x->'num_windows' > 5)"

In [687]:
(X.country.capitol == 'Bueños Aires').to_sql()

"(x->'country'->'capitol' = Bueños Aires)"

In [688]:
(X.hospital.num_windows / 100 + 150 > 1000).to_sql()

"(((x->'hospital'->'num_windows' / 100) + 150) > 1000)"