In [1]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
import jaxopt
import dataclasses
from typing import Any
import textwrap
import easier as ezr

jax.config.update('jax_platform_name', 'cpu')
# jax.config.update('jax_platform_name', 'METAL')

# If you don't have jax-metal properly installed just comment
# out the gpu_device and set device to cpu
# gpu_device = jax.devices('METAL')[0]
# cpu_device = jax.devices('cpu')[0]
# device = gpu_device
# device = cpu_device
import holoviews as hv
hv.extension('bokeh')
import hvplot.pandas

In [2]:

@dataclasses.dataclass
class JaxMLE:
    model: Any
    w: Any = dataclasses.field(init=False) 
    cov:  Any = dataclasses.field(init=False) 
    sigma: Any = dataclasses.field(init=False) 
    N: Any = dataclasses.field(init=False)
    result: Any = dataclasses.field(init=False)
    max_iter: int = 500 
    
    example = textwrap.dedent("""
    Here is the examples you want to see
    """)

    def _loss(self, w, target,  **kwargs):
        y = self.model(w, **kwargs)
        if y.shape != target.shape:
            raise ValueError('The shape of the target variable must match model output shape')
        return jnp.sum((y - target) ** 2)
        

    def fit(self, w0, target, **kwargs):
        w0 = jnp.array(w0).astype(jnp.float32)
        target = jnp.array(target)
        
        solver = jaxopt.GradientDescent(self._loss, maxiter=self.max_iter, verbose=False)
        self.result = solver.run(w0, target, **kwargs)
        # The weights are always 1-d
        self.w = self.result.params.squeeze()
        self.N = len(target.squeeze())


        information_matrix = jax.hessian(self._loss, argnums=0)(self.w, target, **kwargs)
        sigma = jnp.std(target - self.model(self.w, **kwargs))
        
        self.cov = sigma ** 2 * jnp.linalg.inv(information_matrix)
            
        self.sigma = jnp.sqrt(jnp.diag(self.cov).squeeze())

    def _compute_error(self, **kwargs):

        @jax.jit
        def model_with_only_params(w):
            return model(w, **kwargs)
        
        J2 = self.N * jax.jacfwd(model_with_only_params)(self.sigma) ** 2
        
        sigma2 = J2 @ (self.sigma.reshape([-1, 1]) ** 2)
        sigma = jnp.sqrt(sigma2).squeeze()
        return sigma

    def predict(self, **kwargs):
        return self.model(self.w, **kwargs)

    def predict_and_error(self, **kwargs):
        y = self.predict(**kwargs)
        sigma = self._compute_error(**kwargs)
        return y, sigma

def model(w,  X):
    w = w.reshape(-1, 1)
    return (X @ w).squeeze()


def X_from_t(t):
    s = np.sin(t)
    c = np.cos(t)
    X = np.stack([s, c], axis=0).T
    return X

t = np.linspace(0, 2 * np.pi, 100)
X = X_from_t(t)
w_true = np.array([2, 3])
y_true = (X @ w_true.reshape(-1, 1)).squeeze()
y = y_true + .30 * np.random.randn(*y_true.shape)

w0 = np.array([-1., 1.])
fitter = JaxMLE(model)
fitter.fit(w0, y, X=X)

tf = np.linspace(0, 2 * np.pi, 1000)
Xf = X_from_t(tf)
yf, sigma = fitter.predict_and_error(X=Xf)

ol = hv.Overlay([
    hv.Scatter((t, y)).options(color='black', size=5),
    hv.Area((tf, yf-sigma, yf+sigma), vdims=['y', 'y3']).options(color='grey', alpha=.2),
    hv.Curve((tf, yf)).options(color='red')
])
                               
# ol = hv.Scatter((t, y)).options(color=ezr.cc[0], size=5) * hv.Curve((tf, yf)).options(color=ezr.cc[1])
ol

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
I0000 00:00:1718823116.938123 3953857 service.cc:145] XLA service 0x60000003b300 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1718823116.938133 3953857 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1718823116.939556 3953857 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1718823116.939568 3953857 mps_client.cc:384] XLA backend will use up to 11452858368 bytes on device 0 for SimpleAllocator.


Metal device set to: Apple M1 Pro
