#### **Import Libraries**

In [35]:
from typing import Tuple, Sequence
import jax 
import jax.numpy as jnp
from jax import Array
import flax.linen as nn 
from functools import partial 
from einops import rearrange, repeat
import optax 
from diffrax import diffeqsolve, ODETerm, Dopri5
from tqdm import tqdm 
from rfp import MLP, Model, ModelParams
from rfp.losses import Supervised_Loss, mse
from rfp.train import Trainer

#### **Path**

In [3]:
figure_folders = './../../../rfp_paper/figures/'

#### **Set Up Plotting**

In [4]:
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib import rcParams
rcParams['image.interpolation'] = 'nearest'
rcParams['image.cmap'] = 'viridis'
rcParams['axes.grid'] = False
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
plt.style.use('seaborn-v0_8-dark-palette')

from matplotlib import font_manager 
locations = './../../styles/Newsreader'
font_files = font_manager.findSystemFonts(fontpaths=locations)
print(locations)
print(font_files[0])
for f in font_files: 
    font_manager.fontManager.addfont(f)
plt.rcParams["font.family"] = "Newsreader"

./../../styles/Newsreader
/home/ubuntu/rfp/styles/Newsreader/static/Newsreader_9pt/Newsreader_9pt-LightItalic.ttf


#### **Data Hyperparameters**

In [52]:
n = 20             
d = 10             
c = 20            
init_key = jax.random.PRNGKey(0)
nodes = 32 
lr = 1e-3
epochs = 100 

#### **Value to Key**

In [69]:
def value_to_key(value: float) -> jax.random.PRNGKey:
    # Ensure value is in the range [0, 1]
    if not (0.0 <= value <= 1.0):
        raise ValueError("Value must be between 0.0 and 1.0")
    
    # Scale value to the range of PRNG key integers
    max_int = jnp.iinfo(jnp.int32).max
    scaled_value = jnp.array(value * max_int, dtype=jnp.int32)
    
    # Create a new key using the scaled value
    key = jax.random.PRNGKey(scaled_value)
    
    return key

#### **Model**

In [54]:
class SimpleMLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, inputs):
    x = inputs
    for i, feat in enumerate(self.features):
      x = nn.Dense(feat, name=f'layers_{i}')(x)
      if i != len(self.features) - 1:
        x = nn.relu(x)
    return x

#### **Generate Conditional Expectation Functions**

In [64]:
model = SimpleMLP([nodes, nodes, 1])
x = jax.random.normal(jax.random.PRNGKey(0), (d,)) 
cs = jax.random.normal(init_key, shape=(c,))
cluster_keys = jax.vmap(value_to_key)(cs)
cluster_params = jax.vmap(model.init, in_axes=(0, None))(cluster_keys, x)