# 🚀 GAS Demo (ICML 2025)
Paper: https://arxiv.org/abs/2506.07744  
GitHub: https://github.com/qortmdgh4141/GAS  
Project page: https://qortmdgh4141.github.io/projects/GAS  
Talk (10min): https://www.youtube.com/watch?v=6mxRlbn2_6s

This notebook demonstrates how to run pretrained GAS checkpoints and visualize trajectories.

## 1. **Setup**

In [None]:
!git clone -q https://github.com/qortmdgh4141/GAS.git
%cd /content/GAS

!pip -q install ogbench==1.1.5
!pip -q install distrax==0.1.5
!pip -q install ml_collections==0.1.1

In [None]:
import os
import cv2
os.environ['MUJOCO_GL'] = 'egl'

import jax
import random
import mujoco
import ogbench
import imageio
import numpy as np
from collections import defaultdict
from IPython.display import Video, display
from huggingface_hub import snapshot_download

from K_utils.keygraph_utils import KeyGraph
from M_utils.agents import gas, agents_dict
from M_utils.flax_utils import restore_agent

## 2. **Define helper functions**

In [None]:
def supply_rng(f, rng=jax.random.PRNGKey(0)):
    """Helper function to split the random number generator key before each call to the function."""
    def wrapped(*args, **kwargs):
        nonlocal rng
        rng, key = jax.random.split(rng)
        return f(*args, seed=key, **kwargs)
    return wrapped

def demo_evaluate_with_graph(agent, key_graph, env, shadow_renderer, top_cam, ego_cam, front_cam, env_name, task_id, seed, eval_on_cpu, eval_subgoal_threshold, eval_final_goal_threshold,):
    """
    Evaluate the GAS in the environment.
    In OGbench environments, the final goal includes slight random noise per episode.
    Empirically, we found little to no performance difference between:
    (1) Recomputing the shortest path every episode
    (2) Reusing a precomputed path for the same task_id

    The function `demo_evaluate_with_graph()` follows (2) by computing the shortest path once via `precompute_shortest_paths_to_all_tasks()`
    When using (2), we recommend setting `eval_final_goal_threshold >= 2` to allow agents to reach slightly perturbed final goals.
    If final goals vary significantly per episode, or if GAS is extended to online RL, strategy (1) is preferred.
    """
    eval_agent = jax.device_put(agent, device=jax.devices('cpu')[0]) if eval_on_cpu else agent
    get_phi_fn = eval_agent.get_phi
    actor_fn = supply_rng(eval_agent.sample_actions, rng=jax.random.PRNGKey(seed))

    steps = 0
    infos = []
    frames = []
    observation, info = env.reset(seed=seed, options=dict(task_id=task_id, render_goal=True))
    goal = info.get('goal')
    done = False

    epsilon=1e-10
    phi_obs = np.array(get_phi_fn(observation))
    phi_goal = np.array(get_phi_fn(goal))
    final_goal_on = False
    shortest_path = key_graph.get_shortest_path(task_id=task_id, source=phi_obs, force_closest=True)
    while not done:
        phi_obs = np.array(get_phi_fn(observation))
        if final_goal_on:
            cur_obs_goal = phi_goal
        else:
            cached_shortest_path = key_graph.get_shortest_path(task_id=task_id, source=phi_obs)
            if cached_shortest_path is None:
                pass
            else:
                shortest_path = cached_shortest_path

            distances = np.linalg.norm(np.array(shortest_path) - phi_obs, axis=1)
            valid_indices = np.where(distances <= eval_subgoal_threshold)[0]
            cur_node_idx = valid_indices[-1] if len(valid_indices) > 0 else 0
            if len(shortest_path) <= eval_final_goal_threshold:
                final_goal_on = True
                cur_obs_goal = phi_goal
            else:
                cur_obs_goal = shortest_path[cur_node_idx]

        skills = (cur_obs_goal - phi_obs) / (np.linalg.norm(cur_obs_goal - phi_obs) + epsilon)
        action = actor_fn(observations=observation, goals=skills, temperature=0.0)
        action = np.clip(np.array(action), -1, 1)

        next_observation, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated
        steps += 1
        if steps % 3 == 0 or done:
            infos.append(info)
            if "antmaze" in env_name:
                assert top_cam is not None and ego_cam is not None, "Antmaze cameras not initialized."
                shadow_renderer.update_scene(env.unwrapped.data, camera=top_cam)
                top_frame = shadow_renderer.render()
                ego_cam.lookat[0] = float(info["xy"][0])
                ego_cam.lookat[1] = float(info["xy"][1])
                shadow_renderer.update_scene(env.unwrapped.data, camera=ego_cam)
                ego_frame = shadow_renderer.render()
                frame = np.concatenate([top_frame, ego_frame], axis=1)
            elif "scene" in ENV_NAME:
                assert front_cam is not None, "Scene front camera not initialized."
                shadow_renderer.update_scene(env.unwrapped.data, camera=front_cam)
                frame = shadow_renderer.render()
            else:
                raise RuntimeError(f"Unsupported environment: {ENV_NAME}")
            frames.append(frame)
        observation = next_observation

    return frames, infos, steps

## 3. Configure environment & agent

In [None]:
# Select the environment.
ENV_NAME_LIST = ["antmaze-giant-navigate-v0", "antmaze-giant-stitch-v0", "antmaze-large-explore-v0", "scene-play-v0",
                 "visual-antmaze-giant-navigate-v0", "visual-antmaze-giant-stitch-v0", "visual-antmaze-large-explore-v0", "visual-scene-play-v0",]
ENV_NAME = ENV_NAME_LIST[4] # Change the index to select the desired environment 🌍

# Select the task ID.
TASK_ID_LIST = [1, 2, 3, 4, 5]
TASK_ID = TASK_ID_LIST[0] # Change the index to select the desired task 🎯

In [None]:
# Set random seeds and configuration (usually no need to change)
seed = 0
random.seed(seed)
np.random.seed(seed)
eval_final_goal_threshold = 2
config = gas.get_config()

if ENV_NAME.startswith("visual-"):
    eval_on_cpu = 0
    config["encoder"] = "impala_small"
else:
    eval_on_cpu = 1
    config["encoder"] = "not_used"

if ENV_NAME.startswith("scene-play"):
    config["way_steps"] = 48
elif ENV_NAME.startswith("visual-scene-play"):
    config["way_steps"] = 24
else:
    config["way_steps"] = 8

# Set up environment.
env = ogbench.make_env_and_datasets(ENV_NAME, env_only=True)
env.unwrapped.model.vis.global_.offwidth  = 400
env.unwrapped.model.vis.global_.offheight = 400
shadow_renderer = mujoco.Renderer(env.unwrapped.model, width=400, height=400)

top_cam = None
ego_cam = None
front_cam = None
if "antmaze" in ENV_NAME:
    top_cam = mujoco.MjvCamera()
    top_cam_params_list = [("giant", (26, 18, 70, -90)), ("large", (18, 12, 50, -90)),]
    for key, (lx, ly, dist, elev) in top_cam_params_list:
        if key in ENV_NAME:
            (top_cam.lookat[0], top_cam.lookat[1], top_cam.distance, top_cam.elevation) = (lx, ly, dist, elev)
            break
    ego_cam = mujoco.MjvCamera()
    (ego_cam.distance, ego_cam.elevation) = (10, -50)
elif "scene" in ENV_NAME:
    cam_name = "front" #"front_pixels" if ENV_NAME.startswith("visual-") else "front"
    cam_id = mujoco.mj_name2id(env.unwrapped.model, mujoco.mjtObj.mjOBJ_CAMERA, cam_name)
    front_cam = mujoco.MjvCamera()
    mujoco.mjv_defaultCamera(front_cam)
    front_cam.type = mujoco.mjtCamera.mjCAMERA_FIXED
    front_cam.fixedcamid = cam_id
else:
    raise RuntimeError(f"Unsupported environment: {ENV_NAME}")

In [None]:
# Initialize agent.
obs_space = env.observation_space
act_space = env.action_space
if len(obs_space.shape) == 3:
    ex_obs = np.random.randint(0, 256, size=(1, *obs_space.shape), dtype=np.uint8)
else:
    ex_obs = np.random.randn(1, obs_space.shape[0]).astype(np.float32)
ex_act = np.random.randn(1, act_space.shape[0]).astype(np.float32)

agent_class = agents_dict[config['agent_name']]
agent = agent_class.create(seed, ex_obs, ex_act, config,)

# Download official GAS checkpoints.
ckpt_dir = "checkpoints"
os.makedirs(ckpt_dir, exist_ok=True)
env_folder = ENV_NAME.replace("-v0", "")
snapshot_download(repo_id="qortmdgh4141/GAS", local_dir=ckpt_dir, allow_patterns=[f"{env_folder}/*"],)

# Restore graph.
key_graph = KeyGraph()
keygraph_path = os.path.join(ckpt_dir, env_folder, "keygraph.pkl")
keygraph_load_path = os.path.dirname(keygraph_path)
keygraph_load_filename = os.path.basename(keygraph_path).split('_')[-1].split('.')[0]
key_graph.load_keygraph(keygraph_load_path, keygraph_load_filename)

# Restore low-level policy.
params_file = "params_500000.pkl" if ENV_NAME.startswith("visual-") else "params_1000000.pkl"
policy_path = os.path.join(ckpt_dir, env_folder, params_file)
policy_restore_path = os.path.dirname(policy_path)
policy_restore_epoch = os.path.basename(policy_path).split('_')[-1].split('.')[0]
agent = restore_agent(agent, policy_restore_path, policy_restore_epoch)

## 4. Run evaluation

In [None]:
# Evaluate GAS (single episode)
frames, infos, steps = demo_evaluate_with_graph(agent, key_graph, env, shadow_renderer, top_cam, ego_cam, front_cam, ENV_NAME, TASK_ID, seed, eval_on_cpu, config['way_steps'], eval_final_goal_threshold,)

## 5. Visualize results

In [None]:
# Visualize evaluation result
print(f"🌍 Environment: {ENV_NAME}")
print(f"🎯 Task ID: {TASK_ID}")
if infos[-1]["success"] == 1.0:
    print(f"✅ Episode succeeded in {steps} steps\n")
else:
    print("❌ Episode failed to reach the final goal\n")

imageio.mimsave("/tmp/demo.mp4", frames, fps=30)
display(Video("/tmp/demo.mp4", embed=True, height=400))