In [2]:

from fit_a_nef.nef.utils import custom_uniform


from config import load_cfgs, store_cfg
from dataset import path_from_name_idxs
from dataset.data_creation import get_dataset
from dataset.image_dataset import load_attributes, load_images

from tasks.utils import find_seed_idx, get_num_nefs_list, get_signal_idx


from ml_collections import ConfigDict, config_dict
from absl import flags, logging

# Documentation
import json

from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import h5py
import jax
import jax.numpy as jnp
import optax
import os

from flax import linen as nn

# JAX

from fit_a_nef.initializers import InitModel, MetaLearnedInit, RandomInit, SharedInit
from fit_a_nef.nef import param_key_dict

# Misc
from fit_a_nef.utils import (
    TrainState,
    flatten_params,
    get_nef,
    get_optimizer,
    get_scheduler,
    unflatten_params,
)

from fit_a_nef.metrics import mae, mse, psnr, simse, ssim

In [3]:
def find_env_path(env_var: str = "NEF_PATH", default_path: str = "saved_models"):
    if env_var not in os.environ:
        env_path = Path(default_path).absolute()
        env_path.mkdir(parents=True, exist_ok=True)
        logging.warning(f"{env_var} environment variable not set, using default value {env_path}")
        return str(env_path)
    else:
        return str(Path(os.environ[env_var]).absolute())

In [4]:
def get_config():
    cfg = ConfigDict()
    cfg.task = "image"
    cfg.nef_dir = 'saved_models/example'
    # Create dir for saving meta-learned initialisations
    cfg.meta_nef_dir = find_env_path("NEF_PATH", "saved_meta_models")
    cfg.seeds = '(0,1,2,3,4)'

    # Train
    cfg.train = ConfigDict()
    cfg.train.start_idx = 0
    cfg.train.end_idx = 10
    cfg.train.num_steps = 500
    cfg.train.num_parallel_nefs = 2000
    cfg.train.masked_portion = 1.0
    cfg.train.multi_gpu = False
    # put train_to_target_psnr is an optional argument
    cfg.train.train_to_target_psnr = config_dict.placeholder(float)
    cfg.train.check_every = 10
    cfg.train.fixed_init = True
    cfg.train.verbose = True

    # Whether to use meta-learned initialization
    cfg.train.from_meta_init = False
    cfg.train.meta_init_epoch = 10

    # Logging
    cfg.log = ConfigDict()
    cfg.log.images = 500
    cfg.log.metrics = 10
    cfg.log.loss = 10
    cfg.log.use_wandb = False

    # Wandb
    cfg.wandb = ConfigDict()
    cfg.wandb.entity = "neuralfield-wandb"
    cfg.wandb.project = "Classification_tuning_MNIST"
    cfg.wandb.name = "logging_MNIST_dset"

    # Dataset
    cfg.dataset = ConfigDict()
    cfg.dataset.path = "."
    cfg.dataset.name = "MNIST"
    cfg.dataset.out_channels = 1

    return cfg

In [5]:
def load_cfgs(
    nef='SIREN',
    _SCHEDULER_FILE: Optional[flags.FlagHolder] = None,
    _OPTIMIZER_FILE: Optional[flags.FlagHolder] = None,
):
    cfg = get_config()
    if nef == 'SIREN':
        nef_cfg = ConfigDict(
        {
            "name": "SIREN",
            "params": {
                "hidden_dim": 8,
                "num_layers": 3,
                "omega_0": 8.0,
            },
        }
    )
    elif nef == 'HYPLLMLP':
        nef_cfg = ConfigDict(
        {
            "name": "HYPLLMLP",
            "params": {
                "hidden_dim": 15,
                "num_layers": 3
            },
        }
    )

    # TODO find a way to not have to do this
    nef_cfg.unlock()
    nef_cfg.params.output_dim = cfg.dataset.get("out_channels", 1)
    nef_cfg.lock()

    scheduler_cfg = ConfigDict(
        {
            "name": "constant_schedule",
            "params": {
                "value": 5e-4,
            },
        }
    )
    cfg.unlock()
    cfg["scheduler"] = scheduler_cfg
    cfg.lock()

    optimizer_cfg = ConfigDict(
        {
            "name": "adamw",
            "params": {
                "b1": 0.9,
                "b2": 0.999,
                "eps": 1e-8,
                "eps_root": 0.0,
                "weight_decay": 1e-3,
            },
        }
    )
    cfg.unlock()
    cfg["optimizer"] = optimizer_cfg
    cfg.lock()

    return cfg, nef_cfg

# Setup configuration

In [6]:
cfg, nef_cfg = load_cfgs()



In [7]:
seeds = eval(cfg.seeds)

In [8]:
storage_folder = Path(cfg.nef_dir) / Path(cfg.dataset.name) / Path(f"{nef_cfg.name}")
meta_storage_folder = Path(cfg.meta_nef_dir) / Path(cfg.dataset.name) / Path(f"{nef_cfg.name}")
storage_folder.mkdir(parents=True, exist_ok=True)

In [9]:
store_cfg(nef_cfg, storage_folder, "nef.json", overwrite=True)
store_cfg(cfg, storage_folder, "cfg.json", overwrite=True)

In [10]:
available_datasets = [
        "CIFAR10",
        "MNIST",
        "CelebA",
        "ImageNet",
        "TinyImageNet",
        "STL10",
    ]

In [11]:
import os
from pathlib import Path
from absl import logging

from pathlib import Path
from typing import Any, Callable, Dict, Literal, Sequence, Tuple, Union

import torchvision
from absl import logging
from ml_collections import ConfigDict
from torch.utils import data

# image datasets
from dataset.image_dataset.CelebA import CelebA
from dataset.image_dataset.image_data import load_images
from dataset.image_dataset.ImageNet import ImageNetKaggle
from dataset.image_dataset.MicroImageNet import MicroImageNet
from dataset.image_dataset.TinyImageNet import TinyImageNet
from dataset.image_dataset.utils import (
    MEAN_STD_IMAGE_DATASETS,
    fast_normalize,
    image_to_numpy,
)

# shape datasets
from dataset.shape_dataset.shape_data import load_shapes
from dataset.shape_dataset.shapenet import ShapeNet
from dataset.shape_dataset.shapenet_val import ShapeNetVal

In [12]:
if "DATA_PATH" not in os.environ:
        data_path = Path("data").absolute()
        data_path.mkdir(parents=True, exist_ok=True)
        logging.warning(f"DATA_PATH environment variable not set, using default value {data_path}")
        DATA = data_path
else:
        DATA = Path(os.environ["DATA_PATH"])



In [None]:
if cfg.dataset.name == "MNIST":
        mean, std = MEAN_STD_IMAGE_DATASETS["MNIST"]

        normalize_fn = lambda x: fast_normalize(x, mean, std)

        train_dataset = torchvision.datasets.MNIST(
            root=DATA / Path(cfg.dataset.path),
            train=True,
            transform=torchvision.transforms.Compose(
                [image_to_numpy, lambda x: x.reshape(28, 28, 1), normalize_fn]
            ),
            download=True,
        )
        test_dataset = torchvision.datasets.MNIST(
            root=DATA / Path(cfg.dataset.path),
            train=False,
            transform=torchvision.transforms.Compose(
                [image_to_numpy, lambda x: x.reshape(28, 28, 1), normalize_fn]
            ),
            download=True,
        )
        source_dataset = data.ConcatDataset([train_dataset, test_dataset])

In [1]:
source_dataset = get_dataset(cfg.dataset)

NameError: name 'get_dataset' is not defined

In [1]:
source_dataset

NameError: name 'source_dataset' is not defined

In [38]:
signals_in_dset = len(source_dataset)

In [39]:
total_nefs = cfg.train.end_idx - cfg.train.start_idx

In [40]:
num_nefs_list = get_num_nefs_list(
        nef_start_idx=cfg.train.start_idx,
        nef_end_idx=cfg.train.end_idx,
        num_parallel_nefs=cfg.train.num_parallel_nefs,
        signals_in_dset=signals_in_dset,
    )

In [41]:
num_nefs_list

[10]

In [42]:
init_rngs_per_seed = [jax.random.PRNGKey(seed) for seed in seeds]

In [43]:
initializers = [SharedInit(init_rngs_per_seed[i]) for i in range(len(seeds))]

# Training

## Prepare training

In [44]:
nef_start_idx = 0
nef_end_idx = 10

In [45]:
seed_idx = find_seed_idx(nef_start_idx, signals_in_dset)
seed = seeds[seed_idx]
train_rng = jax.random.PRNGKey(seed)

start_idx = get_signal_idx(nef_start_idx, signals_in_dset)
end_idx = get_signal_idx(nef_end_idx - 1, signals_in_dset) + 1

In [46]:
def load_images(
    source_dataset: data.Dataset,
    start_idx: int,
    end_idx: int,
    rng: Optional[jax.random.PRNGKey] = None,
    force_shuffle: bool = False,
) -> Tuple[jnp.array, jnp.array, Tuple[int, int, int], Optional[jax.random.PRNGKey]]:
    """Load images from the dataset and create the coordinates. The returned images will have shape
    (num_images, num_pixels, num_channels) while the coordinates will have shape (num_pixels,
    num_channels).

    When shuffling is required, a Jax key is used to get the randomness.
    A new key is returned which should be further split before the next use.

    Args:
        source_dataset (data.Dataset): The dataset.
        start_idx (int): The index of the first image to load.
        end_idx (int): The last image loaded will be the one with index end_idx - 1.
        rng (jax.random.PRNGKey): a random key used in the pseudo-random number generation. Default: None.
        force_shuffle (bool): whether to shuffle the dataset. Default: False.

    Returns:
        coords: The coordinates of the pixels with shape
                (num_pixels, num_channels).
        images: The images with shape
                (num_images, num_pixels, num_channels).
        image_shape: The original shape of the images.
        new_rng: A new PRNG key if one was provided and
                shuffling was performed, otherwise returns the
                provided key in `rng`.
    """
    # Create a subset of the dataset
    dset = data.Subset(source_dataset, range(start_idx, end_idx))

    # Create a loader with a single worker to get the data from disk to RAM
    loader = data.DataLoader(
        dset, batch_size=1, shuffle=False, num_workers=0, collate_fn=lambda x: x
    )
    # make them into a jnp array with the correct shape, assumes the data is in [0, 1]
    images = jnp.stack([x[0][0] for x in iter(loader)], axis=0)
    # remember the original shape of the images for plotting and coordinate making
    images_shape = list(images.shape[-3:])
    # reshape to (num_images, num_pixels, num_channels)
    images = images.reshape(-1, images_shape[0] * images_shape[1], images_shape[2])

    if force_shuffle:
        if rng is None:
            raise RuntimeError(
                "force_shuffle is set to True, but the rng key has not been provided."
            )
        else:
            rng, new_rng = jax.random.split(rng)
            index_perm = jax.random.permutation(new_rng, jnp.arange(0, images.shape[0]), axis=0)
            images = images[index_perm]
    else:
        index_perm = jnp.arange(0, images.shape[0])

    # coordinates
    x = jnp.linspace(-1, 1, images_shape[0])
    y = jnp.linspace(-1, 1, images_shape[1])
    x, y = jnp.meshgrid(x, y)
    coords = jnp.stack([x, y], axis=-1)
    # reshape to (1, num_pixels, num_channels)
    coords = coords.reshape(images_shape[0] * images_shape[1], 2)

    return coords, images, images_shape, rng, index_perm

In [47]:
coords, images, images_shape, _, _ = load_images(source_dataset, start_idx, end_idx)

In [48]:
dataset_mean, dataset_std = MEAN_STD_IMAGE_DATASETS[cfg.dataset.name]

In [49]:
signals=images
coords=coords
nef_cfg=nef_cfg
scheduler_cfg=cfg.scheduler
optimizer_cfg=cfg.optimizer
log_cfg=cfg.log
train_rng=train_rng
initializer=initializers[seed_idx]
num_steps=cfg.train.num_steps
masked_portion=cfg.train.masked_portion
images_mean=dataset_mean
images_std=dataset_std
images_shape=images_shape
verbose=cfg.train.verbose

In [50]:
max_images_logged = 5
max_recons_metrics = 10

In [51]:
num_signals = signals.shape[0]

In [52]:
# Initialize optimizer and learning rate schedule
lr_schedule = get_scheduler(scheduler_cfg)
optimizer = get_optimizer(optimizer_cfg, lr_schedule)

## Set up model

In [53]:
def SIREN_key(param_name, nef_cfg):
    # bias before kernel, ordered based on layer number
    if param_name.startswith("output_linear."):
        index = 2 * nef_cfg.get("num_layers") - 2
    else:
        index = 2 * int(param_name.split(".")[0].split("_")[-1])

    if param_name.endswith(".bias"):
        return index
    elif param_name.endswith(".kernel"):
        return index + 1
    else:
        raise ValueError(f"param_name (`{param_name}`) must end with either `.bias` or `.kernel`.")


In [54]:
class SIREN(nn.Module):
    output_dim: int
    hidden_dim: int
    num_layers: int
    omega_0: float

    def setup(self):
        self.kernel_net = [
            HSirenLayer(
                output_dim=self.hidden_dim,
                omega_0=self.omega_0,
                is_first_layer=True,
            )
        ] + [
            HSirenLayer(
                output_dim=self.hidden_dim,
                omega_0=self.omega_0,
            )
            for _ in range(self.num_layers - 2)
        ]

        self.output_linear = nn.Dense(
            features=self.output_dim,
            use_bias=True,
            kernel_init=custom_uniform(numerator=1, mode="fan_in", distribution="normal"),
            bias_init=nn.initializers.zeros,
        )

    def __call__(self, x):
        for layer in self.kernel_net:
            x = layer(x)

        out = self.output_linear(x)

        return out

In [55]:
class HSirenLayer(nn.Module):
    output_dim: int
    omega_0: float
    is_first_layer: bool = False

    def setup(self):
        c = 1 if self.is_first_layer else 6 / self.omega_0**2
        distrib = "uniform_squared" if self.is_first_layer else "uniform"
        self.linear = nn.Dense(
            features=self.output_dim,
            use_bias=True,
            kernel_init=custom_uniform(numerator=c, mode="fan_in", distribution=distrib),
            bias_init=nn.initializers.zeros,
        )

    def __call__(self, x):
        print(x.shape)
        after_linear = self.omega_0 * self.linear(x)
        print(self.linear.variables['params'])
        return jnp.sin(after_linear)
    
    def construct_dl_parameters(self, in_features: int, out_features: int, bias: bool = True
        ):
        weight = jax.numpy.empty(in_features, out_features)
        if bias:
            b = jax.numpy.empty(out_features)
        else:
            b = None
        return weight, b
        

In [56]:
model = SIREN(**nef_cfg['params'])

In [57]:
model

SIREN(
    # attributes
    output_dim = 1
    hidden_dim = 8
    num_layers = 3
    omega_0 = 8.0
)

In [59]:
example_input = coords[: int(masked_portion * coords.shape[0])]

# Initialize model parameters
params = initializer(model, example_input, num_signals)

(784, 2)
{'kernel': Traced<ShapedArray(float32[2,8])>with<BatchTrace(level=1/0)> with
  val = Array([[[ 0.2923069 ,  0.28966415, -0.2603742 ,  0.48741376,
          0.07832241, -0.4425907 , -0.14212322, -0.1031878 ],
        [-0.08754194, -0.4592371 ,  0.47663093, -0.3745258 ,
          0.34914184, -0.01427758, -0.28593063,  0.08828318]],

       [[ 0.2923069 ,  0.28966415, -0.2603742 ,  0.48741376,
          0.07832241, -0.4425907 , -0.14212322, -0.1031878 ],
        [-0.08754194, -0.4592371 ,  0.47663093, -0.3745258 ,
          0.34914184, -0.01427758, -0.28593063,  0.08828318]],

       [[ 0.2923069 ,  0.28966415, -0.2603742 ,  0.48741376,
          0.07832241, -0.4425907 , -0.14212322, -0.1031878 ],
        [-0.08754194, -0.4592371 ,  0.47663093, -0.3745258 ,
          0.34914184, -0.01427758, -0.28593063,  0.08828318]],

       [[ 0.2923069 ,  0.28966415, -0.2603742 ,  0.48741376,
          0.07832241, -0.4425907 , -0.14212322, -0.1031878 ],
        [-0.08754194, -0.4592371 ,  0.4

In [134]:
print(model.tabulate(jax.random.key(0), example_input))


[3m                                 SIREN Summary                                  [0m
┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃[1m [0m[1mpath          [0m[1m [0m┃[1m [0m[1mmodule    [0m[1m [0m┃[1m [0m[1minputs       [0m[1m [0m┃[1m [0m[1moutputs       [0m[1m [0m┃[1m [0m[1mparams       [0m[1m [0m┃
┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│                │ SIREN      │ [2mfloat32[0m[784,… │ [2mfloat32[0m[784,1] │               │
├────────────────┼────────────┼───────────────┼────────────────┼───────────────┤
│ kernel_net_0   │ SirenLayer │ [2mfloat32[0m[784,… │ [2mfloat32[0m[784,8] │               │
├────────────────┼────────────┼───────────────┼────────────────┼───────────────┤
│ kernel_net_0/… │ Dense      │ [2mfloat32[0m[784,… │ [2mfloat32[0m[784,8] │ bias:         │
│                │            │               │                │ [2mfloat32[0m[8]    │
│    

In [60]:
class Curvature():
    value: jnp.ndarray
    c: jnp.ndarray
    
    def __init__(self, value=1.0):
        self.value = jnp.array(value)
        self.c = self.softplus(self.value, )
        
    def softplus(self, value):
        return jnp.log1p(jnp.exp(value))

In [61]:
curvature = Curvature(0.1)

In [62]:
def expmap0(x: jnp.ndarray, curv: jnp.ndarray, dim: int = 1):
        v_norm = jnp.linalg.norm(x, axis=dim, keepdims=True)
        v_norm_clamped = jnp.maximum(v_norm, 1e-15)
        v_norm_c_sqrt = v_norm_clamped * jnp.sqrt(curv)
        return jnp.tanh(v_norm_c_sqrt) * x / v_norm_c_sqrt

def project(x: jnp.ndarray, curv: jnp.ndarray, dim: int = 1, eps: float = -1.0):
    if eps < 0:
            if x.dtype == jnp.float32:
                eps = 4e-3
            else:
                eps = 1e-5
    maxnorm = (1 - eps) / ((curv + 1e-15) ** 0.5)
    maxnorm = jnp.where(curv > 0, maxnorm, jnp.full((), 1e15, dtype=curv.dtype))
    norm = jnp.linalg.norm(x, axis=dim, keepdims=True)
    norm = jnp.clip(norm, a_min=1e-15)
    cond = norm > maxnorm
    projected = x / norm * maxnorm
    return jnp.where(cond, projected, x)

In [63]:
def reflatten_images(images, shapes):
    return images.reshape(images.shape[0], shapes[0], shapes[1], shapes[2])

def flatten_images(images, shapes):
    return images.reshape(images.shape[0], shapes[0]*shapes[1], shapes[2])

In [64]:
img = reflatten_images(signals, images_shape)

In [65]:
projected_img = jax.numpy.array([project(expmap0(image, curvature.c), curvature.c) for image in img])

In [66]:
input = flatten_images(projected_img, images_shape)

In [67]:
input.shape

(10, 784, 1)

In [68]:
# Create the train state
state = TrainState.create(
            apply_fn=model.apply, params=params, tx=optimizer, rng=train_rng
        )

In [69]:
param_key = partial(
            param_key_dict[nef_cfg.get("name", None)], nef_cfg=nef_cfg.get("params", None)
        )

In [70]:
def create_loss():
    def loss_fn(params, coords, images):
            y = model.apply({"params": params}, coords)
            loss = (images - y) ** 2
            recon_loss = loss.mean()
            return recon_loss

    return jax.vmap(loss_fn, in_axes=(0, None, 0), out_axes=0)

loss_fn = create_loss()


In [71]:
def process_batch(
        state: TrainState, coords: jnp.ndarray, images: jnp.ndarray
    ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:

        rng, step_rng = jax.random.split(state.rng)
        mask = jax.random.permutation(step_rng, coords.shape[0])
        mask = mask[: int(masked_portion * coords.shape[0])]
        # apply_mask
        coords = coords[mask]
        images = images[:, mask]

        return coords, images, rng

In [72]:
def create_train_step():
    def train_step(stat, coord, sign):
        coords, signals, rng = process_batch(stat, coord, sign)
        # pass the coordinates and the signals all at once
        my_loss = lambda params: loss_fn(params, coords, signals).sum()
        # compute gradients wrt state.params
        loss, grads = jax.value_and_grad(my_loss, has_aux=False)(stat.params)
        stat = stat.apply_gradients(grads=grads, rng=rng)
        return stat, loss

    return jax.jit(train_step)

train_step = create_train_step()

**Compile TrainStep**

In [73]:
#Train model
state, losses = train_step(state, coords, signals)

(784, 2)
{'bias': Traced<ShapedArray(float32[8])>with<BatchTrace(level=4/0)> with
  val = Traced<ShapedArray(float32[10,8])>with<JVPTrace(level=3/0)> with
    primal = Traced<ShapedArray(float32[10,8])>with<DynamicJaxprTrace(level=1/0)>
    tangent = Traced<ShapedArray(float32[10,8])>with<JaxprTrace(level=2/0)> with
      pval = (ShapedArray(float32[10,8]), None)
      recipe = LambdaBinding()
  batch_dim = 0, 'kernel': Traced<ShapedArray(float32[2,8])>with<BatchTrace(level=4/0)> with
  val = Traced<ShapedArray(float32[10,2,8])>with<JVPTrace(level=3/0)> with
    primal = Traced<ShapedArray(float32[10,2,8])>with<DynamicJaxprTrace(level=1/0)>
    tangent = Traced<ShapedArray(float32[10,2,8])>with<JaxprTrace(level=2/0)> with
      pval = (ShapedArray(float32[10,2,8]), None)
      recipe = LambdaBinding()
  batch_dim = 0}
(784, 8)
{'bias': Traced<ShapedArray(float32[8])>with<BatchTrace(level=4/0)> with
  val = Traced<ShapedArray(float32[10,8])>with<JVPTrace(level=3/0)> with
    primal = Tr

## Training Loop

In [94]:
def get_lr():
    schedule = lr_schedule
    if schedule is None:
        logging.warning("No learning rate schedule found.")
        return
    opt_state = state.opt_state
    opt_state = [s for s in opt_state if isinstance(s, optax.ScaleByScheduleState)]

    if len(opt_state) == 0:
        logging.warning("No state of a learning rate schedule found.")
        return
    if len(opt_state) > 1:
        logging.warning(
            "Found multiple states of a learning rate schedule. Using the last one."
            )
    step = opt_state[-1].count
    lr = schedule(step)
    return lr

In [95]:
def apply_model(model, params, coords):
    return jax.vmap(lambda params: model.apply({"params": params}, coords))(params)

In [96]:
def get_psnr(model, params, coords, signals):
    recon = apply_model(model, params, coords)
    metric = psnr(recon, signals, images_mean, images_std)
    return jnp.mean(metric), jnp.mean(jnp.square(metric))

In [97]:
for step_num in range(1, num_steps + 1):
    # Train model for one epoch, and log avg loss
    state, losses = train_step(state, coords, signals)

    if log_cfg is not None:
        if step_num % log_cfg.loss == 0 or (step_num == num_steps):
            learning_rate = get_lr()
            logging.info(f"Step: {step_num}. Loss: {losses.mean()}. LR {learning_rate}")
        if step_num % log_cfg.images == 0 or (step_num == num_steps):
            recons = apply_model(model, state.params, coords)
            logging.info("Wandb not available. Skipping logging images.")
        if step_num % log_cfg.metrics == 0 or (step_num == num_steps):
            psnr_mean, psnr_squared_mean = get_psnr(model, state.params, coords, signals)
            logging.info(f"Step: {step_num}. PSNR: {psnr_mean}")


In [68]:
attributes = load_attributes(source_dataset, start_idx, end_idx)

In [333]:
def save(path, **kwargs):
    param_config, comb_params = flatten_params(
            state.params, num_batch_dims=1, param_key=param_key
        )
    comb_params = jax.device_get(comb_params)
    param_config = json.dumps(param_config)

    with h5py.File(path, "w") as f:
        f.create_dataset("params", data=comb_params)
        dt = h5py.special_dtype(vlen=str)
        data = f.create_dataset("param_config", (1,), dtype=dt)
        data[0] = param_config
        for key, value in kwargs.items():
            if isinstance(value, jnp.ndarray):
                value = jax.device_get(value)
            f.create_dataset(key, data=value)

In [334]:
save(storage_folder / Path(path_from_name_idxs("nefs", nef_start_idx, nef_end_idx)), **attributes)