In [26]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
import jaxopt
import dataclasses
from typing import Any
import textwrap
import easier as ezr
import optax

jax.config.update('jax_platform_name', 'cpu')
# jax.config.update('jax_platform_name', 'METAL')

# If you don't have jax-metal properly installed just comment
# out the gpu_device and set device to cpu
# gpu_device = jax.devices('METAL')[0]
# cpu_device = jax.devices('cpu')[0]
# device = gpu_device
# device = cpu_device
import holoviews as hv
hv.extension('bokeh')
# hv.extension('matplotlib')
import hvplot.pandas

In [27]:
!rm -rf .ipynb_checkpoints/

  pid, fd = os.forkpty()


In [19]:

@dataclasses.dataclass
class JaxMLE:
    model: Any
    learning_rate: float = .1
    tracker: Any = None
    kwarg_to_track: Any = None
    w: Any = dataclasses.field(init=False) 
    cov:  Any = dataclasses.field(init=False) 
    sigma: Any = dataclasses.field(init=False) 
    N: Any = dataclasses.field(init=False)
    result: Any = dataclasses.field(init=False)
    max_iter: int = 500 
    
    example = textwrap.dedent("""
    Here is the examples you want to see
    """)

    def _loss(self, w, target,  **kwargs):
        y = self.model(w, **kwargs)
        if y.shape != target.shape:
            raise ValueError('The shape of the target variable must match model output shape')
        return jnp.sum((y - target) ** 2)
        

    def fit(self, w0, target, **kwargs):
        w = jnp.array(w0).astype(jnp.float32)
        target = jnp.array(target)

        optimizer = optax.adam(self.learning_rate)
        state = optimizer.init(w)
        steps = []
        losses = []
        for nn in range(self.max_iter):
            grads = jax.grad(self._loss)(w, y, **kwargs)
            updates, state = optimizer.update(grads, state)
            w = optax.apply_updates(w, updates)
            steps.append(nn)
            losses.append(self._loss(w, y, **kwargs))
            yf = model(w, **kwargs)
            if self.tracker is not None:
                if self.kwarg_to_track is not None:
                    self.tracker.update(kwargs[self.kwarg_to_track], yf)
                else:
                    self.tracker.update(steps, losses)
        self.w = w
        self.N = len(target)

        information_matrix = jax.hessian(self._loss, argnums=0)(self.w, target, **kwargs)
        sigma = jnp.std(target - self.model(self.w, **kwargs))
        
        self.cov = sigma ** 2 * jnp.linalg.inv(information_matrix)
            
        self.sigma = jnp.sqrt(jnp.diag(self.cov).squeeze())
    

    def _compute_error(self, **kwargs):

        @jax.jit
        def model_with_only_params(w):
            return model(w, **kwargs)
        
        J2 = self.N * jax.jacfwd(model_with_only_params)(self.sigma) ** 2
        
        sigma2 = J2 @ (self.sigma.reshape([-1, 1]) ** 2)
        sigma = jnp.sqrt(sigma2).squeeze()
        return sigma

    def predict(self, **kwargs):
        return self.model(self.w, **kwargs)

    def predict_and_error(self, **kwargs):
        y = self.predict(**kwargs)
        sigma = self._compute_error(**kwargs)
        return y, sigma




In [40]:
def model(w, t):
    a = w[0]
    f = w[1]
    y = a * jnp.sin(2 * np.pi * f * t)
    return y

t = np.linspace(0, 2 * np.pi, 600)
a_true = 2
f_true = .5
w = np.array([a_true, f_true])
y_true = model(w, t)
y = y_true + 4.2 * np.random.randn(*y_true.shape)

w0 = jnp.array([.1, .4])
fitter = JaxMLE(model, learning_rate=.01, tracker=tracker1, max_iter=527, kwarg_to_track=None)
fitter.fit(w0, y, t=t)
yf, sigma = fitter.predict_and_error(t=t)
hv.Curve((t, yf)) * hv.Curve((t, yf+sigma)) * hv.Curve((t, yf-sigma)) * hv.Scatter((t, y)) * hv.Curve((t, y_true))
# hv.Curve((t, sigma))

In [36]:

tracker1 = ezr.Tracker(label='true', logy=True)
tracker1.init()