In [None]:
import tensorflow_datasets as tfds  # TFDS to download MNIST.
import tensorflow as tf  # TensorFlow / `tf.data` operations.
from flax import nnx  # The Flax NNX API.
from functools import partial
import jax.numpy as jnp  # JAX NumPy
import jax
import optax
from linearRNN import forward_h
from linearRNN import forward
from linearRNN import init_lru_parameters
from linearRNN import binary_operator_diag
from linearRNN import LRU
import numpy as np
from flax import linen as nn


In [None]:
pool=0 #pooling layer after MLP is taking the average over the numbers
transformation=0 #transformation of the data from decimals between 0 and 255 to binary 8 bit numbers
leave_data=1 #download csv data of the results
hidden_neuron=128 #no details in the 2023 paper => 2024 paper fixed to 512
encoded_size=256
hidden_size=128
learning_rate = 0.004
momentum = 0.9
train_steps=3000
eval_every = 50
batch_size=50
r_min = 0
r_max = 1
max_phase = 6.28
depth=1
method_name="LRUMLP"
dataset_name="MNIST"

In [None]:
def lr(mode:str,d,m,l,k):
    lr=0
    sigma=0
    if mode=="input":
        lr=m/(jnp.power(l,3/2)*d)
        sigma=1/jnp.sqrt(d)
    if mode=="hidden":
        lr=1/jnp.power(l,3/2)
        sigma=2/jnp.sqrt((m+d)/2)
    if mode=="output":
        lr=k/(jnp.power(l,3/2)*m)
        sigma=jnp.sqrt(k)/m
    else:
        raise ValueError
    return lr, sigma

In [None]:
def vec_bin_array(arr, m):#https://stackoverflow.com/questions/22227595/convert-integer-to-binary-array-with-suitable-padding
    """
    Arguments: 
    arr: Numpy array of positive integers
    m: Number of bits of each integer to retain

    Returns a copy of arr with every element replaced with a bit vector.
    Bits encoded as int8's.
    """
    to_str_func = np.vectorize(lambda x: np.binary_repr(x).zfill(m))
    strs = to_str_func(arr)
    ret = np.zeros(list(arr.shape) + [m], dtype=np.int8)
    for bit_ix in range(0, m):
        fetch_bit_func = np.vectorize(lambda x: x[bit_ix] == '1')
        ret[...,bit_ix] = fetch_bit_func(strs).astype("int8")

    return ret 

In [None]:
#Import data

if dataset_name=="MNIST":
    dataset=tf.keras.datasets.mnist.load_data()
    train=dataset[0]
    test=dataset[1]

    train_x_seq=train[0].shape[0]
    train_x_len=int(jnp.prod(jnp.array(train[0].shape[1:])))
    train_x_size=1
    test_x_seq=test[0].shape[0]
    test_x_len=int(jnp.prod(jnp.array(test[0].shape[1:])))
    test_x_size=1
    if transformation: #transform the information of the pixel to 8-bit binary numbers
        train_x=train[0].reshape((train_x_seq,train_x_len,train_x_size))
        train_x=vec_bin_array(train_x,8)
        train_y=train[1].reshape(train_x_seq)
        train_y_class=len(jnp.unique(train_y))
        test_x=test[0].reshape((test_x_seq,test_x_len,test_x_size))
        test_x=vec_bin_array(test_x,8)
        test_y=test[1].reshape(test_x_seq)
    else:
        train_x=train[0].reshape((train_x_seq,train_x_len,train_x_size))/255
        train_y=train[1].reshape(train_x_seq)
        train_y_class=len(jnp.unique(train_y))
        test_x=test[0].reshape((test_x_seq,test_x_len,test_x_size))/255
        test_y=test[1].reshape(test_x_seq)

if dataset_name=="CIFAR10":
    dataset=tf.keras.datasets.cifar10.load_data()
    train=dataset[0]
    test=dataset[1]

    train_x_seq=train[0].shape[0]
    train_x_len=int(jnp.prod(jnp.array(train[0].shape[1:-1])))
    train_x_size=int(jnp.prod(jnp.array(train[0].shape[-1])))

    test_x_seq=test[0].shape[0]
    test_x_len=int(jnp.prod(jnp.array(test[0].shape[1:-1])))
    test_x_size=int(jnp.prod(jnp.array(train[0].shape[-1])))

    if transformation:#transform the information of the pixel to 3*8-bit binary numbers
        train_x=train[0].reshape((train_x_seq,train_x_len,train_x_size))
        train_x=vec_bin_array(train_x,8)
        train_y=train[1].reshape(train_x_seq)
        train_y_class=len(jnp.unique(train_y))
        test_x=test[0].reshape((test_x_seq,test_x_len,test_x_size))
        test_x=vec_bin_array(test_x,8)
        test_y=test[1].reshape(test_x_seq)
    else:
        train_x=train[0].reshape((train_x_seq,train_x_len,train_x_size))/255
        train_y=train[1].reshape(train_x_seq)
        train_y_class=len(jnp.unique(train_y))
        test_x=test[0].reshape((test_x_seq,test_x_len,test_x_size))/255
        test_y=test[1].reshape(test_x_seq)


print(train_x.shape)
print(train_y.shape)
print(test_x.shape)
print(test_y.shape)


In [None]:
train_ds=tf.data.Dataset.from_tensor_slices((jnp.real(train_x),jnp.array(train_y,dtype=int)))
test_ds=tf.data.Dataset.from_tensor_slices((jnp.real(test_x),jnp.array(test_y,dtype=int)))

train_ds = train_ds.repeat().shuffle(100)
# Group into batches of `batch_size` and skip incomplete batches, prefetch the next sample to improve latency.
train_ds = train_ds.batch(batch_size, drop_remainder=True).take(train_steps).prefetch(1)
# Group into batches of `batch_size` and skip incomplete batches, prefetch the next sample to improve latency.
test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)

In [None]:
from flax import nnx
import optax
from flax.nnx.nn.recurrent import LSTMCell,GRUCell
import copy
import random
class MLP(nnx.Module):
  #DON'T FORGET TO CHANGE THE MODEL NAME BEFORE RUNNING
  #According to the scheme of the paper (Figure 1), input_size=M, encoded_size=H,layer_dim=number of neurons in MLP, out_dim=number of classes
  def __init__(self, token_size, token_len,encoded_dim,hidden_dim, layer_dim, out_dim, rngs: nnx.Rngs):

    #linear encoder
    self.lin_encoder = nnx.Linear(in_features=token_size, out_features=encoded_dim,rngs=rngs)
    #self.lin_encoder=nnx.Param(jnp.array(np.random.rand(token_size,encoded_dim)))
    #LRU+MLP block
    self.rnn = LRU(in_features=encoded_dim, hidden_features=hidden_dim, r_min=r_min,r_max=r_max,max_phase=max_phase)
    self.linear1 = nnx.Linear(in_features=encoded_dim, out_features=layer_dim, rngs=rngs)
    self.linear2 = nnx.Linear(in_features=layer_dim//2,out_features=encoded_dim,rngs=rngs)
    self.batchnorm=nnx.BatchNorm(num_features=encoded_dim,rngs=rngs,use_running_average=True)
    #Linear layers
    if pool: #If pooling layer takes the average over the token sequence length 
      self.linear3=(lambda x: jnp.mean(x,axis=1))
    else: #learn the parameters of the linear transformation
      self.linear3= nnx.Linear(in_features=token_len,out_features=1,rngs=rngs)
    self.linear4= nnx.Linear(in_features=encoded_dim,out_features=out_dim,rngs=rngs)
    #self.weight = nnx.Param(jnp.array(np.random.rand(token_len,1)))
    #self.bias = nnx.Param(jnp.array(np.random.rand(encoded_dim,1)))
    #self.weight2 = nnx.Param(jnp.array(np.random.rand(out_dim,encoded_dim)))
    #self.bias2 = nnx.Param(jnp.array(np.random.rand(out_dim,1)))
    self.out_dim = out_dim
    self.token_len=token_len

    
  @nnx.vmap(in_axes=(None,0)) 
  def __call__(self, x):
    x = self.lin_encoder(x)
    #x=x@self.lin_encoder
    y = x.copy()
    #LRU+MLP block
    for i in range(depth):
      x = self.rnn(x)
      x = self.linear1(x)
      x = nnx.glu(x,axis=-1)
      x = self.linear2(x)
      x += y #Skip connection -> p.21 adding for each block
      x=self.batchnorm(x)#batch normalization

    #x = x.T@self.weight #+ self.bias #project from L*H to H*1
    #x = self.weight2@x #+ self.bias2#project from H*1 to out_dim
    x=self.linear3(x.T)
    x=self.linear4(x.T)
    return x.reshape(self.out_dim)


model = MLP(train_x_size,train_x_len,encoded_size,hidden_size, hidden_neuron, train_y_class, rngs=nnx.Rngs(0))  # eager initialization

nnx.display(model)

MLP(
  lin_encoder=Linear(
    kernel=Param(
      value=Array(shape=(1, 256), dtype=float32)
    ),
    bias=Param(
      value=Array(shape=(256,), dtype=float32)
    ),
    in_features=1,
    out_features=256,
    use_bias=True,
    dtype=None,
    param_dtype=<class 'jax.numpy.float32'>,
    precision=None,
    kernel_init=<function variance_scaling.<locals>.init at 0x0000024A433FAE80>,
    bias_init=<function zeros at 0x0000024A400263E0>,
    dot_general=<function dot_general at 0x0000024A3FB90FE0>
  ),
  rnn=LRU(
    in_features=256,
    hidden_features=128,
    nu_log=Param(
      value=Array(shape=(128,), dtype=float64)
    ),
    theta_log=Param(
      value=Array(shape=(128,), dtype=float64)
    ),
    B_re=Param(
      value=Array(shape=(128, 256), dtype=float64)
    ),
    B_im=Param(
      value=Array(shape=(128, 256), dtype=float64)
    ),
    C_re=Param(
      value=Array(shape=(256, 128), dtype=float64)
    ),
    C_im=Param(
      value=Array(shape=(256, 128), dtype=flo

In [71]:
#Test the model with the first batch
for step, batch in enumerate(train_ds.as_numpy_iterator()):
    batch1=batch
    break
a=model(batch[0])
print(a)

[[ 3.03168893e-01 -9.66499567e-01  1.49352655e-01 -8.09825361e-01
   4.07850713e-01 -9.96469975e-01 -4.30168621e-02  2.73884200e-02
  -8.04168224e-01 -6.10215425e-01]
 [ 1.59678176e-01 -4.98017341e-01  1.36498183e-01 -3.53908539e-01
   1.80679873e-01 -4.69490588e-01 -4.97057177e-02  1.37789287e-02
  -4.29928005e-01 -3.55306953e-01]
 [ 3.59912932e-01 -8.56598556e-01  2.32310012e-01 -7.57348418e-01
   3.99321884e-01 -8.78117383e-01  6.85079098e-02  1.50126338e-01
  -8.27721655e-01 -4.97802675e-01]
 [ 4.93077710e-02 -7.27952600e-01  1.02801889e-01 -5.36433995e-01
   2.78145850e-01 -6.34684324e-01  7.33526349e-02 -6.96805641e-02
  -5.64702094e-01 -3.73918504e-01]
 [ 3.80657285e-01 -5.62055886e-01  1.57801181e-01 -4.98182625e-01
   2.35917374e-01 -6.69640064e-01 -3.03118993e-02  1.45991340e-01
  -5.50191522e-01 -3.86757404e-01]
 [-1.12192608e-01 -8.37540090e-01  1.55021325e-01 -6.01021886e-01
   2.32720032e-01 -6.22119606e-01  1.71095565e-01 -2.57776558e-01
  -4.74768907e-01 -3.05242956e-01

In [None]:
import optax

#scheduler = optax.piecewise_constant_schedule(init_value=learning_rate, boundaries_and_scales={int(train_steps*0.1):0.1})
#optimizer = nnx.Optimizer(model, optax.adamw(scheduler, momentum,weight_decay=0.05))
optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum,weight_decay=0.05))
metrics = nnx.MultiMetric(
  accuracy=nnx.metrics.Accuracy(),
  loss=nnx.metrics.Average('loss'),
)

nnx.display(optimizer)

In [None]:
import jax
def loss_fn(model: MLP, batch):
  logits = model(batch[0])
  loss = optax.softmax_cross_entropy_with_integer_labels(
    logits=logits, labels=batch[1]
  ).mean()
  #print(logits.shape)
  #print(batch[1].shape)
  return loss, logits

@nnx.jit
def train_step(model: MLP, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
  """Train for a single step."""
  grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
  (loss, logits), grads = grad_fn(model, batch)
  metrics.update(loss=loss, logits=logits, labels=batch[1])  # In-place updates.
  optimizer.update(grads)  # In-place updates.

  #Print the predicted labels and the actual labels of the first five images from the batch
  #predicted_labels = jnp.argmax(logits, axis=-1)
  #actual_labels = batch[1]
  #jax.debug.print("Predictions: {}",predicted_labels[:5].astype(int))
  #jax.debug.print("Actual Labels: {}",actual_labels[:5].astype(int))
  
@nnx.jit
def eval_step(model: MLP, metrics: nnx.MultiMetric, batch):
  loss, logits = loss_fn(model, batch)
  metrics.update(loss=loss, logits=logits, labels=batch[1])  # In-place updates.

In [None]:
#Train the model + evaluation with the test data
metrics_history = {
    'train_loss': [],
    'train_accuracy': [],
    'test_loss': [],
    'test_accuracy': [],
}

for step, batch in enumerate(train_ds.as_numpy_iterator()):
  # Run the optimization for one step and make a stateful update to the following:
  # - The train state's model parameters
  # - The optimizer state
  # - The training loss and accuracy batch metrics
  train_step(model, optimizer, metrics, batch)

  if step > 0 and (step % eval_every == 0 or step == train_steps - 1):  # One training epoch has passed.
    # Log the training metrics.
    for metric, value in metrics.compute().items():  # Compute the metrics.
      metrics_history[f'train_{metric}'].append(value)  # Record the metrics.
    metrics.reset()  # Reset the metrics for the test set.

    # Compute the metrics on the test set after each training epoch.
    for test_batch in test_ds.as_numpy_iterator():
      eval_step(model, metrics, test_batch)

    # Log the test metrics.
    for metric, value in metrics.compute().items():
      metrics_history[f'test_{metric}'].append(value)
    metrics.reset()  # Reset the metrics for the next training epoch.

    print(
      f"[train] step: {step}, "
      f"loss: {metrics_history['train_loss'][-1]}, "
      f"accuracy: {metrics_history['train_accuracy'][-1] * 100}"
    )
    print(
      f"[test] step: {step}, "
      f"loss: {metrics_history['test_loss'][-1]}, "
      f"accuracy: {metrics_history['test_accuracy'][-1] * 100}"
    )

In [None]:
#Save the training results into csv file
import pandas as pd

if leave_data:
    data=pd.DataFrame({"step":np.arange(eval_every,train_steps+eval_every,eval_every),"train_loss":metrics_history['train_loss'],
                       "test_loss":metrics_history['test_loss'],"train_accuracy":metrics_history['train_accuracy'],
                       "test_accuracy":metrics_history['test_accuracy']})

    data.to_csv(method_name+"_enc"+str(encoded_size)+"_nr"+str(hidden_neuron)+"_d"+str(hidden_size)+"_"+dataset_name+"_step"+str(train_steps)+"r_min_"+str(r_min)+"r_max"+str(r_max)+".csv")

In [None]:
#Plot the loss 
import matplotlib.pyplot as plt

plt.plot(np.arange(eval_every,train_steps+eval_every,eval_every),metrics_history['train_loss'],label="train loss")
plt.plot(np.arange(eval_every,train_steps+eval_every,eval_every),metrics_history['test_loss'],label="test loss")
plt.title("Train loss of "+dataset_name+" dataset with "+method_name+
              ", \nhidden dimension="+str(hidden_size)+", number of neuron="+str(hidden_neuron))
plt.xlabel("Training step")
plt.ylabel("Train loss (cross entropy)")
plt.legend()
if leave_data:
    plt.savefig("loss_"+method_name+"_"+str(encoded_size)+"_"+str(hidden_neuron)+"_"+dataset_name+"_step"+str(train_steps)+"r_min_"+str(r_min)+"r_max"+str(r_max)+".jpg")
plt.show()

In [None]:
#Plot the accuracy
plt.plot(np.arange(eval_every,train_steps+eval_every,eval_every),metrics_history['train_accuracy'],label="train")
plt.plot(np.arange(eval_every,train_steps+eval_every,eval_every),metrics_history['test_accuracy'],label="test")
plt.title("Accuracy of "+dataset_name+" dataset with "+method_name+", \nhidden dimension="+
          str(hidden_size)+", number of neuron="+str(hidden_neuron))
plt.ylabel("Accuracy")
plt.legend()
if leave_data:
    plt.savefig("accuracy_"+method_name+"_"+str(encoded_size)+"_"+str(hidden_neuron)+"_"+dataset_name+"_step"+str(train_steps)+"r_min_"+str(r_min)+"r_max"+str(r_max)+".jpg")
plt.show()