In [634]:
#prefix operators
#infix operators

class SqlExpression:
    def __init__(self, operator, *operands):
        self.operator = operator
        self.operands = operands
    
    @classmethod
    def BuildAtomicExpression(cls, opr, lhs, hrs):
        return SqlVariable(cls(opr, lhs, hrs).to_sql())

    def to_sql(self):
        
        left = SqlVariable(self.operands[0]) if type(self.operands[0]) != SqlVariable else self.operands[0]
        right = SqlVariable(self.operands[1]) if type(self.operands[1]) != SqlVariable else self.operands[1]
        
        if self.operator != 'POWER':
            return f"({' '.join(map(str, (left.to_sql(), self.operator, right.to_sql())))})"
        else:
            return f"{self.operator}( {left.to_sql()}, {right.to_sql()} )"
            

class SqlVariable:
    def __init__(self, name):
        self.name = name
        
    #Comparison
    def __lt__(self, rhs):
        return SqlExpression.BuildAtomicExpression('<', self, rhs)
    
    def __le__(self, rhs):
        return SqlExpression.BuildAtomicExpression('<=', self, rhs)
    
    def __eq__(self, rhs):
        return SqlExpression.BuildAtomicExpression('=', self, rhs)
    
    def __ne__(self, rhs):
        return SqlExpression.BuildAtomicExpression('!=', self, rhs)
    
    def __gt__(self, rhs):
        return SqlExpression.BuildAtomicExpression('>', self, rhs)
    
    def __ge__(self, rhs):
        return SqlExpression.BuildAtomicExpression('<=', self, rhs)
    
    #binary arithmetic operations
    def __add__(self, rhs):
        return SqlExpression.BuildAtomicExpression('+', self, rhs)
    
    def __sub__(self, rhs):
        return SqlExpression.BuildAtomicExpression('-', self, rhs)
    
    def __mul__(self, rhs):
        return SqlExpression.BuildAtomicExpression('*', self, rhs)
    
    def __truediv__(self, rhs):
        return SqlExpression.BuildAtomicExpression('/', self, rhs)
    
    def __mod__(self, rhs):
        return SqlExpression.BuildAtomicExpression('%', self, rhs)
    
    def __pow__(self, rhs):
        return SqlExpression.BuildAtomicExpression('POWER', self, rhs)
    
    #reflected binary arithmetic operations
    def __radd__(self, lhs):
        return SqlExpression.BuildAtomicExpression('+', lhs, self)
    
    def __rsub__(self, lhs):
        return SqlExpression.BuildAtomicExpression('-', lhs, self)
    
    def __rmul__(self, lhs):
        return SqlExpression.BuildAtomicExpression('*', lhs, self)
    
    def __rtruediv__(self, lhs):
        return SqlExpression.BuildAtomicExpression('/', lhs, self)
    
    def __rmod__(self, lhs):
        return SqlExpression.BuildAtomicExpression('%', lhs, self)
    
    def __rpow__(self, lhs):
        return SqlExpression.BuildAtomicExpression('POWER', lhs, self)
    
    def __getattr__(self, attr):
        return SqlVariable(f"{self.name}->'{attr}'")
    
    def to_sql(self):
        return self.name

def test_sql_equality_expression():
    #x is an instance of a class
    expr = (X == 3).to_sql()
    assert expr == '(x = 3)'

def test_sql_gt():
    expr = (X > 15).to_sql()
    assert expr == '(x > 15)'

def test_sql_multiplication():
    expr = ((X * 2) > 12).to_sql()
    assert expr == '((x * 2) > 12)'

def test_sql_json_property_access():
    expr = (X.country == 'Argentina').to_sql()
    assert expr == "x->'country' = 'Argentina'"

def test_nested_property_access():
    expr = (X.country.capitol == 'Bueños Aires').to_sql()
    assert expr == "x->'country'->'capitol' = 'Bueños Aires'"

def test_compare_on_property_access():
    expr = (X.num_windows > 5).to_sql()
    assert expr == "(x->'num_windows' > 5)"

def test_combines_with_other_vars():
    col1 = SqlVariable('col1')
    col2 = SqlVariable('col2')
    expr = ((X + col1) > col2).to_sql()
    assert expr == '(x + col1) > col2'
"""
SELECT * FROM my_table
WHERE (x + col1) > col2
"""
    

'\nSELECT * FROM my_table\nWHERE (x + col1) > col2\n'

In [635]:
col1 = SqlVariable('col1')

In [636]:
col2 = SqlVariable('col2')

In [637]:
X = SqlVariable('x')

In [638]:
((X + 'col1') > 'col2').to_sql()

<__main__.SqlExpression at 0x106b2ef28>

In [614]:
(X + col1 > col2).to_sql()

'((x + col1) > col2)'

In [615]:
col1.to_sql()

'col1'

In [616]:
(col1 + 5)

<__main__.SqlVariable at 0x106b6aac8>

In [618]:
(col1 + 5 + 6 + 7).to_sql()

'(((col1 + 5) + 6) + 7)'

In [619]:
((X * X) > X).to_sql()

'((x * x) > x)'

In [620]:
(X == 3).to_sql()

'(x = 3)'

In [621]:
(X > 3).to_sql()

'(x > 3)'

In [622]:
X.to_sql()

'x'

In [623]:
(3> X).to_sql()

'(x < 3)'

In [624]:
((X *5) + (3 + X)).to_sql()

'((x * 5) + (3 + x))'

In [625]:
(X * 5 + 3 + X).to_sql()

'(((x * 5) + 3) + x)'

In [626]:
(X * 5 + 3 / X).to_sql()

'((x * 5) + (3 / x))'

In [627]:
(X ** 2 + 3).to_sql()

'(POWER( x, 2 ) + 3)'

In [628]:
(X.country == 'Argentina').to_sql()

"(x->'country' = Argentina)"

In [629]:
(X.num_windows > 5).to_sql()

"(x->'num_windows' > 5)"

In [630]:
(X.country.capitol == 'Bueños Aires').to_sql()

"(x->'country'->'capitol' = Bueños Aires)"

In [631]:
(X.hospital.num_windows / 100 + 150 > 1000).to_sql()

"(((x->'hospital'->'num_windows' / 100) + 150) > 1000)"