In [1]:
!pip install -q brax tyro flax optax

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.2/14.2 MB[0m [31m48.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m124.3/124.3 kB[0m [31m10.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m172.4/172.4 kB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.7/76.7 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.6/6.6 MB[0m [31m45.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.7/6.7 MB[0m [31m34.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m20.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import jax
import jax.numpy as jnp
import numpy as np
import functools
import time
from brax.training.agents.sac.train import train as sac_train
from brax import envs
from brax.io import html

In [3]:
def progress_fn(num_steps, metrics):
    reward = metrics.get("eval/episode_reward", float('nan'))
    print(f"[{num_steps} steps] Eval reward: {reward:.2f}")

In [4]:
def train_fn(environment, progress_fn=None):
    return sac_train(
        environment=environment,
        progress_fn=progress_fn,
        num_timesteps=1_000_000,
        num_evals=50,
        episode_length=1000,
        normalize_observations=True,
        reward_scaling=30,
        action_repeat=1,
        discounting=0.997,
        learning_rate=6e-4,
        num_envs=128,
        batch_size=512,
        grad_updates_per_step=64,
        seed=1
    )


In [5]:
env = envs.create('hopper')
make_inference_fn, params, _ = train_fn(environment=env, progress_fn=progress_fn)
inference_fn = make_inference_fn(params)

jit_reset = jax.jit(env.reset)
jit_step = jax.jit(env.step)
jit_infer = jax.jit(inference_fn)

[0 steps] Eval reward: 13.84
[20480 steps] Eval reward: 197.52
[40960 steps] Eval reward: 201.06
[61440 steps] Eval reward: 26.85
[81920 steps] Eval reward: 498.81
[102400 steps] Eval reward: 484.50
[122880 steps] Eval reward: 493.93
[143360 steps] Eval reward: 494.91
[163840 steps] Eval reward: 474.12
[184320 steps] Eval reward: 511.02
[204800 steps] Eval reward: 571.46
[225280 steps] Eval reward: 271.98
[245760 steps] Eval reward: 575.11
[266240 steps] Eval reward: 512.23
[286720 steps] Eval reward: 530.68
[307200 steps] Eval reward: 483.63
[327680 steps] Eval reward: 538.22
[348160 steps] Eval reward: 603.31
[368640 steps] Eval reward: 500.64
[389120 steps] Eval reward: 569.72
[409600 steps] Eval reward: 530.24
[430080 steps] Eval reward: 651.09
[450560 steps] Eval reward: 622.69
[471040 steps] Eval reward: 713.74
[491520 steps] Eval reward: 527.45
[512000 steps] Eval reward: 721.07
[532480 steps] Eval reward: 531.54
[552960 steps] Eval reward: 619.03
[573440 steps] Eval reward: 656

In [7]:
contexts = []
trajectories = []

num_trajectories = 1000
horizon = 50
rng = jax.random.PRNGKey(0)

for i in range(num_trajectories):
    rng, rollout_rng = jax.random.split(rng)
    state = jit_reset(rollout_rng)
    initial_obs = state.obs
    traj = []

    for _ in range(horizon):
        rollout_rng, act_rng = jax.random.split(rollout_rng)
        action, _ = jit_infer(state.obs, act_rng)
        traj.append(action)
        state = jit_step(state, action)

    context = initial_obs
    contexts.append(np.array(context))
    trajectories.append(np.array(jnp.stack(traj)))

    if i % 10 == 0:
        print(f"Collected {i}/{num_trajectories}")

Collected 0/1000
Collected 10/1000
Collected 20/1000
Collected 30/1000
Collected 40/1000
Collected 50/1000
Collected 60/1000
Collected 70/1000
Collected 80/1000
Collected 90/1000
Collected 100/1000
Collected 110/1000
Collected 120/1000
Collected 130/1000
Collected 140/1000
Collected 150/1000
Collected 160/1000
Collected 170/1000
Collected 180/1000
Collected 190/1000
Collected 200/1000
Collected 210/1000
Collected 220/1000
Collected 230/1000
Collected 240/1000
Collected 250/1000
Collected 260/1000
Collected 270/1000
Collected 280/1000
Collected 290/1000
Collected 300/1000
Collected 310/1000
Collected 320/1000
Collected 330/1000
Collected 340/1000
Collected 350/1000
Collected 360/1000
Collected 370/1000
Collected 380/1000
Collected 390/1000
Collected 400/1000
Collected 410/1000
Collected 420/1000
Collected 430/1000
Collected 440/1000
Collected 450/1000
Collected 460/1000
Collected 470/1000
Collected 480/1000
Collected 490/1000
Collected 500/1000
Collected 510/1000
Collected 520/1000
Coll

In [8]:
np.savez("proposal_dataset_hopper.npz", contexts=np.stack(contexts), trajectories=np.stack(trajectories))
print("✅ Dataset saved: proposal_dataset_hopper.npz")

✅ Dataset saved: proposal_dataset_hopper.npz
