<a href="https://colab.research.google.com/github/USCbiostats/PM520/blob/main/Lab_4_Optimization_PtI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Move on Up, or: Maximum likelihood Estimation & Optimization Pt I

TBD: move notes from slides to here.


In [7]:
import jax
import jax.numpy as jnp
import jax.random as rdm

## MLE for iid Normal data

In [8]:
def norm_rv(key, n: int, mu: float, sigma_sq: float):
  """
  Samples $n$ observations from $x_i \sim N(\mu, \sigma^2)$

  n: the number of observations
  mu: the mean parameter
  sigma_sq: the variance parameter

  returns: x, Array of observations
  """
  x = mu + jnp.sqrt(sigma_sq) * rdm.normal(key, shape=(n,))

  return x


def norm_mle(x):
  """
  Computes $\hat{\mu}_{MLE}$ and $\hat{\sigma^2}_{MLE}$.

  x: Array of observations

  returns: Tuple of $\hat{\mu}_{MLE}$ and $\hat{\sigma^2}_{MLE}$.
  """
  mu_hat = jnp.mean(x)
  ssq_hat = jnp.mean((x - mu_hat) ** 2) # jnp.var(x)

  return mu_hat, ssq_hat

seed = 0
key = rdm.PRNGKey(seed)
key, x_key = rdm.split(key)

N = 100

mu = 58.
sigma_sq = 10.
x = norm_rv(x_key, N, mu, sigma_sq)
print(f"x = {x}")
mu_hat, ssq_hat = norm_mle(x)
print(f"MLE[\mu, \sigma^2] = {mu_hat}, {ssq_hat}")

x = [54.748363 61.49207  58.068737 62.63578  63.038616 57.023346 62.89525
 67.758804 53.49503  64.813095 54.75522  56.203518 58.778862 59.815628
 59.685024 52.485527 54.44686  58.31798  59.458866 54.291496 62.230354
 61.92895  57.596672 54.478622 62.95002  61.07032  58.497906 61.5559
 57.91157  48.89393  58.041267 57.05834  54.320095 60.814167 58.863586
 52.72084  57.37599  58.616085 61.504147 53.660152 53.529938 56.553493
 59.48462  57.536884 58.604736 60.350792 59.59793  58.373173 56.29943
 59.44     60.093395 58.018135 60.842957 58.10686  62.98405  56.63515
 60.17766  62.55351  61.50157  63.707592 55.19841  63.116325 57.727917
 57.98455  62.07337  59.92226  55.25571  55.92521  60.866608 58.606293
 58.526638 62.975803 60.693626 63.814217 62.8752   57.88707  61.743687
 62.84971  61.9028   60.084473 61.1948   56.62765  59.065964 56.75269
 57.404243 53.602177 56.415016 54.36041  60.074253 60.310368 59.518055
 62.36892  57.469944 62.59477  61.664345 62.104057 55.20062  62.079853
 56.3116

In [None]:
def sq_diff(param, estimate):
  return (param - estimate) ** 2

mu = 58.
sigma_sq = 10.
for N in [50, 100, 1000, 10000]:
  key, x_key = rdm.split(key)
  # generate N observations
  x_n = norm_rv(x_key, N, mu, sigma_sq)
  # estimate mu, and sigma_sq
  mu_hat, ssq_hat = norm_mle(x_n)
  # compute the sq-diff for both and report
  mu_err = sq_diff(mu, mu_hat)
  ssq_err = sq_diff(sigma_sq, ssq_hat)
  print(f"MSE[{N} | mu, sigma^2] = {mu_err}, {ssq_err}")

MSE[50 | mu, sigma^2] = 4.6261833631433547e-05, 11.868013381958008
MSE[100 | mu, sigma^2] = 0.11534085124731064, 0.0939243957400322
MSE[1000 | mu, sigma^2] = 0.004835018888115883, 8.171940862666816e-05
MSE[10000 | mu, sigma^2] = 7.22354743629694e-05, 0.12703219056129456


## MLE for iid Exponential data

In [10]:
def exp_rv(key, n: int, rate: float):
  """
  Samples $n$ observations from $x_i \sim Exp(\lambda)$

  n: the number of observations
  rate: the $\lambda$ parameter

  returns: x, Array of observations
  """
  mean = 1 / rate
  x = mean * rdm.exponential(key, shape=(n,))
  return x


def exp_mle(x):
  """
  Computes $\hat{\lambda}_{MLE}$.

  x: Array of observations

  returns: $\hat{\lambda}_{MLE}$.
  """
  rate_hat = 1 / jnp.mean(x)
  return rate_hat

key, x_key = rdm.split(key)
N = 100
rate = 1 / 500.
x = exp_rv(x_key, N, rate)
print(f"x = {x}")
rate_hat = exp_mle(x)
print(f"MLE[\lambda = {rate}] = {rate_hat}")

x = [  17.312157   709.8151       6.562796    35.84027    470.85352
  998.4208      81.79583      9.341232   376.30804    340.34982
  157.47165     35.8379     261.62537   1138.6255     207.14104
  136.99928   1194.3866     833.9906     259.2464     385.55402
  133.51808     14.5040455  721.3384     763.35876    373.88965
  673.0625    1454.55      1262.22       267.10886    388.59683
 1102.788      168.1625     760.39923    193.80412     19.935608
  613.356      193.36668   1924.9954     554.87555    513.06714
  140.60371    100.00257    342.7109     262.97552    306.68918
  128.06355    212.35666   1253.6094     139.26414    522.2092
  180.08957    365.10867    431.31024    482.0427      16.010775
  222.55356    310.96436    249.05305    548.1318     678.6601
  221.14578   1445.9286     291.73547    324.59033     89.26435
  381.03754     17.550144   322.22336    662.46857     78.28897
   86.23128    513.8632     482.30316     77.64375     43.57906
  204.64058    401.58325    997.7560

In [11]:
rate = 1 / 50.
for N in [50, 100, 1000, 10000]:
  key, x_key = rdm.split(key)
  # generate N observations
  x_n = exp_rv(x_key, N, rate)
  # estimate rate
  rate_hat = exp_mle(x_n)
  # compute the sq-diff for rate
  rate_err = sq_diff(rate, rate_hat)
  print(f"MSE[{N} | \lambda = {rate}] = {rate_err}")

MSE[50 | \lambda = 0.02] = 3.6685960935756157e-07
MSE[100 | \lambda = 0.02] = 5.662052899424452e-06
MSE[1000 | \lambda = 0.02] = 7.783638693581452e-07
MSE[10000 | \lambda = 0.02] = 4.279328180700759e-08


# Gradient descent

In [None]:
def sim_linear_reg(key, N, P, r2=0.5):
  key, x_key = rdm.split(key)
  X = rdm.normal(x_key, shape=(N, P))

  key, b_key = rdm.split(key)
  beta = rdm.normal(b_key, shape=(P,))

  # g = jnp.dot(X, beta)
  g = X @ beta
  s2g = jnp.var(g)

  # back out what s2e is, such that s2g / (s2g + s2e) == h2
  s2e = (1 - r2) / r2 * s2g
  key, y_key = rdm.split(key)

  # add env noise to g, but scale such that var(e) == s2e
  y = g + jnp.sqrt(s2e) * rdm.normal(y_key, shape=(N,))
  return y, X, beta

key, sim_key = rdm.split(key)

N = 1000
P = 5
y, X, beta = sim_linear_reg(sim_key, N, P)

def linreg_loss(y, X, beta_hat):
  pred_y = X @ beta_hat
  return jnp.sum((y - pred_y)**2)

def gradient(y, X, beta_hat):
  Xty = X.T @ y
  XtX = X.T @ X
  return (XtX @ beta_hat) - Xty

# X transpose

step_size = 1 / N
diff = 10.
last_loss = 1000.
idx = 0
beta_hat = jnp.zeros((P,))
# while delta in loss is large, continue
while jnp.fabs(diff) > 1e-3:
  # take a step in the direction of the gradient using step_size
  beta_hat = beta_hat - step_size * gradient(y, X, beta_hat)
  # update our current loss and compute delta
  cur_loss = linreg_loss(y, X, beta_hat)
  diff = last_loss - cur_loss
  last_loss = cur_loss
  # wave to the crowd
  print(f"Loss[{idx} | {beta}] = {last_loss} @ {beta_hat}")
  idx += 1

Loss[0 | [-1.0257084   1.8594663   1.1235882   0.23114179 -0.34901348]] = 5639.55810546875 @ [-1.079702    1.8402483   1.1758611   0.23298487 -0.52076465]
Loss[1 | [-1.0257084   1.8594663   1.1235882   0.23114179 -0.34901348]] = 5623.48291015625 @ [-1.1577635   1.8372455   1.081856    0.24313869 -0.47761846]
Loss[2 | [-1.0257084   1.8594663   1.1235882   0.23114179 -0.34901348]] = 5623.23828125 @ [-1.1617126  1.8369826  1.0963572  0.2397221 -0.4716044]
Loss[3 | [-1.0257084   1.8594663   1.1235882   0.23114179 -0.34901348]] = 5623.23193359375 @ [-1.1620235   1.8380398   1.094032    0.23982307 -0.47083738]
Loss[4 | [-1.0257084   1.8594663   1.1235882   0.23114179 -0.34901348]] = 5623.2314453125 @ [-1.1620889   1.8379307   1.0944436   0.23976085 -0.4707122 ]


In [None]:
key, sim_key = rdm.split(key)

def linreg_loss_forjax(beta_hat, y, X):
    pred =  X @ beta_hat
    return jnp.sum((y-pred)**2)
N=1000
P=5
diff = 10
y, X, beta = sim_linear_reg(sim_key, N, P)
step_size = 1/N
idx = 0
beta_hat = jnp.zeros((P,))

#while the delta is large:
while jnp.fabs(diff) > 1e-3:







