### This notebook is created based on https://github.com/simpler-env/SimplerEnv/blob/main/example.ipynb

In [None]:
SEED = None

## Installation


In [None]:
#@title [!Important]Please use a GPU runtime.
!nvidia-smi

In [None]:
# @title Make sure vulkan is installed correctly
!vulkaninfo | head -n 5

In [None]:
# @title [Important]Post Installation

# run this so local pip installs are recognized
import site
site.main()

In [None]:
import numpy as np

## Create a Simulated Environment and Take Random Actions

In [None]:
import simpler_env
from simpler_env.utils.env.observation_utils import get_image_from_maniskill2_obs_dict
import mediapy
import sapien.core as sapien

## Run Inference on Simulated Environments

In [None]:
# @title Setup

import os
import numpy as np
import simpler_env
from simpler_env.utils.env.observation_utils import get_image_from_maniskill2_obs_dict
import mediapy

ckpt_dir="./checkpoints"

RT_1_CHECKPOINTS = {
    "rt_1_x": "rt_1_x_tf_trained_for_002272480_step",
    "rt_1_400k": "rt_1_tf_trained_for_000400120",
    "rt_1_58k": "rt_1_tf_trained_for_000058240",
    "rt_1_1k": "rt_1_tf_trained_for_000001120",
}

In [None]:
# @title Select your model and environment

task_name = "google_robot_pick_customizable"  # @param ["google_robot_pick_coke_can", "google_robot_move_near", "google_robot_open_drawer", "google_robot_close_drawer", "widowx_spoon_on_towel", "widowx_carrot_on_plate", "widowx_stack_cube", "widowx_put_eggplant_in_basket"]

if 'env' in locals():
  print("Closing existing env")
  env.close()
  del env
env = simpler_env.make(task_name)

# Note: we turned off the denoiser as the colab kernel will crash if it's turned on
# To use the denoiser, please git clone our SIMPLER environments
# and perform evaluations locally.
sapien.render_config.rt_use_denoiser = True

obs, reset_info = env.reset(seed=SEED)
# obs, reset_info = env.reset()
instruction = env.get_language_instruction()
print("Reset info", reset_info)
print("Instruction", instruction)

if "google" in task_name:
  policy_setup = "google_robot"
else:
  policy_setup = "widowx_bridge"

In [None]:
# @title Select your model and environment

model_name = "openvla-7b"

if "rt_1" in model_name:
  from simpler_env.policies.rt1.rt1_model import RT1Inference
  ckpt_path = os.path.join(ckpt_dir, RT_1_CHECKPOINTS[model_name])
  model = RT1Inference(saved_model_path=ckpt_path, policy_setup=policy_setup)
elif "octo" in model_name:
  from simpler_env.policies.octo.octo_model import OctoInference

  model = OctoInference(model_type=model_name, policy_setup=policy_setup, init_rng=0)
elif "openvla" in model_name:
  from simpler_env.policies.openvla.openvla_model import OpenVLAInference
  model = OpenVLAInference(model_type=model_name, policy_setup=policy_setup)
else:
  raise ValueError(model_name)


In [None]:
def run_inference(seed, options, env, model):
    #@title Run inference

    obs, reset_info = env.reset(seed=seed, options=options)
    # obs, reset_info = env.reset(options=options)
    instruction = env.get_language_instruction()
    model.reset(instruction)
    print(instruction)
    print("Reset info", reset_info)
    
    image = get_image_from_maniskill2_obs_dict(env, obs)  # np.ndarray of shape (H, W, 3), uint8
    images = [image]
    predicted_terminated, success, truncated = False, False, False
    timestep = 0
    while not (predicted_terminated or truncated):
        # step the model; "raw_action" is raw model action output; "action" is the processed action to be sent into maniskill env
        raw_action, action = model.step(image)
        predicted_terminated = bool(action["terminate_episode"][0] > 0)
        obs, reward, success, truncated, info = env.step(
            np.concatenate([action["world_vector"], action["rot_axangle"], action["gripper"]])
        )
        print(timestep, info)
        # update image observation
        image = get_image_from_maniskill2_obs_dict(env, obs)
        images.append(image)
        timestep += 1
    
    episode_stats = info.get("episode_stats", {})
    print(f"Episode success: {success}")
    return images

In [None]:
images = run_inference(seed=SEED, options=None, env=env, model=model)

In [None]:
mediapy.show_video(images, fps=10)