#### **Import Libraries**

In [None]:
from typing import Tuple, Sequence
import jax 
import jax.numpy as jnp
from jax import Array
import flax.linen as nn 
from functools import partial 
from tqdm import tqdm 
from einops import rearrange, repeat
import optax 
from diffrax import diffeqsolve, ODETerm, Dopri5
from tqdm import tqdm 
from rfp import MLP, Model, ModelParams
from rfp.losses import Supervised_Loss, mse, Cluster_Loss
from rfp.train import Trainer

#### **Path**

In [None]:
figure_folders = './../../../rfp_paper/figures/'

#### **Set Up Plotting**

In [None]:
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib import rcParams
rcParams['image.interpolation'] = 'nearest'
rcParams['image.cmap'] = 'viridis'
rcParams['axes.grid'] = False
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
plt.style.use('seaborn-v0_8-dark-palette')

from matplotlib import font_manager 
locations = './../../styles/Newsreader'
font_files = font_manager.findSystemFonts(fontpaths=locations)
print(locations)
print(font_files[0])
for f in font_files: 
    font_manager.fontManager.addfont(f)
plt.rcParams["font.family"] = "Newsreader"

#### **Data Hyperparameters**

In [None]:
n = 200             
d = 1             
c = 100
k = 10             
init_key = jax.random.PRNGKey(0)
nodes = 128 
lr = 1e-3
epochs = 2000 
inner_epochs = 2
simulations = 100
reg_value = 0.9

#### **Value to Key**

In [None]:
def value_to_key(value: float) -> jax.random.PRNGKey:
    # Ensure value is in the range [0, 1]
    
    # Scale value to the range of PRNG key integers
    max_int = jnp.iinfo(jnp.int32).max
    scaled_value = jnp.array(value * max_int, dtype=jnp.int32)
    
    # Create a new key using the scaled value
    key = jax.random.PRNGKey(scaled_value)
    
    return key

#### **Model**

In [None]:
class SimpleMLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, inputs):
    x = inputs
    for i, feat in enumerate(self.features):
      x = nn.Dense(feat, name=f'layers_{i}')(x)
      if i != len(self.features) - 1:
        x = nn.hard_swish(x)
    return x

#### **Generate Conditional Expectation Functions**

In [None]:
base_model = SimpleMLP( [nodes, nodes, 1])

def gen_outcome(params_key, features):
    a, b, c = jax.random.normal(params_key, shape=(3,))
    f = lambda x: a*x**2 + b*x + c
    return jax.vmap(f)(features)

#### **Sample Data**

In [None]:
def sample(key, n, d):
    k1, k2, k3 = jax.random.split(key, 3)
    means = jax.random.normal(k1, shape=(d,))
    cluster_feature = jax.random.uniform(k2)
    params_key = value_to_key(cluster_feature)
    xs = jax.random.multivariate_normal(k2, mean=means, cov=jnp.eye(d), shape=(n,))
    features = xs 
    outcomes = gen_outcome(params_key, features)
    #outputs, params = base_model.init_with_output(params_key, features)
    return features, outcomes

In [None]:
mlp = MLP([nodes, nodes], jax.nn.relu)
final_activation_fn = lambda x: x; print(f"Final Activation Function: {final_activation_fn}")
model = Model(mlp, final_activation_fn)
supervised_loss = Supervised_Loss(mse, model.fwd_pass) #*** I am not sure this is correct***
standard_yuri = Trainer(supervised_loss, optax.sgd(learning_rate=lr, momentum=0.9), epochs)


inner_yuri = Trainer(supervised_loss, optax.sgd(learning_rate=lr), inner_epochs)
cluster_loss_train = Cluster_Loss(inner_yuri, reg_value)
cluster_loss_val = Cluster_Loss(inner_yuri, 0.0)
cluster_yuri = Trainer(cluster_loss_train, optax.sgd(learning_rate=lr, momentum=0.9), epochs, val_loss_fn=cluster_loss_val)

In [None]:
def get_data_params(key, n, c, d):
    k1, k2, k3 = jax.random.split(key, 3)
    batch_inputs, batch_outcomes  = jax.vmap(partial(sample, n=n, d=d))(jax.random.split(k1, c))
    inputs = rearrange(batch_inputs, 'a b c -> (a b) c')
    outcomes = rearrange(batch_outcomes, 'a b c -> (a b) c')
    inputs_standardized = jax.nn.standardize(inputs, axis=0)
    batch_inputs_standarized = rearrange(inputs_standardized, '(a b) c -> a b c', a=c)

    train_group = jax.random.choice(k2, jnp.arange(c), replace=False, shape=(k,))
    training_group = jnp.isin(jnp.arange(c), train_group)
    training_mask = jnp.repeat(training_group, n)
    validation_mask = ~training_mask
    batch_training_mask = rearrange(training_mask, '(a b) -> a b', a=c)
    batch_validation_mask = rearrange(validation_mask, '(a b) -> a b', a=c)

    params = ModelParams.init_fn(k3, mlp, d)

    return {'inputs_standarized': inputs_standardized,
            'batch_inputs_standarized': batch_inputs_standarized,
            'outcomes': outcomes,
            'batch_outcomes': batch_outcomes, 
            'training_group': training_group,
            'training_mask': training_mask, 
            'batch_training_mask': batch_training_mask,
            'validation_mask': validation_mask, 
            'batch_validation_mask': batch_validation_mask,
            'params': params, 
            }


In [None]:
def cluster_simulate(key, n, c, d):
    setup = get_data_params(key, n, c, d)
    params, opt_params, training_loss, validation_loss = cluster_yuri.train_with_val(setup['params'],
                                                                                setup['batch_inputs_standarized'],
                                                                                setup['batch_outcomes'],
                                                                                mask=jnp.ones_like(setup['batch_outcomes']), 
                                                                                train_idx=setup['batch_training_mask'], 
                                                                                val_idx=setup['batch_validation_mask'])
    return  training_loss, validation_loss, (setup['batch_inputs_standarized'], setup['batch_outcomes'], jax.vmap(model.fwd_pass, in_axes=(None, 0))(params, setup['batch_inputs_standarized']), setup['training_group'])

def standard_simulate(key, n, c, d):
    setup = get_data_params(key, n, c, d)
    params, opt_params, training_loss, validation_loss = standard_yuri.train_with_val(setup['params'], 
                                                                                      setup['inputs_standarized'],  
                                                                                      setup['outcomes'], 
                                                                                      mask=jnp.ones_like(setup['outcomes']), 
                                                                                      train_idx=setup['training_mask'], 
                                                                                      val_idx=setup['validation_mask'])
    return  training_loss, validation_loss, (setup['batch_inputs_standarized'], setup['batch_outcomes'], jax.vmap(model.fwd_pass, in_axes=(None, 0))(params, setup['batch_inputs_standarized']), setup['training_group'])

In [109]:

if d == 1:
    for kk in range(2):
        ts, vs, (xss, yss, yhatss, tg) = standard_simulate(jax.random.key(kk), n, c, d)
        tsc, vsc, (xss, yss, yhatsc, tg) = cluster_simulate(jax.random.key(kk), n, c, d)
        print(f"Minimum Validation Loss: Standard --> {jnp.min(vs):.3f} | Cluster -->  {jnp.min(vsc)}:.3f")

        ###
        for z, (i, j) in enumerate(zip(xss, yss)):
            if tg[z]:
                idx = jnp.argsort(i[:,0])
                plt.plot(i[:,0][idx], j[idx], color='black', linewidth=2.)
        xs = rearrange(xss, 'a b c -> (a b) c')[:,0].reshape(-1,1)
        print(xs.shape)
        idx = jnp.argsort(xs.reshape(-1,))
        yhats = rearrange(yhatss, 'a b c -> (a b) c')
        yhatc = rearrange(yhatsc, 'a b c -> (a b) c')
        plt.plot(xs[idx], yhats[idx], linestyle='--', label='Standard', linewidth=2.5)
        plt.plot(xs[idx], yhatc[idx], linestyle='--', label='Cluster', linewidth=2.5)
        plt.legend()
        plt.show()

        ###
        for z, (i, j) in enumerate(zip(xss, yss)):
            if not tg[z]:
                idx = jnp.argsort(i[:,0])
                plt.plot(i[:,0][idx], j[idx], color='black', linewidth=2.)
        xs = rearrange(xss, 'a b c -> (a b) c')[:,0].reshape(-1,1)
        print(xs.shape)
        idx = jnp.argsort(xs.reshape(-1,))
        yhats = rearrange(yhatss, 'a b c -> (a b) c')
        yhatc = rearrange(yhatsc, 'a b c -> (a b) c')
        plt.plot(xs[idx], yhats[idx], linestyle='--', label='Standard', linewidth=2.5)
        plt.plot(xs[idx], yhatc[idx], linestyle='--', label='Cluster', linewidth=2.5)
        plt.legend()
        plt.show()

        plt.plot(vs, label='Standard')
        plt.plot(vsc, label='Cluster')
        plt.legend()
        plt.show()

In [None]:
results = {'standard': {'training_loss': [], 'validation_loss': []},
           'cluster': {'training_loss': [], 'validation_loss': []}}

In [None]:
key = init_key
for _ in tqdm(range(simulations)):
    t, v, _ = standard_simulate(key, n, c, d)
    results['standard']['training_loss'].append(t)
    results['standard']['validation_loss'].append(v)
    key, _ = jax.random.split(key)

In [None]:
key = init_key
for _ in tqdm(range(simulations)):
    t, v, _ = cluster_simulate(key, n, c, d)
    results['cluster']['training_loss'].append(t)
    results['cluster']['validation_loss'].append(v)
    key, _ = jax.random.split(key)

In [None]:
jnp.array([i[-1] for i in results['standard']['validation_loss']])

In [None]:
jnp.mean(jnp.array([i[-1] for i in results['cluster']['validation_loss']]))

In [None]:
# Create bar graph
fig = plt.figure(dpi=300, tight_layout=True, figsize=(4, 4.5))
ax = plt.axes(facecolor=(.95, .96, .97))
ax.xaxis.set_tick_params(length=0, labeltop=False, labelbottom=True)

for key in 'left', 'right', 'top':
    ax.spines[key].set_visible(False)

subtitle = 'Validation Loss'
ax.text(0., 1.02, s=subtitle, transform=ax.transAxes, size=14)
ax.yaxis.set_tick_params(length=0)
ax.yaxis.grid(True, color='white', linewidth=2)
ax.set_axisbelow(True)
plt.plot([i[-1] for i in results['standard']['validation_loss']]
, label='Standard')
plt.plot([i[-1] for i in results['cluster']['validation_loss']]
, label='RFP')
plt.xlabel('Random Seeds', size=14)
plt.legend()
fig.savefig(figure_folders + f'naive_exp1_{n}_{c}_{d}_{epochs}_{inner_epochs}.png')
plt.show()

In [None]:
for i in range(10):
    plt.plot(jnp.array(results['standard']['validation_loss'])[i])
    if i > 1:
        break

In [None]:
print(jnp.min(jnp.array(results['standard']['validation_loss'])[0]))
print(jnp.array(results['standard']['validation_loss'])[0][-1])

In [None]:
min_standard = jnp.min( jnp.array(results['standard']['validation_loss']), axis=-1)
min_rfp = jnp.min( jnp.array(results['cluster']['validation_loss']), axis=-1)

In [None]:
# Create bar graph
fig = plt.figure(dpi=300, tight_layout=True, figsize=(4, 4.5))
ax = plt.axes(facecolor=(.95, .96, .97))
ax.xaxis.set_tick_params(length=0, labeltop=False, labelbottom=True)

for key in 'left', 'right', 'top':
    ax.spines[key].set_visible(False)

subtitle = 'Validation Loss'
ax.text(0., 1.02, s=subtitle, transform=ax.transAxes, size=14)
ax.yaxis.set_tick_params(length=0)
ax.yaxis.grid(True, color='white', linewidth=2)
ax.set_axisbelow(True)
plt.plot(min_standard, label='Standard')
plt.plot(min_rfp, label='RFP')
plt.xlabel('Random Seeds', size=14)
plt.legend()
fig.savefig(figure_folders + f'exp1_{n}_{c}_{d}_{epochs}_{inner_epochs}.png')
plt.show()