In [12]:
class Expr : # abstract class
    #pass # eval (self:Expr,env:dictionary) : Int returns a number.

    def checkEquality(self, other, env):
        self_value = self.eval(env)
        other_value = other.eval(env)
        return self_value == other_value

class Var (Expr) :
    def __init__(self,name) :
        self.name = name
    
    def eval(self,env) :
        return env[self.name]
    
    def __str__(self) :
        return self.name


class Const (Expr) :
    def __init__(self,value) :
        self.value = value
    def eval(self,env) :
        return self.value
    def __str__(self) :
        return str(self.value)


class BinOp(Expr) : # abstract class
    def __init__(self,left,right) :
        self.left = left
        self.right = right
    
    # op(int,int) : int
    
    def eval(self,env) :
        return self.op(self.left.eval(env),self.right.eval(env))

    def __str__(self):
        left_var = self.left
        right_var = self.right
        op_sign = self.op_sign
        return f"{left_var}{op_sign}{right_var}"


class Plus(BinOp) :
    op_sign = '+'

    def op(self,x,y) :
        return x+y

    def __str__(self) :
        return f"({BinOp.__str__(self)})"


# return str(self.left)+"+"+str(self.right)
# return self.left.__str__()+"+"+self.right.__str__()
# now we see that __str__ is actually recursive!
class Times(BinOp) :
    op_sign = '*'

    def op(self,x,y) :
        return x*y


In [16]:
e1 = Times(
    Plus(
        Var("x"),
        Const(2)
        ),
    Var("y")
    )

e2 = Plus(
    Var("x"),
    Times(
        Const(2),
        Var("y")
        )
    )

e3 = Plus(
    Times(
        Var("x"),
        Var("x")
        ),
    Plus(
        Times(Var("x"),Var("y")),
        Times(
            Var("y"),
            Times(
                Var("y"),
                Var("y")
                )
            )
        )
    )

e4 = Times(
    Var("x"),
    Var("y")
    )

e5 = Times(
    Var("y"),
    Var("x")
)

env_data = {
    'x': 2,
    'y': 3
}

In [17]:
print(e1)
print(e2)
print(e3)
print(e4)
print(e5)

(x+2)*y
(x+2*y)
(x*x+(x*y+y*y*y))
x*y
y*x


In [18]:
e1.checkEquality(e2, env_data)

False

In [19]:
e2.checkEquality(e2, env_data)

True

In [20]:
e4.checkEquality(e4, env_data)

True