In [None]:
from time import perf_counter

import numpy as np
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
import optax
import equinox as eqx
import diffrax as dfx

from flex import FuzzyVariable
from flex.fiss import TSK
from flex.utils import count_parameters
from flex.utils.types import Array

In [None]:
# taken from https://docs.kidger.site/diffrax/examples/neural_ode/
def _get_data(ts, *, key):
    y0 = jax.random.uniform(key, (2,), minval=-0.6, maxval=1)

    def f(t, y, args):
        x = y / (1 + y)
        return jnp.stack([x[1], -x[0]], axis=-1)

    solver = dfx.Tsit5()
    dt0 = 0.1
    saveat = dfx.SaveAt(ts=ts)
    sol = dfx.diffeqsolve(
        dfx.ODETerm(f), solver, ts[0], ts[-1], dt0, y0, saveat=saveat
    )
    ys = sol.ys
    return ys


def get_data(dataset_size, *, key):
    ts = jnp.linspace(0, 10, 100)
    key = jax.random.split(key, dataset_size)
    ys = jax.vmap(lambda key: _get_data(ts, key=key))(key)
    return ts, ys

def dataloader(arrays, batch_size, *, key):
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = jnp.arange(dataset_size)
    while True:
        perm = jax.random.permutation(key, indices)
        (key,) = jax.random.split(key, 1)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size

In [None]:
class Func(eqx.Module):
    fis: TSK

    def __init__(
        self,
        n_mfs: int,
        kind: str,
        order: int,
        init_scale: float,
        *,
        key: Array,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)

        y = (
            FuzzyVariable.ruspini(n_mfs, kind=kind, minval=-1.0, maxval=1.0, name="x"),
            FuzzyVariable.ruspini(n_mfs, kind=kind, minval=-1.0, maxval=1.0, name="xdot")
        )

        self.fis = TSK.init(
            input_vars=y,
            order=order,
            init_scale=init_scale,
            key=key,
            name="FODE",
        )
    
    def __call__(self, t, y):
        return self.fis(y).squeeze()  # expecting one output
        

In [None]:
class FuzzyODE(eqx.Module):
    func: Func

    def __init__(
        self,
        n_mfs: int,
        kind: str,
        order: int,
        init_scale: float,
        *,
        key: Array,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.func = Func(
            n_mfs=n_mfs,
            kind=kind,
            order=order,
            init_scale=init_scale,
            key=key,
        )
        
    
    def __call__(self, ts, y0):
        sol = dfx.diffeqsolve(
            dfx.ODETerm(self.func),
            dfx.Tsit5(),
            t0=ts[0],
            t1=ts[-1],
            dt0=ts[1] - ts[0],
            y0=y0,
            stepsize_controller=dfx.PIDController(rtol=1e-3, atol=1e-6),
            saveat=dfx.SaveAt(ts=ts),
        )

        return sol.ys

In [None]:
dataset_size = 256
batch_size = 32,
lr = 3e-3