In [1]:
from pathlib import Path
from typing import Any, Literal, NewType, TypedDict

import gym_pusht  # noqa: F401
import gymnasium as gym

# for performing runtime typechking in a iPython environment.
import jaxtyping
import numpy as np
import rerun as rr
import torch
from beartype.door import die_if_unbearable
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy

%load_ext jaxtyping
%jaxtyping.typechecker beartype.beartype

# load in pretrained model

In [2]:
# Create a directory to store the video of the evaluation
output_directory = Path("outputs/eval/example_pusht_diffusion")
output_directory.mkdir(parents=True, exist_ok=True)

# Select your device
device = "cuda"

# Provide the [hugging face repo id](https://huggingface.co/lerobot/diffusion_pusht):
pretrained_policy_path = "lerobot/diffusion_pusht"
# OR a path to a local outputs/train folder.
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")

policy = DiffusionPolicy.from_pretrained(pretrained_policy_path, map_location=device)

### Initialize eval envrionment
this provides an image of the scene, and the state/position of the agent
will stop running after 300 interactions/steps

In [3]:
env = gym.make(
    "gym_pusht/PushT-v0",
    obs_type="pixels_agent_pos",
    max_episode_steps=300,
)
print("Inputs")
print(env.observation_space)
print(policy.config.input_features)
print()

print("Individual Inputs Seperated out")
print(env.observation_space["agent_pos"])
print(policy.config.input_features["observation.state"])
print(env.observation_space["pixels"])
print(policy.config.input_features["observation.image"])
print()

print("Outputs")
print(env.action_space)
print(policy.config.output_features)

Inputs
Dict('agent_pos': Box(0.0, 512.0, (2,), float64), 'pixels': Box(0, 255, (96, 96, 3), uint8))
{'observation.image': PolicyFeature(type=<FeatureType.VISUAL: 'VISUAL'>, shape=(3, 96, 96)), 'observation.state': PolicyFeature(type=<FeatureType.STATE: 'STATE'>, shape=(2,))}

Individual Inputs Seperated out
Box(0.0, 512.0, (2,), float64)
PolicyFeature(type=<FeatureType.STATE: 'STATE'>, shape=(2,))
Box(0, 255, (96, 96, 3), uint8)
PolicyFeature(type=<FeatureType.VISUAL: 'VISUAL'>, shape=(3, 96, 96))

Outputs
Box(0.0, 512.0, (2,), float32)
{'action': PolicyFeature(type=<FeatureType.ACTION: 'ACTION'>, shape=(2,))}


In [5]:
from einops import rearrange
from jaxtyping import Float32, Float64, UInt8
from torch._tensor import Tensor

# Reset the policy and environments to prepare for rollout
policy.reset()
numpy_observation, info = env.reset(seed=42)
die_if_unbearable(numpy_observation, dict[str, np.ndarray])

rr.init("evaluate_policy")

# Prepare to collect every rewards and all the frames of the episode,
# from initial state to final state.
rewards = []
# frames = []

# Render frame of the initial state
# frames.append(env.render())
rr.set_time_sequence("step", 0)
rr.log("frame", rr.Image(env.render()).compress(jpeg_quality=70))
rr.notebook_show(width=1200)


step = 0
done = False
while not done:
    rr.set_time_sequence("step", step)

    # Prepare observation for the policy running in Pytorch
    state:Float64[torch.Tensor, "2"] = torch.from_numpy(numpy_observation["agent_pos"])
    image:UInt8[torch.Tensor, "96 96 3"] = torch.from_numpy(numpy_observation["pixels"])
    die_if_unbearable(state, Float64[torch.Tensor, "2"])
    die_if_unbearable(image, UInt8[torch.Tensor, "96 96 3"])

    # Convert to float32 with image from channel first in [0,255]
    # to channel last in [0,1]
    state:Float32[torch.Tensor, "2"] = state.to(torch.float32)
    die_if_unbearable(state, Float32[torch.Tensor, "2"])

    image:Float32[torch.Tensor, "96 96 3"] = image.to(torch.float32) / 255
    image:Float32[torch.Tensor, "3 96 96"]  = rearrange(image, "h w c -> c h w")
    die_if_unbearable(image, Float32[torch.Tensor, "3 96 96"])

    # Send data tensors from CPU to GPU
    state = state.to(device, non_blocking=True)
    image = image.to(device, non_blocking=True)

    # Add extra (empty) batch dimension, required to forward the policy
    state = rearrange(state, 'd -> 1 d')
    image = rearrange(image, 'c h w -> 1 c h w')
    die_if_unbearable(state, Float32[torch.Tensor, "1 2"])
    die_if_unbearable(image, Float32[torch.Tensor, "1 3 96 96"])

    # Create the policy input dictionary
    observation: dict[str, Tensor] = {
        "observation.state": state,
        "observation.image": image,
    }

    # Predict the next action with respect to the current observation
    with torch.inference_mode():
        action:Float32[torch.Tensor, "1 2"] = policy.select_action(observation)
        die_if_unbearable(action, Float32[torch.Tensor, "1 2"])

    # Prepare the action for the environment
    numpy_action:Float32[np.ndarray, "2"] = rearrange(action, '1 d -> d').numpy(force=True)
    die_if_unbearable(numpy_action, Float32[np.ndarray, "2"])

    # Step through the environment and receive a new observation
    env_step = env.step(numpy_action)
    # extract the observation, reward, done, info from the environment step
    numpy_observation: dict[str, np.ndarray] = env_step[0]
    die_if_unbearable(numpy_observation, dict[str, np.ndarray])

    reward: float = env_step[1]
    die_if_unbearable(reward, float)

    terminated: bool = env_step[2]
    die_if_unbearable(terminated, bool)

    truncated: bool = env_step[3]
    die_if_unbearable(truncated, bool)

    info: dict[str, Any] = env_step[4]
    die_if_unbearable(info, dict[str, Any])

    # print(f"{step=} {reward=} {terminated=}")

    # Keep track of all the rewards and frames
    rewards.append(reward)
    # frames.append(env.render())
    rr.log("frame", rr.Image(env.render()).compress(jpeg_quality=70))
    rr.log("reward", rr.Scalar(reward))

    # The rollout is considered done when the success state is reach (i.e. terminated is True),
    # or the maximum number of iterations is reached (i.e. truncated is True)
    done = terminated | truncated | done
    step += 1

if terminated:
    print("Success!")
else:
    print("Failure!")

Viewer()

Success!
