In [None]:
import optax
import jax
import jax.numpy as jnp

def map_nested_fn(fn):
  '''Recursively apply `fn` to key-value pairs of a nested dict.'''
  def map_fn(nested_dict):
    return {k: (map_fn(v) if isinstance(v, dict) else fn(k, v))
            for k, v in nested_dict.items()}
  return map_fn

params = {'linear_1': {'w': jnp.zeros((5, 6)), 'b': jnp.zeros(5)},
          'linear_2': {'w': jnp.zeros((6, 1)), 'b': jnp.zeros(1)}}
gradients = jax.tree.map(jnp.ones_like, params)  # dummy gradients

label_fn = map_nested_fn(lambda k, _: k)
tx = optax.partition(
    {'w': optax.adam(1.0), 'b': optax.sgd(1.0)}, label_fn)
state = tx.init(params)
updates, new_state = tx.update(gradients, state, params)
new_params = optax.apply_updates(params, updates)

In [None]:
import numpy as np
params = np.load('./results/repl_uci/mclmc_debug_20250305-140952/warmstart/params_0.npz')
with open("params_end.txt", "w") as f:
    for k, v in params.items():
        f.write(f"{k}: {v}\n")

In [None]:
def delete_layer0(params):
    if 'layer0' in params:
        del params['layer0']
    for key, value in params.items():
        if isinstance(value, dict):
            delete_layer0(value)
    return params
params = {'fcn':{'layer0':{'kernel':1, 'bias':2}, 'layer1':{'kernel':3, 'bias':4}}}
params = delete_layer0(params)
print(params)

In [None]:
def partition_params(params):
    input_output_layers = {}
    hidden_layers = {}

    input_output_layers = {}
    hidden_layers = {}

    for key, value in params['fcn'].items():
        if key == 'layer0' or key == f'layer{len(params["fcn"]) - 1}':
            input_output_layers[key] = value
        else:
            hidden_layers[key] = value

    return {'fcn': input_output_layers}, {'fcn': hidden_layers}
params = {'fcn':{'layer0':{'kernel':1, 'bias':2}, 'layer1':{'kernel':3, 'bias':4}, 'layer2':{'kernel':5, 'bias':6}}}
a,b = partition_params(params)
print(a,b)

In [None]:
import flax.linen as nn
import jax.numpy as jnp
import jax
from flax import traverse_util
import jax.lax as lax
from flax.training import train_state
import optax

def loss_fn(params, x, y):
    preds = model.apply(params, x)
    loss = jnp.mean((preds - y) ** 2)
    return loss

class FCN(nn.Module):
    @nn.compact
    def __call__(self, x: jnp.ndarray):
        x = nn.Dense(features=5)(x)
        x = lax.stop_gradient(x)
        x = nn.relu(x)
        x = nn.Dense(features=1)(x)
        return x
    
model = FCN()
x = jnp.ones((10, 5))
params = model.init(jax.random.PRNGKey(0), x)
print(params)
y = jnp.ones((10, 1))
grads = jax.grad(loss_fn)(params, x, y)
optimizer = optax.adam(1e-3)
# Create a training state
class TrainState(train_state.TrainState):
    params: dict

state = TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)

# Update the parameters
state = state.apply_gradients(grads=grads)

print(state.params)

In [None]:
class PartitionFCN(nn.Module):
    @nn.compact
    def some_filter_fn(k):
        if k == 'Dense_0':
            return True
        return False

    def selective_stop_grad(self,variables):
        flat_vars = traverse_util.flatten_dict(variables)
        new_vars = {k: lax.stop_gradient(v) if self.some_filter_fn(k) else v for k, v in flat_vars.items()}
        return traverse_util.unflatten_dict(new_vars)
    
    def setup(self):
        _FCN = nn.map_variables(FCN, "params", self.selective_stop_grad)
        self.fcn = _FCN()    
        
    def __call__(self, x: jnp.ndarray):
      return self.fcn(x)

In [None]:
from flax.core import freeze, unfreeze
for name, param in unfreeze(params).items():
    print(f"Parameter name: {name}, value: {param}")

In [None]:
import numpy as np
data = np.load('results/dataset/partition_warmstart_0/samples/0/sample_0.npz')
print(data)
data.files

In [None]:
import jax
import jax.numpy as jnp
import blackjax
import matplotlib.pyplot as plt
import jax.lax as lax

# Define 2D Gaussian log probability function
def log_prob_fn(x):
    # Mean at origin, unit variance
    return -0.5 * jnp.sum(x[0]**2) 

# Initialize random key
rng_key = jax.random.PRNGKey(0)



# Initial position
init_position = jnp.array([jax.random.normal(rng_key), jax.random.normal(rng_key)])

# Initialize MCLMC sampler
kernel = blackjax.mcmc.mclmc.build_kernel(
    logdensity_fn=log_prob_fn,
    sqrt_diag_cov=jnp.ones(2),
    integrator=blackjax.mcmc.integrators.isokinetic_mclachlan
)

state = blackjax.mcmc.mclmc.init(
    position=init_position,
    logdensity_fn=log_prob_fn,
    rng_key=rng_key
)

# Run sampler
n_samples = 1000
samples = []
rng_key, sample_key = jax.random.split(rng_key)

for i in range(n_samples):
    rng_key, step_key = jax.random.split(rng_key)
    state, info = kernel(step_key, state, L=0.5, step_size=0.1)
    samples.append(state.position)

samples = jnp.array(samples)

# Plot results
plt.figure(figsize=(10,10))
plt.scatter(samples[:,0], samples[:,1], alpha=0.1)
plt.title('MCLMC Samples from 2D Standard Normal')
plt.xlabel('x')
plt.ylabel('y') 
plt.axis('equal')
plt.show()

# Plot marginal histograms
plt.figure(figsize=(12,5))
plt.subplot(121)
plt.hist(samples[:,0], bins=50, density=True)
plt.title('Marginal Distribution (x)')
plt.subplot(122) 
plt.hist(samples[:,1], bins=50, density=True)
plt.title('Marginal Distribution (y)')
plt.show()

In [8]:
from functools import partial
def fn(x,y,z):
    return x+y+z

fn = partial(fn, y=1)
fn = partial(fn, z=2)
fn(3)

6