In [1]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
import jaxopt
import dataclasses
from typing import Any
import textwrap
import easier as ezr

jax.config.update('jax_platform_name', 'cpu')
# jax.config.update('jax_platform_name', 'METAL')

import holoviews as hv
hv.extension('bokeh')
import hvplot.pandas

In [39]:

@dataclasses.dataclass
class JaxMLE:
    model: Any
    loss: Any = dataclasses.field(default=None)
    w: Any = dataclasses.field(init=False, repr=False) 
    cov:  Any = dataclasses.field(init=False, repr=False) 
    sigma: Any = dataclasses.field(init=False, repr=False) 
    N: Any = dataclasses.field(init=False, repr=False)
    result: Any = dataclasses.field(init=False, repr=False)
    max_iter: Any = dataclasses.field(default=10_000, repr=False) 
    
    def __post_init__(self):
        self.w = None
        self.cov = None
        self.sigma = None
        self.N = None
        self.result = None
    
    example = textwrap.dedent("""
    Here is the examples you want to see
    """)

    def _loss(self, w, target,  **kwargs):
        y = self.model(w, **kwargs)
        if y.shape != target.shape:
            raise ValueError('The shape of the target variable must match model output shape')
        return jnp.sum((y - target) ** 2)
        

    def fit(self, w0, target, **kwargs):
        w0 = jnp.array(w0).astype(jnp.float32)
        target = jnp.array(target)
        if self.loss is None:
            self.loss = self._loss
        
        solver = jaxopt.GradientDescent(self.loss, maxiter=self.max_iter, verbose=False)
        self.result = solver.run(w0, target, **kwargs)
        # The weights are always 1-d
        self.w = self.result.params.squeeze()
        self.N = len(target.squeeze())


        information_matrix = jax.hessian(self.loss, argnums=0)(self.w, target, **kwargs)
        sigma = jnp.std(target - self.model(self.w, **kwargs))
        
        self.cov = sigma ** 2 * jnp.linalg.inv(information_matrix)
            
        self.sigma = jnp.sqrt(jnp.diag(self.cov).squeeze())

    def _compute_error(self, **kwargs):

        @jax.jit
        def model_with_only_params(w):
            return model(w, **kwargs)
        
        J2 = self.N * jax.jacfwd(model_with_only_params)(self.sigma) ** 2
        
        sigma2 = J2 @ (self.sigma.reshape([-1, 1]) ** 2)
        sigma = jnp.sqrt(sigma2).squeeze()
        return sigma

    def predict(self, **kwargs):
        return self.model(self.w, **kwargs)

    def predict_and_error(self, **kwargs):
        y = self.predict(**kwargs)
        sigma = self._compute_error(**kwargs)
        return y, sigma



In [43]:
def model(w, t):
    a = w[0]
    f = w[1]
    y = a * jnp.sin(2 * np.pi * f * t)
    return y

t = np.linspace(0, 2 * np.pi, 600)
a_true = 2
f_true = .5
w = np.array([a_true, f_true])
y_true = model(w, t)
y = y_true + .1 * np.random.randn(*y_true.shape)

w0 = np.array([2.1, .56])

fitter = JaxMLE(model, max_iter=10000)
fitter
fitter.fit(w0, y, t=t)
yf, sigma = fitter.predict_and_error(t=t)
print(yf.shape, sigma.shape)
hv.Overlay([
    hv.Curve((t, y_true), label='true'),
    hv.Scatter((t, y)),
    hv.Curve((t, yf)),
    hv.Area(((t, yf - sigma, yf + sigma)), vdims=['a', 'b']).options(alpha=.1)
    
])
# fitter.sigma
# hv.Curve((t, sigma))

(600,) (600,)


In [11]:
fitter
# X = X_from_t(t)
# w_true = np.array([2, 3])
# y_true = (X @ w_true.reshape(-1, 1)).squeeze()
# y = y_true + .30 * np.random.randn(*y_true.shape)

# w0 = np.array([-1., 1.])
# fitter = JaxMLE(model)
# fitter.fit(w0, y, X=X)

# tf = np.linspace(0, 2 * np.pi, 1000)
# Xf = X_from_t(tf)
# yf, sigma = fitter.predict_and_error(X=Xf)

# ol = hv.Overlay([
#     hv.Scatter((t, y)).options(color='black', size=5),
#     hv.Area((tf, yf-sigma, yf+sigma), vdims=['y', 'y3']).options(color='grey', alpha=.2),
#     hv.Curve((tf, yf)).options(color='red')
# ])
                               
# # ol = hv.Scatter((t, y)).options(color=ezr.cc[0], size=5) * hv.Curve((tf, yf)).options(color=ezr.cc[1])
# ol
fitter

JaxMLE(model=<function model at 0x3045a2440>, loss=<bound method JaxMLE._loss of ...>)