# ECON622: Assignment 5

# Packages

Add whatever packages you wish here

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import scipy
import pandas as pd
import torch
import jax
import jax.numpy as jnp
from jax import grad, hessian
import torch
import torch.nn as nn
import torch.optim as optim
import equinox as eqx

# Q1

The trace of the Hessian matrix is useful in a variety of applications
in statistics, econometrics, and stochastic processes. It can also be
used to regularize a loss function.

For of a function $f:\mathbb{R}^N\to\mathbb{R}$, denote the Hessian as
$\nabla^2 f(x) \in \mathbb{R}^{N\times N}$.

It can be shows that for some mean zero, unit variance random vectors
$v\in\mathbb{R}^N$ with $\mathbb{E}(v) = 0$ and
$\mathbb{E}(v v^{\top}) = I$ the trace of the Hessian fulfills

$$
\mathrm{Tr}(\nabla^2 f(x)) = \mathbb{E}\left[v^{\top} \nabla^2 f(x)\, v\right]
$$

Which leads to a random algorithm by sampling $M$ vectors
$v_1,\ldots,v_M$ and using the monte-carlo approximation of the
expectation, called the [Hutchinson Trace
Estimator](https://www.tandfonline.com/doi/abs/10.1080/03610918908812806)

$$
\mathrm{Tr}(\nabla^2 f(x)) \approx \frac{1}{M} \sum_{m=1}^M v_m^{\top} \nabla^2 f(x)\, v_m
$$

# Q1.1

Now, lets take the function $f(x) = \frac{1}{2}x^{\top} P x$, which is a
quadratic form and where we know that $\nabla^2 f(x) = P$.

The following code finds the trace of the hessian, which is equivalently
just the sum of the diagonal of $P$ in this simple function.

In [2]:
key = jax.random.PRNGKey(0)

N = 100  # Dimension of the matrix
A = jax.random.normal(key, (N, N))
# Create a positive-definite matrix P by forming A^T * A
P = jnp.dot(A.T, A)
def f(x):
    return 0.5 * jnp.dot(x.T, jnp.dot(P, x))
x = jax.random.normal(key, (N,))
print(jnp.trace(jax.hessian(f)(x)))
print(jnp.diag(P).sum())

10223.29
10223.289

Now, instead of calculating the whole Hessian, use a [Hessian-vector
product in
JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#hessian-vector-products-using-both-forward-and-reverse-mode)
and the approximation above with $M$ draws of random vectors to
calculate an approximation of the trace of the Hessian. Increase the
numbers of $M$ to see what the variance of the estimator is, comparing
to the above closed-form solution for this quadratic.

Hint: you will want to do Forward-over-Reverse mode differentiation for
this (i.e. the `vjp` gives a pullback function for first derivative,
then differentate that new function. Given that it would then be
$\mathbb{R}^N \to \mathbb{R}^N$, it makes sense to use forward mode with
a `jvp`)

In [3]:
# ADD CODE HERE

## Q1.2 BONUS

If you wish, you can play around with radically increase the size of the
`N` and change the function itself. One suggestion is to move towards a
sparse or even matrix-free $f(x)$ calculation so that the $P$ doesn’t
itself need to materialize.

In [4]:
# ADD CODE HERE

# Q2

This section gives some hints on how to setup a differentiable
likelihood function with implicit functions

## Q2.1

The following code uses scipy to find the equilibrium price and demand
for some simple supply and demand functions with embedded parameters

In [5]:
from scipy.optimize import root_scalar

# Define the demand function with power c
def demand(P, c_d):
    return 100 - 2 * P**c_d

# Define the supply function with power f
def supply(P, c_s):
    return 5 * 3**(c_s * P)

# Define the function to find the root of, including c and f
def equilibrium(P, c_d, c_s):
    return demand(P, c_d) - supply(P, c_s)

# Use root_scalar to find the equilibrium price
def find_equilibrium(c_d, c_s):
    result = root_scalar(equilibrium, args=(c_d, c_s), bracket=[0, 100], method='brentq')
    return result.root, demand(result.root, c_d)

# Example usage
c_d = 0.5
c_s = 0.15
equilibrium_price, equilibrium_quantity = find_equilibrium(c_d, c_s)
print(f"Equilibrium Price: {equilibrium_price:.2f}")
print(f"Equilibrium Quantity: {equilibrium_quantity:.2f}")

Equilibrium Price: 17.65
Equilibrium Quantity: 91.60

First, convert this to use JAX and one of the JAX packages for finding
the root (e.g., in [JAXopt](https://jaxopt.github.io/stable/)). Make
sure you can jit the whole `find_equilibrium` function

In [6]:
# ADD CODE HERE

## Q2.2

Now, assume that you get a noisy signal on the price that fulfills that
demand system.

$$
\hat{p} \sim \mathcal{N}(p, \sigma^2)
$$

In that case, the log likelihood for the Gaussian is

$$
\log \mathcal{L}(\hat{p}\,|\,c_d, c_s, p) = -\frac{1}{2} \log(2\pi\sigma^2) - \frac{1}{2\sigma^2} (\hat{p} - p)^2
$$

Or, if $p$ was implicitly defined by the equilibrium conditions as some
$p(c_d, c_s)$ from above,

$$
\log \mathcal{L}(\hat{p}\,|\,c_d, c_s) = -\frac{1}{2} \log(2\pi\sigma^2) - \frac{1}{2\sigma^2} (\hat{p} - p(c_d, c_s))^2
$$

Then for some $\sigma = 0.01$ we can calculate this log likelihood the
above as

In [7]:
def log_likelihood(p_hat, c_d, c_s, sigma):
    p, x = find_equilibrium(c_d, c_s)
    return -0.5 * np.log(2 * np.pi * sigma**2) - 0.5 * (p_hat - p)**2 / sigma**2

c_d = 0.5
c_s = 0.15
sigma = 0.01
p, x = find_equilibrium(c_d, c_s) # get the true value for simulation
p_hat = p + np.random.normal(0, sigma) # simulate a noisy signal
log_likelihood(p_hat, c_d, c_s, sigma)

3.46939195899786

Now, take this code for the likelihood and convert it to JAX and jit.
Use your function from Q2.1

In [8]:
# ADD CODE HERE

## Q2.3

Use the function from the previous part and calculate the gradient with
respect to `c_d` and `c_s` using `grad` and JAX. You will probably want
to put the `c_d` and `c_s` into a vector as the first argument, or play
around with passing in a dictionary and using PyTrees

In [9]:
# ADD CODE HERE

## Q2.4 BONUS

You could try to run maximum likelihood estimation by using a gradient
based optimizer in JAX (e.g., ) Typically you wil want to use
[JAXopt](https://jaxopt.github.io/stable/) for this instead of the more
ML-centric optimizers.

If you attempt this: - Consider starting your optimization at the
“pseudo-true” values with the `c_s, c_d, sigma` you used to simulate the
data and even start with `p_hat = p`. - You may find that it is a little
too noisy with only the one observation. If so, you could adapt your
likelihood to take a vector of $\hat{p}$ instead. The likelihood of IID
gaussians is a simple variation on the above.

In [10]:
# ADD CODE HERE

## Q3

For the LLS examples with Pytorch we added in
[linear_regression_pytorch_logging.py](logging%20https://github.com/ubcecon/ECON622/blob/master/lectures/lectures/examples/linear_regression_pytorch_logging.py)
and a CLI interface - which came for free with pytorch lightning.

In this question you will add in some of those features to the
[linear_regression_jax_equinox.py](https://github.com/ubcecon/ECON622/blob/master/lectures/lectures/examples/linear_regression_jax_equinox.py)
example.

## Q3.1

Take the `linear_regression_jax_equinox.py` copied below for your
convenience and:

1.  Setup the W&B properly
2.  Add in logging of the `train_loss` at every step of the optimizer
3.  Remove the other epoch printing, or try to log an epoch specific
    `||theta - theta_hat||` if you wish
4.  Log the end `||theta - theta_hat||` at the end of the training

In [11]:
# MODIFY CODE HERE
import jax
import jax.numpy as jnp
from jax import grad, jit, value_and_grad, vmap
from jax import random
import optax
import equinox as eqx

N = 500  # samples
M = 2
sigma = 0.001
key = random.PRNGKey(42)
# Pattern: split before using key, replace name "key"
key, *subkey = random.split(key, num=4)
theta = random.normal(subkey[0], (M,))
X = random.normal(subkey[1], (N, M))
Y = X @ theta + sigma * random.normal(subkey[2], (N,))  # Adding noise

# Creates an iterable 
def data_loader(key, X, Y, batch_size):
    num_samples = X.shape[0]
    assert num_samples == Y.shape[0]
    indices = jnp.arange(num_samples)
    indices = random.permutation(key, indices)
    # Loop over batches and yield
    for i in range(0, num_samples, batch_size):
        batch_indices = indices[i:i + batch_size]
        yield X[batch_indices], Y[batch_indices]


# Need to randomize our own theta_0 parameters
key, subkey = random.split(key)
theta_0 = random.normal(subkey, (M,))
print(f"theta_0 = {theta_0}, theta = {theta}")

# Probably a way to use `vmap` or `eqx.filter_vmap` here as well

def residual(model, x, y):
    y_hat = model(x)
    return (y_hat - y) ** 2

def residuals(model, X, Y):
    batched_residuals = vmap(residual, in_axes=(None, 0, 0))
    return jnp.mean(batched_residuals(model, X, Y))

# Alternatively could do something like
def residuals_2(model, X, Y):
    Y_hat = vmap(model)(X).squeeze()
    return jax.numpy.mean((Y - Y_hat) ** 2)

# Hypothesis Class: will start with a linear function, which is randomly initialized
# model is a variable of all parametesr, and supports model(X) calls
key, subkey = random.split(key)
model = eqx.nn.Linear(M, 1, use_bias = False, key = subkey)

# reinitialize
lr = 0.001
optimizer = optax.sgd(lr)
# Needs to remove the non-differentiable parts of the "model" object
opt_state = optimizer.init(eqx.filter(model,eqx.is_inexact_array))

@eqx.filter_jit
def make_step(model, opt_state, X, Y):     
  loss_value, grads = eqx.filter_value_and_grad(residuals)(model, X, Y)
  updates, opt_state = optimizer.update(grads, opt_state, model)
  model = eqx.apply_updates(model, updates)
  return model, opt_state, loss_value

num_epochs = 300
batch_size = 64
key, subkey = random.split(key) # will keep same key for shuffling each epoch
for epoch in range(num_epochs):
    key, subkey = random.split(key) # changing key for shuffling each epoch
    train_loader = data_loader(subkey, X, Y, batch_size)
    for X_batch, Y_batch in train_loader:
        model, opt_state, train_loss = make_step(model, opt_state, X_batch, Y_batch)
        # TODO ADD IN LOGGING OF THE train_loss
    
    # TODO CAN REMOVE THIS ENTIRELY AFTER LOGGING IS WORKING
    if epoch % 100 == 0:
        print(f"Epoch {epoch},||theta - theta_hat|| = {jnp.linalg.norm(theta - model.weight)}")

# TODO: LOG THE FINAL VALUE HERE
print(f"||theta - theta_hat|| = {jnp.linalg.norm(theta - model.weight)}")

theta_0 = [ 1.7535115  -0.07298409], theta = [0.1378821  0.79073715]
Epoch 0,||theta - theta_hat|| = 0.4085429608821869
Epoch 100,||theta - theta_hat|| = 0.0743650421500206
Epoch 200,||theta - theta_hat|| = 0.01363404467701912
||theta - theta_hat|| = 0.002546713687479496

## Q3.2

Now, take the above code and copy it into a file named
`linear_regression_jax_cli.py`.

Feel free to use use the builtin
[Argparse](https://docs.python.org/3/library/argparse.html) or any other
[CLI
framework](https://github.com/shadawck/awesome-cli-frameworks#python)

Regardless of how you do it, here is a suggestion of some steps

1.  Create a function called
    `main_fn(lr: float = 0.001, N: int = 100, ...)` with whatever
    parameters you want to change as arguments. The type annotations are
    optional but useful for some CLI packages
2.  Move all of your code inside of that function, and get rid of the
    initialization of those values
3.  You can test it out by adding the following code and then running
    the file

<!-- -->

    if __name__ == '__main__':
      main_fn()

To make this CLI-ready, here are two suggested packages:

### jsonargparse

A package with many features, most of which you wouldn’t use directly,
is [jsonargparse](https://jsonargparse.readthedocs.io/). Besides the
more advanced features like [configuration
files](https://jsonargparse.readthedocs.io/en/v4.26.2/#writing-configuration-files)
and the instantiation of classes/etc. as arguments, the main difference
will be that it checks the types of arguments and converts them for you
using python typehints. In that case, you can adapt the following code
for your linear_regression_jax_cli.py

    import jsonargparse
    def main_fn(lr: float = 0.001, N: int = 100):
        print(f"lr = {lr}, N = {N}")

    if __name__ == "__main__":
         jsonargparse.CLI(main_fn)

If you get this working you may want to consider trying out the
configuration file feature, which is a nice way to save your
hyperparameters for later use and reproducibility

### Python Fire

[python-fire](https://github.com/google/python-fire) which is already in
your `requirements.txt` See the
[documentation](https://google.github.io/python-fire/guide/) for more.
This is a very lightweight package compared to some of the alternatives.

You can adapt the following code

    import fire
    def main_fn(lr: float = 0.001, N: int = 100):
        print(f"lr = {lr}, N = {N}")

    if __name__ == '__main__':
      fire.Fire(main_fn)

In this case, however, the `float` and `int` do not seem to be enforced
by Python Fire, so you may need to cast them directly in your code if
things aren’t working correctly.

### Using your CLI

In either case, at that point you should be able to call this with
`python linear_regression_jax_cli.py` and have it use all of the default
values, `python linear_regression_jax_cli.py --N=200` to change them,
etc.

Either submit the file as part of the assignment or just paste the code
into the notebook

## Q3.3 BONUS

Given the CLI you can now run a hyperparameter search. For this bonus
problem, do a hyperparameter search over the `--lr` argument by
following the [W&B documentation](https://docs.wandb.ai/guides/sweeps).

To get you started, your sweep yaml might look something like this

``` {yaml}
program: linear_regression_jax_cli.py
name: JAX Example
project: linear_regression_pytorch
description: JAX Sweep
method: random
parameters:
  lr:
    min: 0.0001
    max: 0.01
```

Here I changed the `method` from bayes to
[`random`](https://docs.wandb.ai/guides/sweeps/define-sweep-configuration#method)
because otherwise we would need to provide a `metric` to optimize over.
Feel free to adapt any of these settings.

If you successfully run a sweep then paste in your own yaml file here,
and a screenshot of the W&B dashboard showing something about the sweep
results.