# Checkpoint utils

In [None]:
# | default_exp utils.checkpoint

In [None]:
# | export

from pathlib import Path
from flax.training import train_state
import orbax.checkpoint as obc
import hydra
import jax
from physmodjax.scripts.train_rnn import create_train_state
import flax.linen as nn
from omegaconf import OmegaConf
from typing import Tuple
from wandb.apis import public
import wandb

In [None]:
# | export


def restore_experiment_state(
    run_path: Path,  # Path to the run directory (e.g. "outputs/2024-01-23/22-15-11")
    best: bool = True,  # If True, restore the best checkpoint instead of the latest
    step_to_restore: int = None,  # If not None, restore the checkpoint at this step
    x0_shape: Tuple[int] = (1, 101, 1),  # Shape of the initial condition
    x_shape: Tuple[int] = (1, 1, 101, 1),  # Shape of the input data
    kwargs: dict = {},  # Additional arguments to pass to the model
) -> Tuple[train_state.TrainState, nn.Module, obc.CheckpointManager]:
    """
    Restores the train state from a run.

    Args:
        run_path (Path): Path to the run directory (e.g. "outputs/2024-01-23/22-15-11")

    Returns:
    -------
        train_state.TrainState: The train state of the experiment
        nn.Module: The model used in the experiment
        CheckpointManager: The checkpoint manager
    """

    # Make sure the path is a Path object
    run_path = Path(run_path)

    # These are hardcoded, do not change
    ckpt_path = run_path / "checkpoints"
    config_path = run_path / ".hydra" / "config.yaml"

    options = obc.CheckpointManagerOptions(
        max_to_keep=1,
        create=True,
        best_fn=lambda x: float(x["val/mse"]),
        best_mode="min",
    )
    with obc.CheckpointManager(
        ckpt_path,
        options=options,
        item_handlers={"state": obc.PyTreeCheckpointHandler()},
    ) as checkpoint_manager:

        # Load the config
        cfg = OmegaConf.load(config_path)

        model_cls: nn.Module = hydra.utils.instantiate(cfg.model)
        grad_clip = hydra.utils.instantiate(cfg.gradient_clip)

        # initialise train state
        # try to get this information from the config
        if hasattr(cfg, "data_info"):
            print(f"Using data_info from config: {cfg.data_info}")
            x_shape = [1] + cfg.data_info
        rng = jax.random.PRNGKey(cfg.seed)

        empty_state = create_train_state(
            model_cls(training=False, **kwargs),
            rng=rng,
            x_shape=x_shape,
            num_steps=666,
            learning_rate=cfg.optimiser.learning_rate,
            grad_clip=grad_clip,
            components_to_freeze=cfg.frozen,
            norm=cfg.model.norm,
            schedule_type=cfg.schedule_type,
        )

        step = (
            checkpoint_manager.latest_step()
            if not best
            else checkpoint_manager.best_step()
        )
        step = step_to_restore if step_to_restore is not None else step
        print(f"Restoring checkpoint from step {step}...")
        state = checkpoint_manager.restore(
            step=step,
            args=obc.args.Composite(
                state=obc.args.PyTreeRestore(empty_state),
            ),
        )['state']

        return state, model_cls(training=False, **kwargs), checkpoint_manager

In [None]:
# | export

def download_ckpt_single_run(
    run_name: str,
    project: str,
    tmp_dir: Path = Path("/tmp/physmodjax"),
    overwrite: bool = False,
) -> Tuple[Path, OmegaConf]:
    filter_dict = {
        "display_name": run_name,
    }

    if wandb.run is None:
        wandb.init()
        api: public.Api = wandb.Api()

    runs: public.Runs = api.runs(project, filter_dict)

    assert len(runs) > 0, f"No runs found with name {run_name}"
    assert len(runs) == 1, f"More than one run found with name {run_name}"

    run: public.Run = runs[0]
    conf = OmegaConf.create(run.config)

    artifacts: public.RunArtifacts = run.logged_artifacts()

    artifact: wandb.Artifact

    # check if no artifacts
    if len(artifacts) == 0:
        raise ValueError(f"No artifacts found for run {run_name}")

    for artifact in artifacts:
        if artifact.type == "model":
            checkpoint_path = tmp_dir / artifact.name
            if checkpoint_path.exists() and not overwrite:
                print(f"Checkpoint already exists at {checkpoint_path}, skipping")
                return checkpoint_path, conf
            else:
                artifact.download(checkpoint_path)

    # save config next to checkpoint
    conf_path = checkpoint_path / ".hydra" / "config.yaml"
    conf_path.parent.mkdir(parents=True, exist_ok=True)
    OmegaConf.save(conf, conf_path)

    print(f"Downloaded checkpoint to {checkpoint_path}")
    return checkpoint_path, conf

In [None]:
from hydra import initialize, compose
from hydra.core.hydra_config import HydraConfig
from physmodjax.scripts.train_rnn import train_rnn
from pathlib import Path

In [None]:
# | eval: false

data_array = "../data/ftm_string_nonlin_1000_Noise_4000Hz_1.0s.npy"
batch_size = 1
split = [0.01, 0.01, 0.01]
extract_channels = [0]
output_dir = ""

with initialize(version_base=None, config_path="../../conf"):
    cfg = compose(
        return_hydra_config=True,
        config_name="train_rnn",
        overrides=[
            "+experiment=1d_koopman",
            f"++datamodule.data_array={data_array}",
            f"++datamodule.batch_size={batch_size}",
            f"++datamodule.split={split}",
            f"++datamodule.extract_channels={extract_channels}",
            "++model.d_vars=1",
            "++epochs=1",
            "++epochs_val=1",
            "++wandb.project=physmodjax",
            "++wandb.entity=iir-modal"
        ],
    )
    OmegaConf.register_new_resolver("eval", eval, replace=True)
    OmegaConf.resolve(cfg)

    cfg_no_hydra = {k:v for (k,v) in cfg.items() if "hydra" not in k} 
    print(OmegaConf.to_yaml(cfg_no_hydra))

    HydraConfig.instance().set_config(cfg)
    print(OmegaConf.to_yaml((HydraConfig.get().runtime)))

    output_dir = Path(cfg.hydra.run.dir).absolute()
    # HydraConfig.get().runtime["output_dir"] = output_dir
    HydraConfig.instance().set_config(cfg)

    print(f"Output dir: {output_dir}")

    train_rnn(cfg)



model:
  _target_: physmodjax.models.autoencoders.BatchedKoopmanAutoencoder1D
  _partial_: true
  d_vars: 1
  d_model: 101
  norm: layer
  encoder_model:
    _target_: physmodjax.models.mlp.MLP
    _partial_: true
    hidden_channels:
    - 128
    - 128
    - 256
    kernel_init:
      _target_: flax.linen.initializers.orthogonal
  decoder_model:
    _target_: physmodjax.models.mlp.MLP
    _partial_: true
    hidden_channels:
    - 128
    - 128
    - 101
    kernel_init:
      _target_: flax.linen.initializers.orthogonal
  dynamics_model:
    _target_: physmodjax.models.recurrent.LRUDynamics
    _partial_: true
    d_hidden: 128
    r_min: 0.99
    r_max: 0.999
    max_phase: 6.28
    clip_eigs: true
datamodule:
  _target_: physmodjax.utils.data.DirectoryDataModule
  split:
  - 0.01
  - 0.01
  - 0.01
  batch_size: 1
  extract_channels:
  - 0
  total_num_train: 4000
  total_num_val: 4000
  total_num_test: 4000
  num_steps_train:
  - 1
  - 3999
  num_steps_val:
  - 1
  - 3999
  mode: s

InstantiationException: Error in call to target 'physmodjax.utils.data.DirectoryDataModule':
AssertionError('The data array does not exist')
full_key: datamodule

In [None]:
# | eval: false

# instantiate the datamodule

datamodule = hydra.utils.instantiate(cfg.datamodule)
train_dataloader = datamodule.train_dataloader
val_dataloader = datamodule.val_dataloader
test_dataloader = datamodule.test_dataloader


In [None]:
# | eval: false

checkpoint_path, cfg = download_ckpt_single_run("eager-valley-1758")
kwargs = {"n_steps": datamodule.num_steps_target_val}
state, model, ckpt_manager = restore_experiment_state(
    checkpoint_path,
    kwargs=kwargs,
)

Checkpoint already exists at /tmp/physmodjax/checkpoints_fiug7qv5:v0, skipping
Using data_info from config: [1, 101, 1]
Restoring checkpoint from step 1...




In [None]:
# | eval: false

from functools import partial
from physmodjax.utils.metrics import (
    mse,
    mae,
    mse_relative,
    mae_relative,
    accumulate_metrics,
)
import numpy as np

In [None]:
# | eval: false

@partial(jax.jit, static_argnames=("model", "norm"))
def val_step(
    state: train_state.TrainState,
    x,
    y,
    model,
    norm,
):
    if norm in ["batch"]:
        pred = model.apply(
            {"params": state.params, "batch_stats": state.batch_stats}, x
        )
    else:
        pred = model.apply({"params": state.params}, x)

    metrics = {
        "val/mse": mse(y, pred),
        "val/mae": mae(y, pred),
        "val/mse_rel": mse_relative(y, pred),
        "val/mae_rel": mae_relative(y, pred),
    }
    return metrics, pred


val_batch_metrics = []
for x, y in val_dataloader:

    metrics, pred = val_step(
        state,
        x=x,
        y=y,
        model=model,
        norm=cfg.model.norm,
    )
    val_batch_metrics.append(metrics)
val_batch_metrics = accumulate_metrics(val_batch_metrics)

metrics = ckpt_manager.metrics(ckpt_manager.best_step())
val_metrics = {k: v for k, v in metrics.items() if "val" in k}

for key, value in val_metrics.items():
    assert np.isclose(
        value, val_batch_metrics[key], atol=1e-6
    ), f"Metric {key} does not match: {value} != {val_batch_metrics[key]}"