In [1]:
import dataclasses

import jax

from openpi.models import model as _model
from openpi.policies import droid_policy
from openpi.policies import policy_config as _policy_config
from openpi.shared import download
from openpi.training import config as _config
from openpi.training import data_loader as _data_loader

In [2]:
import subprocess
import os

result = subprocess.run('bash -c "source /etc/network_turbo && env | grep proxy"', shell=True, capture_output=True, text=True)
output = result.stdout
for line in output.splitlines():
    if '=' in line:
        var, value = line.split('=', 1)
        os.environ[var] = value

# Policy inference

The following example shows how to create a policy from a checkpoint and run inference on a dummy example.

In [2]:

config = _config.get_config("pi0_fast_droid")
checkpoint_dir = download.maybe_download("s3://openpi-assets/checkpoints/pi0_fast_droid")
print("Checkpoint downloaded to:", checkpoint_dir)

# Create a trained policy.
policy = _policy_config.create_trained_policy(config, checkpoint_dir)

# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.
example = droid_policy.make_droid_example()
result = policy.infer(example)

# Delete the policy to free up memory.
del policy

print("Actions shape:", result["actions"].shape)

Checkpoint downloaded to: /root/.cache/openpi/openpi-assets/checkpoints/pi0_fast_droid


Some kwargs in processor config are unused and will not have any effect: min_token, action_dim, time_horizon, scale, vocab_size. 
Some kwargs in processor config are unused and will not have any effect: min_token, action_dim, time_horizon, scale, vocab_size. 


Actions shape: (10, 8)


# Working with a live model


The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.


In [6]:
config = _config.get_config("pi0_aloha_sim")

checkpoint_dir = download.maybe_download("s3://openpi-assets/checkpoints/pi0_aloha_sim")
# print("Checkpoint downloaded to:", checkpoint_dir)

key = jax.random.key(0)

# Create a model from the checkpoint.
model = config.model.load(_model.restore_params(checkpoint_dir / "params"))

# We can create fake observations and actions to test the model.
obs, act = config.model.fake_obs(), config.model.fake_act()

# Sample actions from the model.
loss = model.compute_loss(key, obs, act)
print("Loss shape:", loss.shape)
print(loss)

Loss shape: (1, 50)
[[9.128082  8.807123  8.616392  8.909349  8.866441  8.331093  8.623332
  8.269854  8.772354  8.809507  7.9495163 8.170952  8.852371  8.389737
  8.961611  8.289936  8.176186  7.990792  8.255247  8.0972595 8.1045
  8.130877  8.136462  7.9967995 7.969477  7.990366  8.517185  8.204456
  7.3827057 8.35801   8.100222  7.6314535 7.79031   8.697946  8.020273
  7.918974  8.298997  7.7337446 8.041628  8.003708  7.8390265 8.053684
  7.9783554 8.584342  8.673449  8.223232  8.627295  8.290242  7.978609
  9.183986 ]]


Now, we are going to create a data loader and use a real batch of training data to compute the loss.

In [7]:
import subprocess
import os

result = subprocess.run('bash -c "source /etc/network_turbo && env | grep proxy"', shell=True, capture_output=True, text=True)
output = result.stdout
for line in output.splitlines():
    if '=' in line:
        var, value = line.split('=', 1)
        os.environ[var] = value


# Reduce the batch size to reduce memory usage.
config = dataclasses.replace(config, batch_size=2)

# Load a single batch of data. This is the same data that will be used during training.
# NOTE: In order to make this example self-contained, we are skipping the normalization step
# since it requires the normalization statistics to be generated using `compute_norm_stats`.
loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)
obs, act = next(iter(loader))

# Sample actions from the model.
loss = model.compute_loss(key, obs, act)

# Delete the model to free up memory.
del model

print("Loss shape:", loss.shape)
print(loss)

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Fetching 106 files:   0%|          | 0/106 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/50 [00:00<?, ?it/s]

Loss shape: (2, 50)
[[0.16305491 0.16936925 0.17258719 0.17509273 0.16549909 0.17259395
  0.15390736 0.15149112 0.15859419 0.15400508 0.14487183 0.14287466
  0.14395398 0.14725338 0.15673023 0.15144402 0.15445954 0.1614206
  0.16689792 0.16941643 0.17953211 0.19078353 0.1857508  0.20530383
  0.21838504 0.22986919 0.24887243 0.27364165 0.27926794 0.30435103
  0.3344978  0.36285368 0.36504626 0.38305753 0.41475707 0.45635572
  0.48581845 0.52128434 0.54498327 0.6158477  0.6252788  0.6834529
  0.68625534 0.7348676  0.79612076 0.79155254 0.81713223 0.8567008
  0.84930444 0.8213333 ]
 [0.10275582 0.10470946 0.1043144  0.11545829 0.11299162 0.11235681
  0.11203268 0.10946272 0.09947442 0.1066063  0.10297152 0.09772962
  0.10122444 0.09798466 0.10198066 0.09790149 0.10567708 0.111296
  0.10884148 0.12303942 0.12316433 0.13894933 0.13322476 0.14326176
  0.1546779  0.17019024 0.17846107 0.1934481  0.2125881  0.23278277
  0.2537644  0.27986088 0.32630974 0.34754717 0.3839236  0.41825342
  0.4655