Before you turn this problem in, make sure everything runs as expected. First, **restart the kernel** (in the menubar, select Kernel$\rightarrow$Restart) and then **run all cells** (in the menubar, select Cell$\rightarrow$Run All).

Make sure you fill in any place that says `YOUR CODE HERE` or "YOUR ANSWER HERE", as well as your name and collaborators below:

In [None]:
NAME = "Ruchit Patel"
COLLABORATORS = ""

---

# Homework 7 Resubmission: Expression Trees

Copyright Luca de Alfaro, 2019-20. 
License: [CC-BY-NC-ND](https://creativecommons.org/licenses/by-nc-nd/4.0/).

## Submission

[Please submit to this Google Form](https://docs.google.com/forms/d/e/1FAIpQLSe99_Vjy0uHT_PeKcl4dWu1A4KKkAUS1JBhXb679resr0JpNA/viewform?usp=sf_link).

Deadline: Monday November 16, 11pm (check on Canvas for updated information).

## The assignment

There are three questions in this assignment, each labeled `Question n:`, and each followed by one or more test cells:

* A question on printing out expressions in LaTeX, 
* A question on derivatives,
* A question on checking expression equality.

We will develop a data structure to represent arithmetic expressions containing variables, such as $3 + 4$ or $2 + x * (1 - y)$.  

What is an expression?  An expression consists of one of these: 


1. A number
2. A variable
3. If $e_1$ and $e_2$ are expressions, then $e_1 + e_2$, $e_1 - e_2$, $e_1 * e_2$, and $e_1 / e_2$ are also expressions. 

Formally, the set of expressions is the _least_ set constructed according to the rules above. 

Thus, an expression can be either a constant, representing numbers and variables, or a composite expression, consisting of an operator, a left expression, and a right expression.  


There are (at least) two ways of representing expressions. The simplest way is to represent expressions as trees, and define operations on them. 
The more sophisticated way consists in representing expressions via classes: there will be one class for variable and constants, and one class representing composite expressions; both of these classes will be subclasses of a generic "expression" class. 

In this chapter, we will represent expression as trees, to gain experience with writing recursive functions on trees; in the next chapter, we will show how to represent them more elegantly as classes.

We will represent expressions as trees.  A number will be represented via a number; a variable via a string, and the expression $e_1 \odot e_2$ via the tuple $(\odot, e_1, e_2)$, for $\odot \in \{+, -, *, / \}$.

For example, we will represent $2 * (x + 1)$ via:

    ('*', 2, ('+', 'x', 1))

In [None]:
e = ('*', 2, ('+', 'x', 1))


Let us define a check function, in preparation for our first question.

In [None]:
def check_equal(x, y, msg=None):
    if x != y:
        if msg is None:
            print("Error:")
        else:
            print("Error in", msg, ":")
        print("    Your answer was:", x)
        print("    Correct answer: ", y)
    assert x == y, "%r and %r are different" % (x, y)
    print("Success")

## Question 1: Printing an expression in LaTeX format

Our first question asks you to write a function `expr_to_latex` so that, for an expression `e`, `expr_to_latex(e)` will produce a string which is the LaTeX representation of the expression `e`. 
This sounds complicated, but the rules are simple: 

* Numbers and variable names are simply converted to strings (you can use str()). 
* If e = (op, e1, e2), then you compute the strings s1 and s2 representing e1 and e2 via a recursive call.
* If op is one of `"+"`, `"-"`, `"*"`, you proceed as follows: 
    
    * Enclose s1 in parentheses if: 
        
        e1 is not a leaf, and its operator is not `"/"`
        
      To enclose s1 in parentheses, do: `s1 = "(" + s1 + ")"`;
    * Same for e2;
    * Finally, you output s1 + op + s2, where op is one of `+`, `-`, `*`. 
* If you have e1 / e2, you output `\frac{s1}{s2}`.

To test whether parentheses are needed, for an expression e1, you can use this code: 

    if isinstance(e1, tuple) and e1[0] != '/':

Some examples will help: 

If the expression is

    ("+", 3, "x")
    
the latex output is simply `"3+x"`.  If the expression is 

    ("+", 3, ("*", 2, "x"))

you should output 

    "3+(2*x)"
    
where the parentheses have been added because the expression `("*", 2, "x")` is not a leaf. 
If the expression is 

    ("/", ("+", 3, ("*", 2, "x")), ("+", 1, "x"))
    
you should output: 

    \frac{3+(2*x)}{1+x}
    

One word of caution about Python strings.  To include a `\` in a string, you have to quote it, as in `"\\"`. 
So to generate the string `\frac`, you have to actually write `"\\frac"` in your code. 

In [None]:
### Question 1: define expr_to_latex

def expr_to_latex(e):
    # YOUR CODE HERE
    if not isinstance(e, tuple):
        return str(e)
        
    else:
        e1, e2 = e[1], e[2]
        s1, s2, op = expr_to_latex(e1) ,  expr_to_latex(e2), str(e[0])

        if op == "/":
            return "\\frac{" + s1 + "}{" + s2 +"}"
        if isinstance(e1, tuple) and e1[0] != "/":
                s1 = "(" + s1 + ")"
        if isinstance(e2, tuple) and e2[0] != "/":
                s2 = "(" + s2 + ")"  
    
        return s1 + op + s2

In [None]:
# You may want to write in this cell some tests that help you debug your 
# implementation, beyond the tests that are given already.

# YOUR CODE HERE

In [None]:
### Question 1, part 1: 5 points: tests for leaves. 

check_equal(expr_to_latex(3), "3")
check_equal(expr_to_latex("x"), "x")
check_equal(expr_to_latex(3.1), "3.1")
# Hey, I never said that a variable can't be called \foo! 
check_equal(expr_to_latex("\\foo"), "\\foo")


In [None]:
### Question 1, part 2: 5 points: simple expressions. 
e = ("+", 3, "x")
check_equal(expr_to_latex(e), "3+x")
e = ("-", "y", "x")
check_equal(expr_to_latex(e), "y-x")
e = ("*", "y", "2")
check_equal(expr_to_latex(e), "y*2")
e = ("/", "y", "2")
check_equal(expr_to_latex(e), "\\frac{y}{2}")


In [None]:
### Question 1, part 3: 5 points: when to use parentheses.
e = ("*", ("+", 1, "x"), ("-", 3, "y"))
check_equal(expr_to_latex(e), "(1+x)*(3-y)")
e = ("*", ("+", 1, "x"), "z")
check_equal(expr_to_latex(e), "(1+x)*z")
e = ("*", ("-", "\\foo", "hello"), ("+", "hey", "\\dude"))
check_equal(expr_to_latex(e), "(\\foo-hello)*(hey+\\dude)")


In [None]:
### Question 1, part 4, 10 points: general expressions.

e = ("*", ("+", 1, "x"), ("-", ("*", "x", "y"), "y"))
check_equal(expr_to_latex(e), "(1+x)*((x*y)-y)")
e = ("*", ("/", ("*", 3, "y"), "x"), ("-", ("*", "x", "y"), "y"))
check_equal(expr_to_latex(e), "\\frac{3*y}{x}*((x*y)-y)")
e = ("/", ("/", ("*", 3, "y"), "x"), ("-", ("*", "x", "y"), "y"))
check_equal(expr_to_latex(e), "\\frac{\\frac{3*y}{x}}{(x*y)-y}")
e = ("+", "bread", "water")
check_equal(expr_to_latex(e), "bread+water")
e = ("+", "bread\\sour", "water")
check_equal(expr_to_latex(e), "bread\\sour+water")
# Actually, nothing says that there can't be a variable called \frac! 
e = ("*", ("+", "\\frac", "\\frog"), ("+", 3, 2))
check_equal(expr_to_latex(e), "(\\frac+\\frog)*(3+2)")


Note that with a tiny bit of work, we can really use `expr_to_latex` to output expressions in LaTeX for us: 

In [None]:
from IPython.display import display, Math

def niceprint(e):
    display(Math(expr_to_latex(e)))

In [None]:
niceprint(("/", ("+", 3, ("*", 2, "x")), ("+", 1, "x")))

In [None]:
e = ("*", ("/", ("*", 3, "y"), "x"), ("-", ("*", "x", "y"), "y"))
niceprint(e)

Of course, optimizing the beauty of the output, so that the above is rendered as: 

$$
    \frac{3 + 2x}{1 + x}
$$

would require a tiny bit more work.  For instance, we could decide to include the `*` operation only when both left and right subexpressions are numbers.  We should also then convert $x2$ into $2x$... many optimizations are possible.  But we got the basics done. 

### Evaluating expressions with respect to a variable valuation.

Let us define a compute function that evaluates an expression with the help of a _variable valuation_.
The idea is that if we specify values for variables, we can then use those values in computing the value of an expression. 
A _variable valuation_ is a mapping from variables to their values; we can represent it simply as a dictionary associating to each variable a number:

In [None]:
varval = {'x': 3, 'y': 8}



We can compute the value of expressions given a variable valuation as follows: 

In [None]:
from numbers import Number

def compute(e, varval={}):
    if isinstance(e, Number):
        return e
    elif isinstance(e, str):
        v = varval.get(e)
        # If we find a value for e, we return it; otherwise we return e.
        return e if v is None else v
    else:
        op, l, r = e
        # We simplify the left and right subexpressions first.
        ll = compute(l, varval=varval)
        rr = compute(r, varval=varval)
        # And we carry out the operation if we can.
        if isinstance(ll, Number) and isinstance(rr, Number):
            if op == '+':
                return ll + rr
            elif op == '-':
                return ll - rr
            elif op == '*':
                return ll * rr
            elif op == '/' and rr != 0:
                return ll / rr
        # Not simplifiable.
        return (op, ll, rr)


In [None]:
e = ('*', 2, ('+', 'x', ('-', 3, 2)))
print(compute(e))
print(compute(e, varval={'x': 6}))


If we provide the values for only some of the variables, the compute function defined above, will plug in the values for those variables and perform all computations possible.  Of course, if the expression contains variables for which the valuation does not specify a value, the resulting expression will still contain those variables: it will not be simply a number.  In computer science, evaluating an expression as far as possible using the values for a subset of the variables is knwon as _partial evaluation_.

In [None]:
e = ('+', ('-', 'y', 3), ('*', 'x', 4))
print(compute(e, varval={'x': 2}))
print(compute(e, varval={'y': 3}))
print(compute(e, varval={'x': 2, 'y': 3}))


## Question 2: Compute symbolic derivatives

As we have symbolic expressions, we can compute their (partial) derivative with respect to any variable.  Given an expression $e$ and a variable $x$, we denote by $\partial e / \partial x$ the partial derivative of $e$ with respect to $x$.  To compute it, we can simply rely on the definition of derivative. 
For leaf nodes in the expression tree:

* For a constant $c$, $\partial c / \partial x = 0$.
* For a variable $y \neq x$,  $\partial y / \partial x = 0$.
* $\partial x / \partial x = 1$.

For operators, we can use:

$$
 \begin{align*}
 \frac{\partial}{\partial x}(f \pm g) & = \frac{\partial f}{\partial x} \pm \frac  {\partial g}{\partial x}, \\[1ex]
 \frac{\partial}{\partial x}(f \cdot g) & = g \cdot \frac{\partial f}{\partial x}  + f \cdot \frac{\partial g}{\partial x}, \\[1ex]
 \frac{\partial}{\partial x}\left(\frac{f}{g}\right) & = \frac{g \cdot \frac
  {\partial f}{\partial x} - f \cdot \frac{\partial g}{\partial x}}{g^2}. 
\end{align*}
$$

This directly suggest how to implement the symbolic computation of derivatives.

Write a function `derivate` that, given an expression $e$ and a variable $x$, returns an expression for $\partial e / \partial x$.  Please, write it according to the above rules, including order of terms in products.  For instance, use

$$ 
\frac{\partial}{\partial x}(f \cdot g) = g \cdot \frac{\partial f}{\partial x}  + f \cdot \frac{\partial g}{\partial x}
$$
rather than 
$$ 
\frac{\partial}{\partial x}(f \cdot g) = \frac{\partial f}{\partial x} \cdot g + f \cdot \frac{\partial g}{\partial x}
$$

While the two expressions are equivalent, our tests (so far!) can only check for _identical_, not _equivalent_, expressions.

In [None]:
### Implement `derivate`

def derivate(e, x):
    """Returns the derivative of e wrt x.
    It can be done in less than 15 lines of code."""
    # YOUR CODE HERE
    if e == x:
        return 1
    elif isinstance(e, str) or isinstance(e, Number):
        return 0
    else:
        op, s1, s2 = e[0], derivate(e[1], x), derivate(e[2], x)
        if op == "+" or op == "-":
            return (op, s1, s2)
        if op == "*":
            return ("+", (op, e[2], s1), (op, e[1], s2))
        if op == "/":
            return ("/", ("-", ("*", e[2], s1), ("*", e[1], s2)), ("*", e[2], e[2]))


In [None]:
# You may want to write in this cell some tests that help you debug your 
# implementation, beyond the tests that are given already.

# YOUR CODE HERE

In [None]:
### Question 2, part 1: 5 points. 
# Base case tests for `derivate`

# First, the basics.
check_equal(derivate(3, 'x'), 0)
check_equal(derivate('y', 'x'), 0)
check_equal(derivate('x', 'x'), 1)
check_equal(derivate("x", "a"), 0)
check_equal(derivate("a", "a"), 1)



In [None]:
### Question 2, part 2: 5 points. 
### Tests for `derivate` for single-operator expressions

check_equal(derivate(('+', 'x', 'x'), 'x'), ('+', 1, 1))
check_equal(derivate(('-', 4, 'x'), 'x'), ('-', 0, 1))
check_equal(derivate(('*', 2, 'x'), 'x'),
            ('+', ('*', 'x', 0), ('*', 2, 1)))
check_equal(derivate(('/', 2, 'x'), 'x'),
            ('/', ('-', ('*', 'x', 0), ('*', 2, 1)), ('*', 'x', 'x')))
check_equal(derivate(('/', 'a', 'b'), 'a'),
            ('/', ('-', ('*', 'b', 1), ('*', 'a', 0)), ('*', 'b', 'b')))
check_equal(derivate(('/', 'a', 'b'), 'b'),
            ('/', ('-', ('*', 'b', 0), ('*', 'a', 1)), ('*', 'b', 'b')))



In [None]:
### Question 2, part 3: 5 points. 
### Tests for `derivate` for composite expressions

e1 = ('*', 'x', 'x')
e2 = ('*', 3, 'x')
num = ('-', e1, e2)
e3 = ('*', 'a', 'x')
den = ('+', e1, e3)
e = ('/', num, den)

f = ('/',
 ('-',
  ('*',
   ('+', ('*', 'x', 'x'), ('*', 'a', 'x')),
   ('-',
    ('+', ('*', 'x', 1), ('*', 'x', 1)),
    ('+', ('*', 'x', 0), ('*', 3, 1)))),
  ('*',
   ('-', ('*', 'x', 'x'), ('*', 3, 'x')),
   ('+',
    ('+', ('*', 'x', 1), ('*', 'x', 1)),
    ('+', ('*', 'x', 0), ('*', 'a', 1))))),
 ('*',
  ('+', ('*', 'x', 'x'), ('*', 'a', 'x')),
  ('+', ('*', 'x', 'x'), ('*', 'a', 'x'))))

check_equal(derivate(e, 'x'), f)


## Question 3: expression equality

We now consider the following problem: given two expressions $e$ and $f$, how can we decide whether they are equal in value, that is, whether they yield always the same value for all values of the variables? 

This _"value equality"_ is a different notion from the structural equality we defined before.  For instance, the two expressions `V('x') + 1` and `2 * V('x') + 1 - V('x')` are not structurally equal, but they are equal in values.  

How can we test for value equality of expressions?  There are two ways: the high road one, and the pirate one.  

The high-road approach consists in trying to demonstrate, in some way, that the two expressions are equal.  One way of doing so would be to define a set of [rewriting rules](https://en.wikipedia.org/wiki/Rewriting) for expressions, that try to transform one expression into the other; this would mimick the process often done by hand to show that two expressions are equal.  Another way would be to use theorem provers that can reason about expressions and real numbers, such as [PVS](https://pvs.csl.sri.com).  The problem is that these approaches are a lot of work.  Is there a way to be lazy, and still get the job done? 

There is, it turns out.  Suppose you have two expressions $f, g$ containing variable $x$ only.  The idea is that if $f$ and $g$ are built with the usual operators of algebra, it is exceedingly unlikely for $f$ and $g$ to give the same value  many values of $x$, and yet not be always equal.  This would not be true if our expressions could contain if-then-else statements, but for the operators we defined so far, it holds.  Indeed, one could be more precise, and try to come up with a theorem of the form: 

> If $f$ and $g$ have "zerosity" $n$, and are equal for $n+1$ values of $x$, then they are equal for all values of $x$. 

We could then try to define the "zerosity" of an expression to make this hold: for example, for two polynomials of degree at most $d$, once you show that they are equal for $d+1$ points, they must be equal everywhere ([why?](https://en.wikipedia.org/wiki/Fundamental_theorem_of_algebra)).  But this again would be a smart approach, and we are trying to see if we can solve the problem while being as stupid as possible.  So our idea will simply be: pick 1000 values of $x$ at random; if the two expressions are equal for all the values, then they must be equal everywhere.  This is a somewhat special case of a [Monte Carlo method](https://en.wikipedia.org/wiki/Monte_Carlo_method), a method used to estimate the probability of complex phenomena (where expression equality is our phenomenon).

There are only two wrinkles with this.  The first is that an expression can contain many variables, and we have to try to value assignments for all of the variables.  This is easy to overcome; we just need some helper function that gives us the set of variables in a function.  The second wrinkle is: how do we generate the possible value assignments?  How big do these values need to be on average?  According to what probability distribution?  We could dive into a lot of theory and reasoning about how to compute appropriate probability distributions, but since our goal is to be stupid, we will use one of the simplest distributions with infinite domain: the Gaussian one. 

Let us start by writing the function `variables` such that, if `e` is an expression, `variables(e)` is the set of variables that appear in it.

In [None]:
### Exercise: define `variables`

# YOUR CODE HERE
def variables(e):
    if  isinstance(e, str):
        return e
    elif isinstance(e, Number):
        return set()
    else:
        e1, e2 = variables(e[1]), variables(e[2])
        e1_list, e2_list = list(), list()
        if  isinstance(e1, set) and isinstance(e2, set):
            return e1 | e2
        if not isinstance(e1, set) and len(e1) != 0:
            e1_list.append(e1)
        if not isinstance(e2, set) and len(e2) != 0:
            e2_list.append(e2)
        return set(e1_list) | set(e2_list) 

In [None]:
# You may want to write in this cell some tests that help you debug your 
# implementation, beyond the tests that are given already.

# YOUR CODE HERE
e7 = ("+", "aa", "aa")
e8 = ("+", "aa", "ab")
e9 = ("*", 2, "aa")

print("e9_vars: ", variables(e9))

In [None]:
### Question 3, part 1: 5 points
### Tests for `Expr.variables`

e = ('*', ('+', 'x', 2), ('/', 'x', 'y'))
check_equal(variables(e), {'x', 'y'})
e = 5
check_equal(variables(e), set())



Now write the `value_equality` method for expressions.  The idea is perform an equality test `num_samples` times.  Each time, you produce a variable assignment (a dictionary) `d` mapping variables to random values, and then you use the `compute(e1, varval=d)` function to evaluate `e1` under that assignment, and similarly for `e2`.  You can then compare the resulting values, up to the tolerance `tolerance`.  All of this can be done in six lines of code (it's ok if you use a few more). 

In [None]:
### Exercise: implementation of value equality

import random

def value_equality(e1, e2, num_samples=1000, tolerance=1e-6):
    """Return True if the two expressions self and other are numerically
    equivalent.  Equivalence is tested by generating
    num_samples assignments, and checking that equality holds
    for all of them.  Equality is checked up to tolerance, that is,
    the values of the two expressions have to be closer than tolerance.
    It can be done in less than 10 lines of code."""
    # YOUR CODE HERE
    if not (isinstance(e1, tuple) or isinstance(e2, tuple)):
        return e1== e2
    e1_vars, e2_vars = variables(e1), variables(e2)
    common_vars = set(e1_vars) & set(e2_vars)
    for i in range(num_samples):
        rand_value =  random.random()
        temp1 = {x: rand_value for x in common_vars}
        tempe1 = {x: random.random() for x in set(e1_vars) - set(common_vars)}
        e1_d = {**temp1, **tempe1}
        tempe2 = {x: random.random() for x in set(e2_vars) - set(common_vars)}
        e2_d = {**temp1, ** tempe2}
        e1_ret, e2_ret = compute(e1, varval=e1_d), compute(e2, varval=e2_d)
        if not abs(e1_ret - e2_ret <= tolerance):
            return False
    return True


In [None]:
# You may want to write in this cell some tests that help you debug your 
# implementation, beyond the tests that are given already.

# YOUR CODE HERE
e7 = ("+", "aa", "aa")
e8 = ("+", "aa", "ab")
e9 = ("*", 2, "aa")
# check_equal(value_equality(e7, e8), False)
check_equal(value_equality(e8, e7), False)

In [None]:
### Tests for value equality

e1 = ('+', ('*', 'x', 1), ('*', 'y', 0))
e2 = 'x'
check_equal(value_equality(e1, e2), True)

e3 = ('/', ('*', 'x', 'x'), ('*', 'x', 1))
check_equal(value_equality(e1, e3), True)

e4 = ('/', 'y', 2)
check_equal(value_equality(e1, e4), False)
check_equal(value_equality(e3, e4), False)

e5 = "4"
e6 = "5"
check_equal(value_equality(e5, e5), True)
check_equal(value_equality(e5, e6), False)

e7 = ("+", "aa", "aa")
e8 = ("+", "aa", "ab")
e9 = ("*", 2, "aa")
check_equal(value_equality(e7, e8), False)
check_equal(value_equality(e8, e7), False)
check_equal(value_equality(e7, e9), True)
check_equal(value_equality(e9, e7), True)

