In [1]:
# Add "gfn_maxent_rl" to the path (only needed in notebook)
import sys
sys.path.insert(0, '..')

In [2]:
import wandb
import os
import hydra
import jax
import jax.numpy as jnp

from pathlib import Path
from numpy.random import default_rng

from gfn_maxent_rl.algos.detailed_balance_vanilla import DBVParameters
from gfn_maxent_rl.utils import io
from gfn_maxent_rl.utils.evaluations import get_samples_from_env

In [3]:
api = wandb.Api()

In [4]:
run = api.run('tristandeleu_mila_01/gfn_maxent_rl/aw4ru2l7')

In [5]:
# Check the algorithm (we need to pack the parameters)
run.config['exp_name_algorithm']

'dbv'

In [6]:
root = Path(os.getenv('SLURM_TMPDIR')) / run.id
run.file('model.npz').download(root=root, exist_ok=True)

with open(root / 'model.npz', 'rb') as f:
    params = DBVParameters(**io.load(f))

# Patch
params = jax.tree_util.tree_map(lambda x: x.item(), params)

I0000 00:00:1706026710.670022   35344 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.


In [7]:
env, infos = hydra.utils.instantiate(
    run.config['env'],
    num_envs=run.config['num_envs'],
    seed=run.config['seed'],
    rng=default_rng(run.config['seed']),
)

In [8]:
algorithm = hydra.utils.instantiate(run.config['algorithm'], env=env)

In [9]:
# Patch
net_state = DBVParameters(
    policy={'~': {'normalization': jnp.array(1., dtype=jnp.float32)}},
    flow={}
)

In [10]:
key = jax.random.PRNGKey(run.config['seed'])
samples, returns = get_samples_from_env(
    env,
    algorithm,
    params,
    net_state,
    key,
    num_samples=1000,
    verbose=True
)

  0%|          | 0/1000 [00:00<?, ?it/s]