<a href="https://colab.research.google.com/github/yifdai/PM-520-repo/blob/main/HW/PM520_HW2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Homework 2. Maximum likelihood & Optimization Crash Course

In [None]:
!pip install lineax

In [None]:
import jax
import jax.numpy as jnp
import jax.numpy.linalg as jnpla
import jax.scipy as jsp
import jax.scipy.linalg as jspla

## 1. Ordinary least squares (i.e. OLS)
OLS is an approach to fit a linear regression model $$y = X \beta + ɛ,$$
such that $\mathbb{E}[ɛ'ɛ]$ is minimized, where $\mathbb{E}[ɛ_i]=0$ and
$\mathbb{V}[ɛ_i] = \sigma^2$, for each $i=1,\dotsc,n$.

1.1 Derive the OLS solution $\hat{\beta}$ under the above objective. Show step by step.

$$
\begin{align*}
\mathbf{E} [\varepsilon'\varepsilon] &= \mathbf{E}(y - X \beta)^T (y - X \beta) \\
&= \mathbf{E} (y^T y - 2 y^T X \beta + \beta^T X^T X \beta) \\
&= \mathbf{E}[ (X \beta + ϵ)^T (X \beta + ϵ)] - \beta^T X^T X \beta \\
&= \beta^T X^T X \beta - \beta^T X^T X \beta = 0
\end{align*}
$$
Thus for the expectation operation of the covariance of $ϵ$, we always would get 0, and since after taking the expectation operation, everything in this equation is fixed, hence nothing to optimize.Thus for deriving the OLS solution $\hat{\beta}$, we need to consider optimizing the $ϵ^{T} ϵ$, i.e.

$$
\begin{align*}
\hat{\beta} &= argmin_\beta f(\beta) \\
f(\beta) &=  y^T y - 2 y^T X \beta + \beta^T X^T X \beta  \\
\frac{\partial }{\partial \beta} f(\beta) &= - 2 y^T X \beta + \beta^T X^T X \beta \\
\hat{\beta} &= (X^T X)^{-1} X^T y
\end{align*}
$$




1.2 Re-write the objective using a likelihood formulation assuming $ɛ_i \sim N(0, \sigma^2)$, for each $i=1,\dotsc,n$.

$$
\begin{align*}
L(ɛ) &= \prod^{n}_i \frac{1}{\sqrt{2 \pi \sigma^2}} \mathbf{e}^{\frac{1}{2 \sigma^2} ϵ_i^2}\\
L(\beta ) &= \prod^{n}_i \frac{1}{\sqrt{2 \pi \sigma^2}} \mathbf{e}^{\frac{1}{2 \sigma^2} (y_i - X_i \beta)^2}\\
&= \frac{1}{\sqrt{2 \pi \sigma^2}} \mathbf{e}^{\frac{1}{2 \sigma^2} \sum^n_i (y_i - X_i \beta)^2}\\
\end{align*}
$$
where $ϵ_i = y_i - X_i \beta$

1.3 Derive the OLS solution $\hat{\beta}_{MLE}$ using the above objective. Show step by step.

$$\begin{align*}
\frac{\partial}{\partial \beta} \log (L(ϵ)) &= - \sum^n_i \left( y_i X_i \beta - \beta^T X_i^T X_i \beta \right) ⇒\\
\hat{\beta}_{MLE} &=  (\sum^n_i X_i^T X_i)^{-1} (\sum^n_i X_i^T y_i)
\end{align*}
$$

1.4 Using [lineax](https://docs.kidger.site/lineax/), implement a solver for OLS.

In [None]:
import lineax as lx

from jax import Array
from jax.typing import ArrayLike


def solve_ols(y: ArrayLike, X: ArrayLike) -> Array:
  """
  Solves ordinary least squares using lineax.

  y: ArrayLike of observations
  X: ArrayLike of covariates

  returns: $\hat{\beta}$ for OLS
  """

  X_op = lx.MatrixLinearOperator(X)
  solver = lx.NormalCG(rtol=1e-6, atol=1e-8)

  return lx.linear_solve(X_op, y, solver=solver).value


## 2. Weighted least squares (i.e. WLS)
WLS is an approach to fit a slightly more general linear model where, $$y = X \beta + ɛ,$$ where $\mathbb{E}[ɛ_i] = 0$ and $\mathbb{V}[ɛ_i] = \sigma^2_i$. We can model all variances jointly as $\mathbb{V}[ɛ] = D$ where $D$ is a diagonal matrix such that $D_{ii} = \sigma^2_i$.

2.1 Write the WLS objective.

Objective:
$$
argmin_\beta \ g(\beta) \\
$$
$$
g(\beta) = (y - X\beta)^T D^{-1}  (y - X \beta)
$$

2.2. Derive the WLS solution $\hat{\beta}$ under the above objective. Show step by step.

Here for convience, we first consider transform the original data. By definition, $V(ϵ) = D_{ii} = \sigma^2_{ii}$, thus if we consider $\tilde{ϵ} = D^{-\frac{1}{2}} ϵ$, then by the weighted least squares assumption we have $\tilde{y} = D^{-\frac{1}{2}} y$, $\tilde{X} = D^{-\frac{1}{2}} X$, that is the original problem can be transformed to $\tilde{y} = \tilde{X} \beta + \tilde{ϵ}$, and $\mathbf{E}(\tilde{ϵ}) = 0$, $Var(\tilde{ϵ}) = \sigma^2$, which allows us to directly borrow the results from the OLS results, i.e.

$$
\begin{align*}
\hat{\beta} &= (\tilde{X}^T \tilde{X})^{-1} \tilde{X}^T \tilde{y} \\
&= (X^T D^{-1} X)^{-1} X^T D^{-1} y
\end{align*}
$$

2.3. Re-write the objective using a likelihood formulation assuming $ɛ \sim N(0, D)$.

$$
\begin{align*}
\log L (\beta) &\propto \sum^n_i \frac{(y_i - X_i \beta)^2}{\sigma^2_i} \propto \sum^n_i \frac{- 2 y_i X_i \beta + \beta^T X_i^T X_i \beta}{\sigma^2_i} \\
\end{align*}
$$

2.4 Derive the OLS solution $\hat{\beta}_{MLE}$ using the above objective. Show step by step.

$$
\begin{align*}
\frac{\partial }{\partial \beta} \log L (\beta) &= \frac{ -\sum^n_i X_i^T y_i + \sum^n_i  X_i^T X_i \beta}{\sigma^2_i} ⇒\\
\hat{\beta} &= (\sum^n_i X_i^T X_i / \sigma_i^2)^{-1} (\sum^n_i X_i^T y_i / \sigma^2_i)
\end{align*}
$$

2.5 Using [lineax](https://docs.kidger.site/lineax/), implement a solver for WLS.

In [None]:
import lineax as lx

from jax import Array
from jax.typing import ArrayLike


def solve_wls(y: ArrayLike, X: ArrayLike, D: ArrayLike) -> Array:
  """
  Solves weighted least squares using lineax.

  y: ArrayLike of observations
  X: ArrayLike of covariates
  D: ArrayLike of weights per observation

  returns: $\hat{\beta}$ for WLS
  """

  sqrtD = jnp.sqrt(D)
  X_weighted = sqrtD[:, None] * X
  y_weighted = sqrtD * y
  X_op = lx.MatrixLinearOperator(X_weighted)

  # Use normal CG solver for least squares (equivalent to WLS)
  solver = lx.NormalCG(atol=1e-6, rtol=1e-6)
  solution = lx.linear_solve(X_op, y_weighted, solver=solver)

  return solution.value

## 3. MLE for scalar Poisson observations
Given $x_1, \dotsc, x_n$, assume that $x_i \sim \text{Poi}(\lambda)$ for $i=1,\dotsc,n$ where $\text{Poi}(\lambda)$ is the Poisson distribution with rate $\lambda$.

3.1 Write a likelihood-based formulation of the objective.

$$
\begin{align*}
\log L (\lambda) &\propto \sum^n_i x_i \log \lambda -n\lambda \\
\end{align*}
$$

3.2 Derive the MLE for the above objective. Show step by step.

$$
\begin{align*}
\frac{\partial}{\partial \lambda} \log L (\lambda) &= \frac{1}{\lambda} \sum^n_i x_i -n ⇒\\
\hat{\lambda} &= \bar{x} = \frac{1}{n} \sum^n_i x_i
\end{align*}
$$

3.3 Implement a function that simulates Poisson distributed data with rate $\lambda$ using JAX.

3.4 Implement a function that computes the MLE $\hat{\lambda}$ given observations $x_1, \dotsc, x_n$.

In [None]:
import lineax as lx
import jax.random as rdm

from jax import Array
from jax.typing import ArrayLike


def simulate_poisson(key, rate: ArrayLike, n: int) -> Array:
  """
  Simulates Poisson distributed data.

  key: PRNGKey to generate
  rate: rate specifying the Poisson distribution; can be either a scalar, or
    ArrayLike (i.e. unique to each observation)
  n: the number of samples to generate

  returns: $x_i \sim \text{Poi}(\lambda_i)$
  """
  return rdm.poisson(key, rate, shape=(n,))


def fit_poisson(x: ArrayLike) -> float:
  """
  Fits Poisson distributed data.

  x: ArrayLike observations

  returns: estimate of $\lambda$.
  """
  return jnp.mean(x)