In [None]:
import cv2
import jax
import tensorflow_datasets as tfds
import tqdm
import mediapy
import numpy as np
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
# from octo.utils.jax_utils import initialize_compilation_cache

import tensorflow as tf
# initialize_compilation_cache()
# # prevent tensorflow from using GPU memory since it's only used for data loading
tf.config.set_visible_devices([], "GPU")


## 1.不训练测试：

In [None]:
import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
from huggingface_hub import hf_hub_download
from octo.model.octo_model import OctoModel
model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5")
# model = OctoModel.load_pretrained("/root/autodl-tmp/model/")
print('Successfully!')

In [None]:
# # create RLDS dataset builder
# builder = tfds.builder_from_directory(builder_dir='gs://gresearch/robotics/bridge/0.1.0/')
# ds = builder.as_dataset(split='train[:1]')
# 创建数据集构建器，使用本地路径
# builder = tfds.builder_from_directory(builder_dir='/root/octo/examples/bridge/0.1.0/')
builder = tfds.builder_from_directory(builder_dir='/root/autodl-tmp/aloha_sim_dataset/aloha_sim_cube_scripted_dataset/1.0.0/')
# builder = tfds.builder_from_directory(builder_dir='/root/autodl-tmp/tensorflow_datasets/duck_killer/1.0.0/')
# 加载数据集
ds = builder.as_dataset(split='train')
print(ds)
# total_images = 0

# for example in ds:
#     steps = list(example['steps'])
#     total_images += len(steps)

# print(f"Total number of images in the dataset: {total_images}")
# # sample episode + resize to 256x256 (default third-person cam resolution)
episode = next(iter(ds))
print(episode)
steps = list(episode['steps'])
images = [cv2.resize(np.array(step['observation']['top']), (256, 256)) for step in steps]
# images = images[0:100]
print(len(images))
# extract goal image & language instruction
goal_image = images[-1]
language_instruction = steps[0]['language_instruction'].numpy().decode()

# visualize episode
print(f'Instruction: {language_instruction}')
mediapy.show_video(images, fps=10)

In [None]:
WINDOW_SIZE = 2

# create `task` dict
task = model.create_tasks(goals={"image_primary": goal_image[None]})   # for goal-conditioned
task = model.create_tasks(texts=[language_instruction])                  # for language conditioned

In [None]:
# run inference loop, this model only uses 3rd person image observations for bridge
# collect predicted and true actions
pred_actions, true_actions = [], []
for step in tqdm.trange(len(images) - (WINDOW_SIZE - 1)):
    input_images = np.stack(images[step:step+WINDOW_SIZE])[None]
    observation = {
        'image_primary': input_images,
        'timestep_pad_mask': np.full((1, input_images.shape[1]), True, dtype=bool)
    }
    
    # this returns *normalized* actions --> we need to unnormalize using the dataset statistics
    actions = model.sample_actions(
        observation, 
        task, 
        unnormalization_statistics=model.dataset_statistics['bridge_dataset']["action"], 
        rng=jax.random.PRNGKey(0)
    )
    actions = actions[0] # remove batch dim

    pred_actions.append(actions)
    final_window_step = step + WINDOW_SIZE - 1
    true_actions.append(np.concatenate(
        (
            # steps[final_window_step]['action']['world_vector'], 
            # steps[final_window_step]['action']['rotation_delta'], 
            steps[final_window_step]['action'], 
            # np.array(steps[final_window_step]['action']['open_gripper']).astype(np.float32)[None]
        ), axis=-1
    ))

In [None]:
import matplotlib.pyplot as plt

ACTION_DIM_LABELS = ['x', 'y', 'z', 'yaw', 'pitch', 'roll', 'grasp']

# build image strip to show above actions
img_strip = np.concatenate(np.array(images[::30]), axis=1)

# set up plt figure
figure_layout = [
    ['image'] * len(ACTION_DIM_LABELS),
    ACTION_DIM_LABELS
]
plt.rcParams.update({'font.size': 12})
fig, axs = plt.subplot_mosaic(figure_layout)
fig.set_size_inches([45, 10])

# plot actions
pred_actions = np.array(pred_actions).squeeze()
true_actions = np.array(true_actions).squeeze()
for action_dim, action_label in enumerate(ACTION_DIM_LABELS):
  # actions have batch, horizon, dim, in this example we just take the first action for simplicity
  axs[action_label].plot(pred_actions[:, 0, action_dim], label='predicted action')
  axs[action_label].plot(true_actions[:, action_dim], label='ground truth')
  axs[action_label].set_title(action_label)
  axs[action_label].set_xlabel('Time in one episode')

axs['image'].imshow(img_strip)
axs['image'].set_xlabel('Time in one episode (subsampled)')
plt.legend()

## 2. 训练过的模型

In [None]:
import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
from huggingface_hub import hf_hub_download
from octo.model.octo_model import OctoModel
# model = OctoModel.load_pretrained("hf://rail-berkeley/octo-small-1.5")
model = OctoModel.load_pretrained("/root/autodl-tmp/model/")
print('Successfully!')

In [None]:
# # create RLDS dataset builder
# builder = tfds.builder_from_directory(builder_dir='gs://gresearch/robotics/bridge/0.1.0/')
# ds = builder.as_dataset(split='train[:1]')
# 创建数据集构建器，使用本地路径
# builder = tfds.builder_from_directory(builder_dir='/root/octo/examples/bridge/0.1.0/')
builder = tfds.builder_from_directory(builder_dir='/root/autodl-tmp/aloha_sim_dataset/aloha_sim_cube_scripted_dataset/1.0.0/')
# builder = tfds.builder_from_directory(builder_dir='/root/autodl-tmp/tensorflow_datasets/duck_killer/1.0.0/')
# 加载数据集
ds = builder.as_dataset(split='train')
print(ds)
# total_images = 0

# for example in ds:
#     steps = list(example['steps'])
#     total_images += len(steps)

# print(f"Total number of images in the dataset: {total_images}")
# sample episode + resize to 256x256 (default third-person cam resolution)
episode = next(iter(ds))
steps = list(episode['steps'])
images = [cv2.resize(np.array(step['observation']['top']), (256, 256)) for step in steps]
# images = images[0:100]
print(len(images))
# extract goal image & language instruction
goal_image = images[-1]
language_instruction = steps[0]['language_instruction'].numpy().decode()

# visualize episode
print(f'Instruction: {language_instruction}')
mediapy.show_video(images, fps=10)

In [None]:
WINDOW_SIZE = 2

# create `task` dict
task = model.create_tasks(goals={"image_primary": goal_image[None]})   # for goal-conditioned
task = model.create_tasks(texts=[language_instruction])                  # for language conditioned

In [None]:
# run inference loop, this model only uses 3rd person image observations for bridge
# collect predicted and true actions
pred_actions, true_actions = [], []
for step in tqdm.trange(len(images) - (WINDOW_SIZE - 1)):
    input_images = np.stack(images[step:step+WINDOW_SIZE])[None]
    observation = {
        'image_primary': input_images,
        'timestep_pad_mask': np.full((1, input_images.shape[1]), True, dtype=bool)
    }
    
    # this returns *normalized* actions --> we need to unnormalize using the dataset statistics
    actions = model.sample_actions(
        observation, 
        task, 
        unnormalization_statistics=model.dataset_statistics["action"], 
        rng=jax.random.PRNGKey(0)
    )
    actions = actions[0] # remove batch dim

    pred_actions.append(actions)
    final_window_step = step + WINDOW_SIZE - 1
    true_actions.append(np.concatenate(
        (
            # steps[final_window_step]['action']['world_vector'], 
            # steps[final_window_step]['action']['rotation_delta'], 
            steps[final_window_step]['action'], 
            # np.array(steps[final_window_step]['action']['open_gripper']).astype(np.float32)[None]
        ), axis=-1
    ))

In [None]:
import matplotlib.pyplot as plt

ACTION_DIM_LABELS = ['x', 'y', 'z', 'yaw', 'pitch', 'roll', 'grasp']

# build image strip to show above actions
img_strip = np.concatenate(np.array(images[::30]), axis=1)

# set up plt figure
figure_layout = [
    ['image'] * len(ACTION_DIM_LABELS),
    ACTION_DIM_LABELS
]
plt.rcParams.update({'font.size': 12})
fig, axs = plt.subplot_mosaic(figure_layout)
fig.set_size_inches([45, 10])

# plot actions
pred_actions = np.array(pred_actions).squeeze()
true_actions = np.array(true_actions).squeeze()
for action_dim, action_label in enumerate(ACTION_DIM_LABELS):
  # actions have batch, horizon, dim, in this example we just take the first action for simplicity
  axs[action_label].plot(pred_actions[:, 0, action_dim], label='predicted action')
  axs[action_label].plot(true_actions[:, action_dim], label='ground truth')
  axs[action_label].set_title(action_label)
  axs[action_label].set_xlabel('Time in one episode')

axs['image'].imshow(img_strip)
axs['image'].set_xlabel('Time in one episode (subsampled)')
plt.legend()