In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install flax --upgrade


In [None]:
!pip install jax --upgrade


In [None]:
#check if version >= 0.4.35
!pip show flax

In [None]:
#check if version >= 0.4.35
!pip show jax

In [None]:
import tensorflow_datasets as tfds  # TFDS to download MNIST.
import tensorflow as tf  # TensorFlow / `tf.data` operations.
from flax import nnx  # The Flax NNX API.
import jax.numpy as jnp  # JAX NumPy
import jax
import optax
import copy
import numpy as np
import random


In [None]:
#Change name of the parameters for layer-wise parameterization of the learning rate
from __future__ import annotations

import typing as tp

import jax
import jax.numpy as jnp
import numpy as np
from jax import lax
import opt_einsum

from flax.core.frozen_dict import FrozenDict
from flax import nnx
from flax.nnx import rnglib, variablelib
from flax.nnx.module import Module, first_from
from flax.nnx.nn import dtypes, initializers
from flax.typing import (
    Dtype,
    Shape,
    Initializer,
    PrecisionLike,
    DotGeneralT,
    ConvGeneralDilatedT,
    PaddingLike,
    LaxPadding,
)

Array = jax.Array
Axis = int
Size = int


default_kernel_init = initializers.lecun_normal()
default_bias_init = initializers.zeros_init()


class Linear_encoder(Module):
    """A linear transformation applied over the last dimension of the input.

    Example usage::

      >>> from flax import nnx
      >>> import jax, jax.numpy as jnp

      >>> layer = nnx.Linear(in_features=3, out_features=4, rngs=nnx.Rngs(0))
      >>> jax.tree.map(jnp.shape, nnx.state(layer))
      State({
        'bias': VariableState(
          type=Param,
          value=(4,)
        ),
        'kernel': VariableState(
          type=Param,
          value=(3, 4)
        )
      })

    Attributes:
      in_features: the number of input features.
      out_features: the number of output features.
      use_bias: whether to add a bias to the output (default: True).
      dtype: the dtype of the computation (default: infer from input and params).
      param_dtype: the dtype passed to parameter initializers (default: float32).
      precision: numerical precision of the computation see ``jax.lax.Precision``
        for details.
      kernel_init: initializer function for the weight matrix.
      bias_init: initializer function for the bias.
      dot_general: dot product function.
      rngs: rng key.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        *,
        use_bias: bool = True,
        dtype: tp.Optional[Dtype] = None,
        param_dtype: Dtype = jnp.float32,
        precision: PrecisionLike = None,
        kernel_init: Initializer = default_kernel_init,
        bias_init: Initializer = default_bias_init,
        dot_general: DotGeneralT = lax.dot_general,
        rngs: rnglib.Rngs,
    ):
        kernel_key = rngs.params()
        self.lin_encoder_kernel = nnx.Param(
            kernel_init(kernel_key, (in_features, out_features), param_dtype)
        )
        if use_bias:
            bias_key = rngs.params()
            self.lin_encoder_bias = nnx.Param(
                bias_init(bias_key, (out_features,), param_dtype)
            )
        else:
            self.lin_encoder_bias = nnx.Param(None)

        self.in_features = in_features
        self.out_features = out_features
        self.use_bias = use_bias
        self.dtype = dtype
        self.param_dtype = param_dtype
        self.precision = precision
        self.kernel_init = kernel_init
        self.bias_init = bias_init
        self.dot_general = dot_general

    def __call__(self, inputs: Array) -> Array:
        """Applies a linear transformation to the inputs along the last dimension.

        Args:
          inputs: The nd-array to be transformed.

        Returns:
          The transformed input.
        """
        kernel = self.lin_encoder_kernel.value
        bias = self.lin_encoder_bias.value

        inputs, kernel, bias = dtypes.promote_dtype(
            (inputs, kernel, bias), dtype=self.dtype
        )
        y = self.dot_general(
            inputs,
            kernel,
            (((inputs.ndim - 1,), (0,)), ((), ())),
            precision=self.precision,
        )
        assert self.use_bias == (bias is not None)
        if bias is not None:
            y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
        return y


class Linear_MLP1(Module):
    """A linear transformation applied over the last dimension of the input.

    Example usage::

      >>> from flax import nnx
      >>> import jax, jax.numpy as jnp

      >>> layer = nnx.Linear(in_features=3, out_features=4, rngs=nnx.Rngs(0))
      >>> jax.tree.map(jnp.shape, nnx.state(layer))
      State({
        'bias': VariableState(
          type=Param,
          value=(4,)
        ),
        'kernel': VariableState(
          type=Param,
          value=(3, 4)
        )
      })

    Attributes:
      in_features: the number of input features.
      out_features: the number of output features.
      use_bias: whether to add a bias to the output (default: True).
      dtype: the dtype of the computation (default: infer from input and params).
      param_dtype: the dtype passed to parameter initializers (default: float32).
      precision: numerical precision of the computation see ``jax.lax.Precision``
        for details.
      kernel_init: initializer function for the weight matrix.
      bias_init: initializer function for the bias.
      dot_general: dot product function.
      rngs: rng key.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        *,
        use_bias: bool = True,
        dtype: tp.Optional[Dtype] = None,
        param_dtype: Dtype = jnp.float32,
        precision: PrecisionLike = None,
        kernel_init: Initializer = default_kernel_init,
        bias_init: Initializer = default_bias_init,
        dot_general: DotGeneralT = lax.dot_general,
        rngs: rnglib.Rngs,
    ):
        kernel_key = rngs.params()
        self.MLP1_kernel = nnx.Param(
            kernel_init(kernel_key, (in_features, out_features), param_dtype)
        )
        if use_bias:
            bias_key = rngs.params()
            self.MLP1_bias = nnx.Param(
                bias_init(bias_key, (out_features,), param_dtype)
            )
        else:
            self.MLP1_bias = nnx.Param(None)

        self.in_features = in_features
        self.out_features = out_features
        self.use_bias = use_bias
        self.dtype = dtype
        self.param_dtype = param_dtype
        self.precision = precision
        self.kernel_init = kernel_init
        self.bias_init = bias_init
        self.dot_general = dot_general

    def __call__(self, inputs: Array) -> Array:
        """Applies a linear transformation to the inputs along the last dimension.

        Args:
          inputs: The nd-array to be transformed.

        Returns:
          The transformed input.
        """
        kernel = self.MLP1_kernel.value
        bias = self.MLP1_bias.value

        inputs, kernel, bias = dtypes.promote_dtype(
            (inputs, kernel, bias), dtype=self.dtype
        )
        y = self.dot_general(
            inputs,
            kernel,
            (((inputs.ndim - 1,), (0,)), ((), ())),
            precision=self.precision,
        )
        assert self.use_bias == (bias is not None)
        if bias is not None:
            y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
        return y


class Linear_MLP2(Module):
    """A linear transformation applied over the last dimension of the input.

    Example usage::

      >>> from flax import nnx
      >>> import jax, jax.numpy as jnp

      >>> layer = nnx.Linear(in_features=3, out_features=4, rngs=nnx.Rngs(0))
      >>> jax.tree.map(jnp.shape, nnx.state(layer))
      State({
        'bias': VariableState(
          type=Param,
          value=(4,)
        ),
        'kernel': VariableState(
          type=Param,
          value=(3, 4)
        )
      })

    Attributes:
      in_features: the number of input features.
      out_features: the number of output features.
      use_bias: whether to add a bias to the output (default: True).
      dtype: the dtype of the computation (default: infer from input and params).
      param_dtype: the dtype passed to parameter initializers (default: float32).
      precision: numerical precision of the computation see ``jax.lax.Precision``
        for details.
      kernel_init: initializer function for the weight matrix.
      bias_init: initializer function for the bias.
      dot_general: dot product function.
      rngs: rng key.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        *,
        use_bias: bool = True,
        dtype: tp.Optional[Dtype] = None,
        param_dtype: Dtype = jnp.float32,
        precision: PrecisionLike = None,
        kernel_init: Initializer = default_kernel_init,
        bias_init: Initializer = default_bias_init,
        dot_general: DotGeneralT = lax.dot_general,
        rngs: rnglib.Rngs,
    ):
        kernel_key = rngs.params()
        self.MLP2_kernel = nnx.Param(
            kernel_init(kernel_key, (in_features, out_features), param_dtype)
        )
        if use_bias:
            bias_key = rngs.params()
            self.MLP2_bias = nnx.Param(
                bias_init(bias_key, (out_features,), param_dtype)
            )
        else:
            self.MLP2_bias = nnx.Param(None)

        self.in_features = in_features
        self.out_features = out_features
        self.use_bias = use_bias
        self.dtype = dtype
        self.param_dtype = param_dtype
        self.precision = precision
        self.kernel_init = kernel_init
        self.bias_init = bias_init
        self.dot_general = dot_general

    def __call__(self, inputs: Array) -> Array:
        """Applies a linear transformation to the inputs along the last dimension.

        Args:
          inputs: The nd-array to be transformed.

        Returns:
          The transformed input.
        """
        kernel = self.MLP2_kernel.value
        bias = self.MLP2_bias.value

        inputs, kernel, bias = dtypes.promote_dtype(
            (inputs, kernel, bias), dtype=self.dtype
        )
        y = self.dot_general(
            inputs,
            kernel,
            (((inputs.ndim - 1,), (0,)), ((), ())),
            precision=self.precision,
        )
        assert self.use_bias == (bias is not None)
        if bias is not None:
            y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
        return y


class Linear_out(Module):
    """A linear transformation applied over the last dimension of the input.

    Example usage::

      >>> from flax import nnx
      >>> import jax, jax.numpy as jnp

      >>> layer = nnx.Linear(in_features=3, out_features=4, rngs=nnx.Rngs(0))
      >>> jax.tree.map(jnp.shape, nnx.state(layer))
      State({
        'bias': VariableState(
          type=Param,
          value=(4,)
        ),
        'kernel': VariableState(
          type=Param,
          value=(3, 4)
        )
      })

    Attributes:
      in_features: the number of input features.
      out_features: the number of output features.
      use_bias: whether to add a bias to the output (default: True).
      dtype: the dtype of the computation (default: infer from input and params).
      param_dtype: the dtype passed to parameter initializers (default: float32).
      precision: numerical precision of the computation see ``jax.lax.Precision``
        for details.
      kernel_init: initializer function for the weight matrix.
      bias_init: initializer function for the bias.
      dot_general: dot product function.
      rngs: rng key.
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        *,
        use_bias: bool = True,
        dtype: tp.Optional[Dtype] = None,
        param_dtype: Dtype = jnp.float32,
        precision: PrecisionLike = None,
        kernel_init: Initializer = default_kernel_init,
        bias_init: Initializer = default_bias_init,
        dot_general: DotGeneralT = lax.dot_general,
        rngs: rnglib.Rngs,
    ):
        kernel_key = rngs.params()
        self.out_kernel = nnx.Param(
            kernel_init(kernel_key, (in_features, out_features), param_dtype)
        )
        if use_bias:
            bias_key = rngs.params()
            self.out_bias = nnx.Param(bias_init(bias_key, (out_features,), param_dtype))
        else:
            self.out_bias = nnx.Param(None)

        self.in_features = in_features
        self.out_features = out_features
        self.use_bias = use_bias
        self.dtype = dtype
        self.param_dtype = param_dtype
        self.precision = precision
        self.kernel_init = kernel_init
        self.bias_init = bias_init
        self.dot_general = dot_general

    def __call__(self, inputs: Array) -> Array:
        """Applies a linear transformation to the inputs along the last dimension.

        Args:
          inputs: The nd-array to be transformed.

        Returns:
          The transformed input.
        """
        kernel = self.out_kernel.value
        bias = self.out_bias.value

        inputs, kernel, bias = dtypes.promote_dtype(
            (inputs, kernel, bias), dtype=self.dtype
        )
        y = self.dot_general(
            inputs,
            kernel,
            (((inputs.ndim - 1,), (0,)), ((), ())),
            precision=self.precision,
        )
        assert self.use_bias == (bias is not None)
        if bias is not None:
            y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
        return y


In [None]:

parallel_scan = jax.lax.associative_scan

# From Orvieto et al., 2023, (https://arxiv.org/abs/2303.06349)
def compute_lr_sigma(mode: str, d, m, k, L):
    lr = 0
    sigma = 0
    if mode == "input":
        lr = m / (jnp.power(L, 3 / 2) * d)
        sigma = 1 / jnp.sqrt(d)
    elif mode == "hidden":
        lr = 1 / jnp.power(L, 3 / 2)
        sigma = 2 / jnp.sqrt((m + d) / 2)
    elif mode == "output":
        lr = k / (jnp.power(L, 3 / 2) * m)
        sigma = jnp.sqrt(k) / m
    else:
        raise ValueError
    return float(lr), float(sigma)

def forward(lru_parameters, input_sequence):
    """Forward pass of the LRU layer. Output y and input_sequence are of shape (L, H)."""

    # All LRU parameters
    nu_log, theta_log, B_re, B_im, C_re, C_im, D, gamma_log = lru_parameters

    # Materializing the diagonal of Lambda and projections
    Lambda = jnp.exp(-jnp.exp(nu_log) + 1j * jnp.exp(theta_log))
    B_norm = (B_re + 1j * B_im) * jnp.expand_dims(jnp.exp(gamma_log), axis=-1)
    C = C_re + 1j * C_im

    # Running the LRU + output projection
    # For details on parallel scan, check discussion in Smith et al (2022).
    Lambda_elements = jnp.repeat(Lambda[None, ...], input_sequence.shape[0], axis=0)
    Bu_elements = jax.vmap(lambda u: B_norm @ u)(input_sequence)
    elements = (Lambda_elements, Bu_elements)
    _, inner_states = parallel_scan(binary_operator_diag, elements)  # all x_k
    y = jax.vmap(lambda x, u: (C @ x).real + D * u)(inner_states, input_sequence)

    return y


def init_lru_parameters(N, H, r_min=0, r_max=1, max_phase=0.314):
    """Initialize parameters of the LRU layer."""

    # N: state dimension, H: model dimension
    # Initialization of Lambda is complex valued distributed uniformly on ring
    # between r_min and r_max, with phase in [0, max_phase].
    u1 = np.random.uniform(size=(N,))
    u2 = np.random.uniform(size=(N,))
    nu_log = np.log(-0.5 * np.log(u1 * (r_max**2 - r_min**2) + r_min**2))
    theta_log = np.log(max_phase * u2)

    # Glorot initialized Input/Output projection matrices
    B_re = np.random.normal(size=(N, H)) / np.sqrt(2 * H)
    B_im = np.random.normal(size=(N, H)) / np.sqrt(2 * H)
    C_re = np.random.normal(size=(H, N)) / np.sqrt(N)
    C_im = np.random.normal(size=(H, N)) / np.sqrt(N)
    D = np.random.normal(size=(H,))

    # Normalization factor
    diag_lambda = np.exp(-np.exp(nu_log) + 1j * np.exp(theta_log))
    gamma_log = np.log(np.sqrt(1 - np.abs(diag_lambda) ** 2))

    return nu_log, theta_log, B_re, B_im, C_re, C_im, D, gamma_log


def binary_operator_diag(element_i, element_j):
    # Binary operator for parallel scan of linear recurrence.
    a_i, bu_i = element_i
    a_j, bu_j = element_j
    return a_j * a_i, a_j * bu_i + bu_j


Array = jax.Array


class LRU(nnx.Module):

    def __init__(
        self,
        in_features: int,
        hidden_features: int,  # not inferred from carry for now
        *,
        r_min=0,
        r_max=1,
        max_phase=6.28,
    ):
        self.in_features = in_features
        self.hidden_features = hidden_features
        nu_log, theta_log, B_re, B_im, C_re, C_im, D, gamma_log = init_lru_parameters(
            hidden_features, in_features, r_min=r_min, r_max=r_max, max_phase=max_phase
        )

        self.nu_log = nnx.Param(nu_log)
        self.theta_log = nnx.Param(theta_log)
        self.B_re = nnx.Param(B_re)
        self.B_im = nnx.Param(B_im)
        self.C_re = nnx.Param(C_re)
        self.C_im = nnx.Param(C_im)
        self.D = nnx.Param(D)
        self.gamma_log = nnx.Param(gamma_log)

    def __call__(self, inputs: Array):  # type: ignore[override]
        # jax.debug.print("test:{}", jnp.sin(self.nu_log + self.theta_log))
        Lambda = jnp.exp(
            -jnp.exp(self.nu_log.value) + 1j * jnp.exp(self.theta_log.value)
        )
        B_norm = (self.B_re.value + 1j * self.B_im.value) * jnp.expand_dims(
            jnp.exp(self.gamma_log.value), axis=-1
        )
        # Running the LRU + output projection
        # For details on parallel scan, check discussion in Smith et al (2022).
        Lambda_elements = jnp.repeat(Lambda[None, ...], inputs.shape[0], axis=0)
        Bu_elements = jax.vmap(lambda u: B_norm @ u)(inputs)
        elements = (Lambda_elements, Bu_elements)
        C = self.C_re + 1j * self.C_im
        _, h = parallel_scan(binary_operator_diag, elements)
        y = jax.vmap(lambda x, u: (C @ x).real + self.D * u)(h, inputs)
        return y


In [None]:
rnn=1 #rnn=0:transformation of the inputs with fixed RNN weights, rnn=1: adding the RNN module on the model to learn the weight matrices
pool=0 #pooling layer after MLP is taking the average over the numbers
transformation=0 #transformation of the data from decimals between 0 and 256 to binary 8 bit numbers
leave_data=1 #download csv data of the results
hidden_neuron=384 #no details in the 2023 paper => 2024 paper fixed to 512 with LV system
encoded_size=512
hidden_size=384
learning_rate = 1e-3
momentum = 0.9
train_steps=30000
eval_every = 50
batch_size=50
r_min = 0.9
r_max = 0.999
max_phase = 6.28
depth=6
lr_factor=0.25
dropout=0.1
method_name="LRUMLP6"
dataset_name="CIFAR10"
folder_name="step30klayer"
rand=random.randint(0,10000)
rngs1=nnx.Rngs(rand)
print(rand)


In [None]:
def vec_bin_array(arr, m): #https://stackoverflow.com/questions/22227595/convert-integer-to-binary-array-with-suitable-padding
    """
    Arguments:
    arr: Numpy array of positive integers
    m: Number of bits of each integer to retain

    Returns a copy of arr with every element replaced with a bit vector.
    Bits encoded as int8's.12
    """

    to_str_func = np.vectorize(lambda x: np.binary_repr(x).zfill(m))
    strs = to_str_func(arr)
    ret = np.zeros(list(arr.shape) + [m], dtype=np.int8)
    for bit_ix in range(0, m):
        fetch_bit_func = np.vectorize(lambda x: x[bit_ix] == '1')
        ret[...,bit_ix] = fetch_bit_func(strs).astype("int8")

    return ret


#Import data

if dataset_name=="MNIST":
    dataset=tf.keras.datasets.mnist.load_data()
    train=dataset[0]
    test=dataset[1]

    train_x_seq=train[0].shape[0]
    train_x_len=int(jnp.prod(jnp.array(train[0].shape[1:])))
    test_x_seq=test[0].shape[0]
    test_x_len=int(jnp.prod(jnp.array(test[0].shape[1:])))
    if transformation:
        train_x_size=8
        test_x_size=8

        train_x=vec_bin_array(train[0],train_x_size)
        train_x=train_x.reshape((train_x_seq,train_x_len,train_x_size))

        train_y=train[1].reshape(train_x_seq)

        train_y_class=len(jnp.unique(train_y))

        test_x=vec_bin_array(test[0],test_x_size)
        test_x=test_x.reshape((test_x_seq,test_x_len,test_x_size))

        test_y=test[1].reshape(test_x_seq)

    else:
        train_x_size=1
        test_x_size=1
        train_x=train[0].reshape((train_x_seq,train_x_len,train_x_size))/255
        train_y=train[1].reshape(train_x_seq)
        train_y_class=len(jnp.unique(train_y))
        test_x=test[0].reshape((test_x_seq,test_x_len,test_x_size))/255
        test_y=test[1].reshape(test_x_seq)


if dataset_name=="CIFAR10":
    dataset=tf.keras.datasets.cifar10.load_data()
    train=dataset[0]
    test=dataset[1]

    train_x_seq=train[0].shape[0]
    train_x_len=int(jnp.prod(jnp.array(train[0].shape[1:-1])))
    train_x_size=int(jnp.prod(jnp.array(train[0].shape[-1])))

    test_x_seq=test[0].shape[0]
    test_x_len=int(jnp.prod(jnp.array(test[0].shape[1:-1])))
    test_x_size=int(jnp.prod(jnp.array(train[0].shape[-1])))

    if transformation:
        train_x_size=24
        test_x_size=24
        train_x=vec_bin_array(train[0],8)
        train_x=train_x.reshape((train_x_seq,train_x_len,test_x_size))

        train_y=train[1].reshape(train_x_seq)
        train_y_class=len(jnp.unique(train_y))

        test_x=vec_bin_array(test[0],8)
        test_x=test_x.reshape((test_x_seq,test_x_len,test_x_size))

        test_y=test[1].reshape(test_x_seq)

    else:
        train_x=train[0].reshape((train_x_seq,train_x_len,train_x_size))/255
        train_y=train[1].reshape(train_x_seq)
        train_y_class=len(jnp.unique(train_y))
        test_x=test[0].reshape((test_x_seq,test_x_len,test_x_size))/255
        test_y=test[1].reshape(test_x_seq)
        train_x_size=int(jnp.prod(jnp.array(train[0].shape[-1])))
        test_x_size=int(jnp.prod(jnp.array(train[0].shape[-1])))

print(train_x.shape)
print(train_y.shape)
print(test_x.shape)
print(test_y.shape)

In [None]:
train_ds=tf.data.Dataset.from_tensor_slices((jnp.real(train_x),jnp.array(train_y,dtype=int)))
test_ds=tf.data.Dataset.from_tensor_slices((jnp.real(test_x),jnp.array(test_y,dtype=int)))

train_ds = train_ds.repeat().shuffle(100)

# Group into batches of `batch_size` and skip incomplete batches, prefetch the next sample to improve latency.
train_ds = train_ds.batch(batch_size, drop_remainder=True).take(train_steps).prefetch(1)
# Group into batches of `batch_size` and skip incomplete batches, prefetch the next sample to improve latency.
test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)

In [None]:
lin_encoder_lr,lin_encoder_sigma=compute_lr_sigma(mode="input",d=train_x_size,m=encoded_size,k=0,L=1)
MLP1_lr,MLP1_sigma=compute_lr_sigma("input",encoded_size,hidden_neuron//2,0,1)
MLP2_lr,MLP2_sigma=compute_lr_sigma("output",0,hidden_neuron//2,encoded_size,1)
out_lr,out_sigma=compute_lr_sigma("output",encoded_size,encoded_size,train_y_class,1)

In [None]:
class MLP(nnx.Module):
    # DON'T FORGET TO CHANGE THE MODEL NAME BEFORE RUNNING
    # According to the scheme of the paper (Figure 1), input_size=M, encoded_size=H,layer_dim=number of neurons in MLP, out_dim=number of classes
    def __init__(
        self,
        token_size,
        token_len,
        encoded_dim,
        hidden_dim,
        layer_dim,
        out_dim,
        rngs: nnx.Rngs,
    ):

        self.lin_encoder = Linear_encoder(in_features=token_size, out_features=encoded_dim, rngs=rngs,kernel_init=jax.nn.initializers.variance_scaling(scale=lin_encoder_sigma,mode="fan_in",distribution="truncated_normal"))

        self.rnn1 = LRU(
            in_features=encoded_dim,
            hidden_features=hidden_dim,
            r_min=r_min,
            r_max=r_max,
            max_phase=max_phase,
        )
        self.rnn2 = LRU(
            in_features=encoded_dim,
            hidden_features=hidden_dim,
            r_min=r_min,
            r_max=r_max,
            max_phase=max_phase,
        )
        self.rnn3 = LRU(
            in_features=encoded_dim,
            hidden_features=hidden_dim,
            r_min=r_min,
            r_max=r_max,
            max_phase=max_phase,
        )
        self.rnn4 = LRU(
            in_features=encoded_dim,
            hidden_features=hidden_dim,
            r_min=r_min,
            r_max=r_max,
            max_phase=max_phase,
        )
        self.rnn5 = LRU(
            in_features=encoded_dim,
            hidden_features=hidden_dim,
            r_min=r_min,
            r_max=r_max,
            max_phase=max_phase,
        )
        self.rnn6 = LRU(
            in_features=encoded_dim,
            hidden_features=hidden_dim,
            r_min=r_min,
            r_max=r_max,
            max_phase=max_phase,
        )
        self.linear1_1 = Linear_MLP1(in_features=encoded_dim, out_features=layer_dim, rngs=rngs,kernel_init=jax.nn.initializers.variance_scaling(scale=MLP1_sigma,mode="fan_in",distribution="truncated_normal"))
        self.linear1_2 = Linear_MLP2(in_features=layer_dim // 2, out_features=encoded_dim, rngs=rngs,kernel_init=jax.nn.initializers.variance_scaling(scale=MLP2_sigma,mode="fan_in",distribution="truncated_normal"))
        self.linear2_1 = Linear_MLP1(in_features=encoded_dim, out_features=layer_dim, rngs=rngs,kernel_init=jax.nn.initializers.variance_scaling(scale=MLP1_sigma,mode="fan_in",distribution="truncated_normal"))
        self.linear2_2 = Linear_MLP2(in_features=layer_dim // 2, out_features=encoded_dim, rngs=rngs,kernel_init=jax.nn.initializers.variance_scaling(scale=MLP2_sigma,mode="fan_in",distribution="truncated_normal"))
        self.linear3_1 = Linear_MLP1(in_features=encoded_dim, out_features=layer_dim, rngs=rngs,kernel_init=jax.nn.initializers.variance_scaling(scale=MLP1_sigma,mode="fan_in",distribution="truncated_normal"))
        self.linear3_2 = Linear_MLP2(in_features=layer_dim // 2, out_features=encoded_dim, rngs=rngs,kernel_init=jax.nn.initializers.variance_scaling(scale=MLP2_sigma,mode="fan_in",distribution="truncated_normal"))
        self.linear4_1 = Linear_MLP1(in_features=encoded_dim, out_features=layer_dim, rngs=rngs,kernel_init=jax.nn.initializers.variance_scaling(scale=MLP1_sigma,mode="fan_in",distribution="truncated_normal"))
        self.linear4_2 = Linear_MLP2(in_features=layer_dim // 2, out_features=encoded_dim, rngs=rngs,kernel_init=jax.nn.initializers.variance_scaling(scale=MLP2_sigma,mode="fan_in",distribution="truncated_normal"))
        self.linear5_1 = Linear_MLP1(in_features=encoded_dim, out_features=layer_dim, rngs=rngs,kernel_init=jax.nn.initializers.variance_scaling(scale=MLP1_sigma,mode="fan_in",distribution="truncated_normal"))
        self.linear5_2 = Linear_MLP2(in_features=layer_dim // 2, out_features=encoded_dim, rngs=rngs,kernel_init=jax.nn.initializers.variance_scaling(scale=MLP2_sigma,mode="fan_in",distribution="truncated_normal"))
        self.linear6_1 = Linear_MLP1(in_features=encoded_dim, out_features=layer_dim, rngs=rngs,kernel_init=jax.nn.initializers.variance_scaling(scale=MLP1_sigma,mode="fan_in",distribution="truncated_normal"))
        self.linear6_2 = Linear_MLP2(in_features=layer_dim // 2, out_features=encoded_dim, rngs=rngs,kernel_init=jax.nn.initializers.variance_scaling(scale=MLP2_sigma,mode="fan_in",distribution="truncated_normal"))
        self.batchnorm1 = nnx.BatchNorm(num_features=encoded_dim, rngs=rngs)
        self.batchnorm2 = nnx.BatchNorm(num_features=encoded_dim, rngs=rngs)
        self.batchnorm3 = nnx.BatchNorm(num_features=encoded_dim, rngs=rngs)
        self.batchnorm4 = nnx.BatchNorm(num_features=encoded_dim, rngs=rngs)
        self.batchnorm5 = nnx.BatchNorm(num_features=encoded_dim, rngs=rngs)
        self.batchnorm6 = nnx.BatchNorm(num_features=encoded_dim, rngs=rngs)

        # Linear layers
        if pool:  # If pooling layer takes the average over the token sequence length
            self.linear3 = lambda x: jnp.mean(x, axis=1)
        else:  # learn the parameters of the linear transformation
            self.linear3 = nnx.Linear(in_features=token_len, out_features=1, rngs=rngs)
        self.linear4 = Linear_out(
            in_features=encoded_dim, out_features=out_dim, rngs=rngs, kernel_init=jax.nn.initializers.variance_scaling(scale=out_sigma,mode="fan_in",distribution="truncated_normal")
        )
        self.out_dim = out_dim
        self.token_len = token_len
        self.dropout = nnx.Dropout(dropout, rngs=rngs)
        
    @nnx.vmap(in_axes=(None, 0))
    def block_after_batchnorm1(self, x):
        x = self.rnn1(x)
        x = self.linear1_1(x)
        x = nnx.glu(x, axis=-1)
        x = self.dropout(x)
        x = self.linear1_2(x)
        return x

    @nnx.vmap(in_axes=(None, 0))
    def block_after_batchnorm2(self, x):
        x = self.rnn2(x)
        x = self.linear2_1(x)
        x = nnx.glu(x, axis=-1)
        x = self.dropout(x)
        x = self.linear2_2(x)
        return x

    @nnx.vmap(in_axes=(None, 0))
    def block_after_batchnorm3(self, x):
        x = self.rnn3(x)
        x = self.linear3_1(x)
        x = nnx.glu(x, axis=-1)
        x = self.dropout(x)
        x = self.linear3_2(x)
        return x

    @nnx.vmap(in_axes=(None, 0))
    def block_after_batchnorm4(self, x):
        x = self.rnn4(x)
        x = self.linear4_1(x)
        x = nnx.glu(x, axis=-1)
        x = self.dropout(x)
        x = self.linear4_2(x)
        return x

    @nnx.vmap(in_axes=(None, 0))
    def block_after_batchnorm5(self, x):
        x = self.rnn5(x)
        x = self.linear5_1(x)
        x = nnx.glu(x, axis=-1)
        x = self.dropout(x)
        x = self.linear5_2(x)
        return x

    @nnx.vmap(in_axes=(None, 0))
    def block_after_batchnorm6(self, x):
        x = self.rnn6(x)
        x = self.linear6_1(x)
        x = nnx.glu(x, axis=-1)
        x = self.dropout(x)
        x = self.linear6_2(x)
        return x

    @nnx.vmap(in_axes=(None, 0))
    def final_linear_projections(self, x):
        x = self.linear3(x.T)
        x = self.linear4(x.T)
        return x.reshape(self.out_dim)

    def __call__(self, x):
        x = self.lin_encoder(x)
        y = x.copy()

        # LRU+MLP block*6
        x = self.batchnorm1(x)
        x = self.block_after_batchnorm1(x)
        x += y

        x = self.batchnorm2(x)
        x = self.block_after_batchnorm2(x)
        x += y

        x = self.batchnorm3(x)
        x = self.block_after_batchnorm3(x)
        x += y

        x = self.batchnorm4(x)
        x = self.block_after_batchnorm4(x)
        x += y

        x = self.batchnorm5(x)
        x = self.block_after_batchnorm5(x)
        x += y

        x = self.batchnorm6(x)
        x = self.block_after_batchnorm6(x)
        x += y

        return self.final_linear_projections(x)



model = MLP(
    train_x_size,
    train_x_len,
    encoded_size,
    hidden_size,
    hidden_neuron,
    train_y_class,
    rngs=rngs1,
)  # eager initialization

#nnx.display(model)

In [None]:
def group_tuples_to_nested_dict(params):
    nested_dict = {}
    for outer_key, inner_key in params:
        if outer_key not in nested_dict:
            nested_dict[outer_key] = {}
        nested_dict[outer_key][inner_key] = inner_key
    return nested_dict
param = nnx.state(model,nnx.Param).flat_state()
gr=group_tuples_to_nested_dict(list(param.keys()))

In [None]:

#Set optimization method per layer
rnn_schedule=optax.warmup_cosine_decay_schedule(init_value=1e-7*lr_factor, peak_value=learning_rate*lr_factor, warmup_steps=train_steps//10, decay_steps=train_steps, end_value=1e-7*lr_factor)
lin_schedule=optax.warmup_cosine_decay_schedule(init_value=1e-7, peak_value=learning_rate, warmup_steps=train_steps//10, decay_steps=train_steps, end_value=1e-7)
lin_encoder_schedule=optax.warmup_cosine_decay_schedule(init_value=1e-7*lin_encoder_lr, peak_value=learning_rate*lin_encoder_lr, warmup_steps=train_steps//10, decay_steps=train_steps, end_value=1e-7*lin_encoder_lr)
MLP1_schedule=optax.warmup_cosine_decay_schedule(init_value=1e-7*MLP1_lr, peak_value=learning_rate*MLP1_lr, warmup_steps=train_steps//10, decay_steps=train_steps, end_value=1e-7*MLP1_lr)
MLP2_schedule=optax.warmup_cosine_decay_schedule(init_value=1e-7*MLP2_lr, peak_value=learning_rate*MLP2_lr, warmup_steps=train_steps//10, decay_steps=train_steps, end_value=1e-7*MLP2_lr)
out_schedule=optax.warmup_cosine_decay_schedule(init_value=1e-7*out_lr, peak_value=learning_rate*out_lr, warmup_steps=train_steps//10, decay_steps=train_steps, end_value=1e-7*out_lr)

#lr factor only on A and B so C,D same as the linear layers -> although going up slower, less overfitting -> why?
#No sensitive fitting -> less overfitting?
d={"B_re":optax.adamw(rnn_schedule),
   "B_im":optax.adamw(rnn_schedule),
    'C_im': optax.adamw(lin_schedule,weight_decay=0.05),
    'C_re': optax.adamw(lin_schedule,weight_decay=0.05),
    'D': optax.adamw(lin_schedule,weight_decay=0.05),
    'gamma_log': optax.adamw(rnn_schedule),
    'nu_log': optax.adamw(rnn_schedule),
    'theta_log': optax.adamw(rnn_schedule),
    
    #if linear3 not pooling
    'kernel':optax.adamw(lin_schedule,weight_decay=0.05),
    'bias':optax.adamw(lin_schedule,weight_decay=0.05),

     "lin_encoder_kernel": optax.adamw(lin_encoder_schedule,weight_decay=0.05),
     "MLP1_kernel": optax.adamw(MLP1_schedule,weight_decay=0.05),
     "MLP2_kernel": optax.adamw(MLP2_schedule,weight_decay=0.05),
     "out_kernel": optax.adamw(out_schedule,weight_decay=0.05),

     "lin_encoder_bias": optax.adamw(lin_encoder_schedule,weight_decay=0.05),
     "MLP1_bias": optax.adamw(MLP1_schedule,weight_decay=0.05),
     "MLP2_bias": optax.adamw(MLP2_schedule,weight_decay=0.05),
     "out_bias": optax.adamw(out_schedule,weight_decay=0.05),

     "scale": optax.adamw(lin_schedule,weight_decay=0.05),
     }

tx=optax.multi_transform(d,nnx.State(gr))

optimizer = nnx.Optimizer(model, tx)
metrics = nnx.MultiMetric(
    accuracy=nnx.metrics.Accuracy(),
    loss=nnx.metrics.Average("loss"),)

In [None]:
def loss_fn(model: MLP, batch):
  logits = model(batch[0])
  loss = optax.softmax_cross_entropy_with_integer_labels(
    logits=logits, labels=batch[1]
  ).mean()
  #print(logits.shape)
  #print(batch[1].shape)
  return loss, logits

@nnx.jit
def train_step(model: MLP, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
  """Train for a single step."""
  grad_fn = nnx.value_and_grad(loss_fn,has_aux=True)
  (loss, logits), grads = grad_fn(model, batch)
  metrics.update(loss=loss, logits=logits, labels=batch[1])  # In-place updates.
  optimizer.update(grads)  # In-place updates.
  predicted_labels = jnp.argmax(logits, axis=-1)
  actual_labels = batch[1]
  #jax.debug.print("Predictions: {}",predicted_labels[:5].astype(int))
  #jax.debug.print("Actual Labels: {}",actual_labels[:5].astype(int))

@nnx.jit
def eval_step(model: MLP, metrics: nnx.MultiMetric, batch):
  loss, logits = loss_fn(model, batch)
  metrics.update(loss=loss, logits=logits, labels=batch[1])  # In-place updates.

In [None]:
#Train the model + evaluation with the test data
metrics_history = {
    'train_loss': [],
    'train_accuracy': [],
    'test_loss': [],
    'test_accuracy': [],
}

for step, batch in enumerate(train_ds.as_numpy_iterator()):
  # Run the optimization for one step and make a stateful update to the following:
  # - The train state's model parameters
  # - The optimizer state
  # - The training loss and accuracy batch metrics
  train_step(model, optimizer, metrics, batch)

  if step > 0 and (step % eval_every == 0 or step == train_steps - 1):  # One training epoch has passed.
    # Log the training metrics.
    for metric, value in metrics.compute().items():  # Compute the metrics.
      metrics_history[f'train_{metric}'].append(value)  # Record the metrics.
    metrics.reset()  # Reset the metrics for the test set.

    # Compute the metrics on the test set after each training epoch.
    for test_batch in test_ds.as_numpy_iterator():
      eval_step(model, metrics, test_batch)

    # Log the test metrics.
    for metric, value in metrics.compute().items():
      metrics_history[f'test_{metric}'].append(value)
    metrics.reset()  # Reset the metrics for the next training epoch.

    print(
      f"[train] step: {step}       , "
      f"loss: {metrics_history['train_loss'][-1]}      , "
      f"accuracy: {metrics_history['train_accuracy'][-1] * 100}     "
    )
    print(
      f"[test] step: {step}        , "
      f"loss: {metrics_history['test_loss'][-1]}       , "
      f"accuracy: {metrics_history['test_accuracy'][-1] * 100}      "
    )

In [None]:
#Save the training results into csv
import pandas as pd

if leave_data:
    data=pd.DataFrame({"step":np.arange(eval_every,train_steps+eval_every,eval_every),"train_loss":metrics_history['train_loss'],
                       "test_loss":metrics_history['test_loss'],"train_accuracy":metrics_history['train_accuracy'],
                       "test_accuracy":metrics_history['test_accuracy']})
    data.to_csv(method_name+"_H"+str(encoded_size)+"_nr"+str(hidden_neuron)+"_D"+str(hidden_size)+"_"+dataset_name+"_lr"+str(learning_rate)+"_step"+str(train_steps)+"r_min_"+str(r_min)+"r_max"+str(r_max)+"_rand"+str(rand)+".csv")

In [None]:
#Plot the loss
import matplotlib.pyplot as plt

plt.plot(np.arange(eval_every,train_steps+eval_every,eval_every),metrics_history['train_loss'],label="train loss")
plt.plot(np.arange(eval_every,train_steps+eval_every,eval_every),metrics_history['test_loss'],label="test loss")
plt.title("Train loss of "+dataset_name+" dataset with "+method_name+
              ", \nhidden dimension="+str(hidden_size)+", number of neuron="+str(hidden_neuron))
plt.xlabel("Training step")
plt.ylabel("Train loss (cross entropy)")
plt.legend()
if leave_data:
    plt.savefig("loss_"+method_name+"_"+str(encoded_size)+"_"+str(hidden_neuron)+"_"+dataset_name+"_lr"+str(learning_rate)+"_step"+str(train_steps)+"r_min_"+str(r_min)+"r_max"+str(r_max)+"_rand"+str(rand)+".jpg")
plt.show()

In [None]:
#Plot the accuracy
plt.plot(np.arange(eval_every,train_steps+eval_every,eval_every),metrics_history['train_accuracy'],label="train")
plt.plot(np.arange(eval_every,train_steps+eval_every,eval_every),metrics_history['test_accuracy'],label="test")
plt.title("Accuracy of "+dataset_name+" dataset with "+method_name+", \nhidden dimension="+
          str(hidden_size)+", number of neuron="+str(hidden_neuron))
plt.ylabel("Accuracy")
plt.legend()
if leave_data:
    plt.savefig("accuracy_"+method_name+"_"+str(encoded_size)+"_"+str(hidden_neuron)+"_"+dataset_name+"_lr"+str(learning_rate)+"_step"+str(train_steps)+"r_min_"+str(r_min)+"r_max"+str(r_max)+"_rand"+str(rand)+".jpg")
plt.show()