# üîÑ Replay Episode to Robot

This notebook replays **ground truth actions** from a recorded episode directly to the robot.

No trained policy is needed - this is useful for:
1. Verifying that recorded demonstrations are valid
2. Testing robot connectivity and action execution
3. Debugging the action space and robot interface

‚ö†Ô∏è **Warning**: This will move the physical robot! Ensure the workspace is clear.

## 1. Configuration

Set paths and replay parameters.

In [None]:
from pathlib import Path

# Dataset path - try multiple possible locations
possible_dataset_paths = [
    Path("/data/lerobot/[TODO]"),  # JupyterHub absolute
]

DATASET_DIR = None
for p in possible_dataset_paths:
    if p.exists():
        DATASET_DIR = p
        print(f"‚úÖ Found dataset at: {p}")
        break
    else:
        print(f"‚ùå Not found: {p}")

if DATASET_DIR is None:
    print("\n‚ö†Ô∏è Please set DATASET_DIR manually")

# TODO: Set robot IP address
SERVER_ENDPOINT = "<robot_ip_address>:50051"

# Episode to replay (0-indexed)
EPISODE_INDEX = 0

# Replay frequency in Hz - should match the recording FPS
REPLAY_FREQUENCY_HZ = 10.0

# Speed multiplier (1.0 = normal speed, 0.5 = half speed, 2.0 = double speed)
SPEED_MULTIPLIER = 1.0

print(f"\nDataset: {DATASET_DIR}")
print(f"Episode: {EPISODE_INDEX}")
print(f"Robot server: {SERVER_ENDPOINT}")
print(f"Replay frequency: {REPLAY_FREQUENCY_HZ} Hz")
print(f"Speed multiplier: {SPEED_MULTIPLIER}x")

## 2. Load Dataset and Metadata

Load the dataset using LeRobotDataset and extract metadata for the action translator.

In [None]:
import numpy as np
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset

from example_policies.robot_deploy.deploy_core.policy_loader import load_metadata
from example_policies.robot_deploy.deploy_core.action_translator import ActionTranslator
from example_policies.utils.constants import ACTION, OBSERVATION_STATE
from example_policies.utils.action_order import ActionMode

# Select device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load metadata from dataset
metadata = load_metadata(DATASET_DIR)
print(f"‚úÖ Metadata loaded")

# Create a fake config that mimics what the policy loader would provide
class FakeConfig:
    def __init__(self, m):
        self.metadata = m
        self.device = device
        self.output_features = {}
        self.input_features = {}
        self.input_features[OBSERVATION_STATE] = np.asarray(
            m["features"][OBSERVATION_STATE]["names"]
        )
        self.output_features[ACTION] = np.asarray(m["features"][ACTION]["names"])

    def get_tcp_from_state(self, state: np.ndarray) -> np.ndarray:
        state_names = []
        state_names.extend([f"tcp_left_pos_{i}" for i in "xyz"])
        state_names.extend([f"tcp_left_quat_{i}" for i in "xyzw"])
        state_names.extend([f"tcp_right_pos_{i}" for i in "xyz"])
        state_names.extend([f"tcp_right_quat_{i}" for i in "xyzw"])

        state_indices = [
            np.where(self.input_features[OBSERVATION_STATE] == name)[0][0]
            for name in state_names
        ]
        return state[:, state_indices]

cfg = FakeConfig(metadata)
print(f"Input features: {cfg.input_features[OBSERVATION_STATE][:5]}...")
print(f"Output features: {cfg.output_features[ACTION][:5]}...")

# Load dataset
dataset = LeRobotDataset(
    repo_id=DATASET_DIR.name,
    root=DATASET_DIR,
    episodes=[EPISODE_INDEX],
)
print(f"‚úÖ Dataset loaded: {len(dataset)} frames")

# Create dataloader with num_workers=0 to avoid memory issues
dataloader = torch.utils.data.DataLoader(
    dataset,
    num_workers=0,  # Use 0 workers to avoid memory issues
    batch_size=1,
    shuffle=False,
    drop_last=True,
)
print(f"‚úÖ Dataloader created")

# Create action translator
action_translator = ActionTranslator(cfg)
print(f"‚úÖ Action translator created")
print(f"   Action mode: {action_translator.action_mode}")

## 3. Visualize Actions (Optional)

Preview the ground truth actions before sending to the robot.

In [None]:
import matplotlib.pyplot as plt
import pandas as pd

# Load actions directly from parquet (avoids loading videos via LeRobotDataset)
parquet_path = DATASET_DIR / f"data/chunk-000/episode_{EPISODE_INDEX:06d}.parquet"
df = pd.read_parquet(parquet_path)

# Extract action columns
action_cols = sorted([c for c in df.columns if c.startswith("action")])

# Extract actions as numpy array
actions_list = []
for idx in range(len(df)):
    row = df.iloc[idx]
    action_values = []
    for c in action_cols:
        val = row[c]
        if isinstance(val, np.ndarray):
            action_values.extend(val.flatten().tolist())
        else:
            action_values.append(float(val))
    actions_list.append(action_values)

actions_array = np.array(actions_list, dtype=np.float32)
print(f"Actions shape: {actions_array.shape}")

# Create time array
times = np.arange(len(actions_array)) / REPLAY_FREQUENCY_HZ

# Get action names
action_names = cfg.output_features[ACTION]
D = actions_array.shape[1]

fig, axes = plt.subplots(min(D, 10), 1, figsize=(14, 2 * min(D, 10)), sharex=True)
if D == 1:
    axes = [axes]

for d in range(min(D, 10)):
    ax = axes[d]
    ax.plot(times, actions_array[:, d], label="Ground Truth", alpha=0.8)
    ax.set_ylabel(action_names[d] if d < len(action_names) else f"dim {d}", fontsize=8)
    ax.grid(True, alpha=0.3)

axes[0].set_title(f"Episode {EPISODE_INDEX}: Ground Truth Actions")
axes[-1].set_xlabel("Time (s)")
plt.tight_layout()
plt.show()

print(f"Episode duration: {times[-1]:.1f} seconds")
print(f"Total actions: {len(actions_array)}")

## 4. Connect to Robot

Establish connection to the robot service.

In [None]:
import grpc
from example_policies.robot_deploy.robot_io.robot_service import robot_service_pb2_grpc
from example_policies.robot_deploy.robot_io.robot_interface import RobotInterface
from example_policies.robot_deploy.robot_io.robot_client import RobotClient
from example_policies.robot_deploy.utils import print_info

# Connect to robot
print(f"Connecting to robot at {SERVER_ENDPOINT}...")
channel = grpc.insecure_channel(SERVER_ENDPOINT)
stub = robot_service_pb2_grpc.RobotServiceStub(channel)

# Create robot interface
robot_interface = RobotInterface(stub, cfg)

# Create debug printer
dbg_printer = print_info.InfoPrinter(cfg)

# Test connection
try:
    obs = robot_interface.get_observation(device)
    if obs:
        print("‚úÖ Connected to robot!")
    else:
        print("‚ö†Ô∏è Connected but no observation received")
except Exception as e:
    print(f"‚ùå Connection failed: {e}")

## 5. (Optional) Move Robot to Home Position

In [None]:
# Move robot to home position before replay
try:
    response = robot_interface.move_home()
    print(f"‚úÖ Robot homing command sent")
    print(f"   Response: {response}")
except Exception as e:
    print(f"‚ùå Homing failed: {e}")

## 6. Debug: Compare Dataset vs Live Observations

Compare the format of observations from the dataset vs live robot to identify mismatches.

In [None]:
import av
import matplotlib.pyplot as plt

print("=" * 60)
print("DATASET OBSERVATION (from parquet + video)")
print("=" * 60)

# State from parquet (use the df already loaded in cell 3)
row = df.iloc[0]

state_cols = sorted([c for c in df.columns if c.startswith("observation.state")])
state_values = []
for c in state_cols:
    val = row[c]
    if isinstance(val, np.ndarray):
        state_values.extend(val.flatten().tolist())
    else:
        state_values.append(float(val))
dataset_state = torch.tensor(state_values, dtype=torch.float32).unsqueeze(0).to(device)

print(f"\nobservation.state:")
print(f"  Shape: {dataset_state.shape}")
print(f"  Dtype: {dataset_state.dtype}")
print(f"  Range: [{dataset_state.min():.4f}, {dataset_state.max():.4f}]")
print(f"  First 5 values: {dataset_state[0, :5].tolist()}")

# Find available video keys by checking what exists
video_dir = DATASET_DIR / "videos" / "chunk-000"
video_keys = []
dataset_images = {}

if video_dir.exists():
    for subdir in video_dir.iterdir():
        if subdir.is_dir() and subdir.name.startswith("observation.images"):
            key = subdir.name
            video_path = subdir / f"episode_{EPISODE_INDEX:06d}.mp4"
            if video_path.exists():
                video_keys.append(key)
                container = av.open(str(video_path))
                frame = next(container.decode(video=0))
                img = frame.to_ndarray(format="rgb24")
                img_tensor = torch.tensor(img, dtype=torch.float32).permute(2, 0, 1) / 255.0
                dataset_images[key] = img_tensor.unsqueeze(0).to(device)
                container.close()
                
                print(f"\n{key}:")
                print(f"  Shape: {dataset_images[key].shape}")
                print(f"  Dtype: {dataset_images[key].dtype}")
                print(f"  Range: [{dataset_images[key].min():.4f}, {dataset_images[key].max():.4f}]")

print(f"\nFound {len(video_keys)} video streams: {video_keys}")

# Get live observation
print("\n" + "=" * 60)
print("LIVE OBSERVATION (from robot)")
print("=" * 60)

live_obs = robot_interface.get_observation(device)

if live_obs:
    # Check state
    if "observation.state" in live_obs:
        live_state = live_obs["observation.state"]
        print(f"\nobservation.state:")
        print(f"  Shape: {live_state.shape}")
        print(f"  Dtype: {live_state.dtype}")
        print(f"  Range: [{live_state.min():.4f}, {live_state.max():.4f}]")
        print(f"  First 5 values: {live_state[0, :5].tolist()}")
    
    # Check images
    for key in video_keys:
        if key in live_obs:
            live_img = live_obs[key]
            print(f"\n{key}:")
            print(f"  Shape: {live_img.shape}")
            print(f"  Dtype: {live_img.dtype}")
            print(f"  Range: [{live_img.min():.4f}, {live_img.max():.4f}]")
    
    print(f"\n\nAll live obs keys: {list(live_obs.keys())}")
else:
    print("‚ùå No observation received from robot")

# Visual comparison
print("\n" + "=" * 60)
print("VISUAL COMPARISON")
print("=" * 60)

if live_obs and video_keys:
    n_images = len(video_keys)
    fig, axes = plt.subplots(2, n_images, figsize=(5 * n_images, 10))
    if n_images == 1:
        axes = axes.reshape(2, 1)
    
    for i, key in enumerate(video_keys):
        # Dataset image
        if key in dataset_images:
            img_np = dataset_images[key][0].cpu().permute(1, 2, 0).numpy()
            img_np = np.clip(img_np, 0, 1)
            axes[0, i].imshow(img_np)
            axes[0, i].set_title(f"Dataset: {key.split('.')[-1]}")
            axes[0, i].axis('off')
        
        # Live image
        if key in live_obs:
            img_np = live_obs[key][0].cpu().permute(1, 2, 0).numpy()
            img_np = np.clip(img_np, 0, 1)
            axes[1, i].imshow(img_np)
            axes[1, i].set_title(f"Live: {key.split('.')[-1]}")
            axes[1, i].axis('off')
    
    plt.tight_layout()
    plt.show()

## 7. Move Robot to Start Position

Move the robot to the starting position of the episode (first frame's TCP pose).

In [None]:
# Get first batch for starting position
iterator = iter(dataloader)
first_batch = next(iterator)

if action_translator.action_mode in (ActionMode.DELTA_TCP, ActionMode.ABS_TCP, ActionMode.TCP):
    state = first_batch["observation.state"]
    tcp_state = cfg.get_tcp_from_state(state[0].cpu().numpy().reshape(1, -1))
    
    # The robot expects the action to include gripper state as the last two elements
    DEFAULT_GRIPPER_STATE = [0, 0]  # [left_gripper, right_gripper]
    start_action = np.concatenate([tcp_state.flatten(), DEFAULT_GRIPPER_STATE]).astype(np.float32)
    start_action = start_action[None, :]  # Add batch axis
    
    print(f"Start TCP position (first 7 = left arm pos+quat):")
    print(f"  {start_action[0, :7]}")
    
    confirm = input("\nPress Enter to move robot to start position (or 'skip' to skip): ")
    if confirm.lower() != 'skip':
        print("Moving robot to start position...")
        robot_interface.send_action(torch.from_numpy(start_action), ActionMode.ABS_TCP)
        print("‚úÖ Robot moved to start position")
    else:
        print("Skipped moving to start position")
else:
    print(f"Action mode {action_translator.action_mode} - skipping move to start")

## 8. Replay Episode to Robot

‚ö†Ô∏è **This will move the robot!** Make sure the workspace is clear.

This sends the ground truth actions from the recorded episode to the robot.

In [None]:
import time

# Controller mode - CART_WAYPOINT is most stable
CONTROLLER = RobotClient.CART_WAYPOINT

print("=" * 60)
print("‚ö†Ô∏è  ROBOT REPLAY FROM DATASET")
print("=" * 60)
print(f"Episode: {EPISODE_INDEX}")
print(f"Frames: {len(actions_array)}")
print(f"Frequency: {REPLAY_FREQUENCY_HZ} Hz")
print(f"Speed: {SPEED_MULTIPLIER}x")
print(f"Duration: {len(actions_array) / REPLAY_FREQUENCY_HZ / SPEED_MULTIPLIER:.1f} seconds")
print(f"Action mode: {action_translator.action_mode}")
print(f"Controller: {CONTROLLER}")
print("=" * 60)

confirm = input("\nType 'yes' to start replay: ")
if confirm.lower() != 'yes':
    print("Replay cancelled.")
else:
    print("\nüöÄ Starting replay...")
    
    # Adjusted period based on speed multiplier
    period = 1.0 / (REPLAY_FREQUENCY_HZ * SPEED_MULTIPLIER)
    
    try:
        for step, action in enumerate(actions_array):
            start_time = time.time()
            
            # Get current observation from robot (for action translation)
            observation = robot_interface.get_observation(device)
            
            if observation:
                # Convert action to tensor
                action_tensor = torch.tensor(action, dtype=torch.float32).unsqueeze(0)
                
                # Translate action using the same translator as deployment
                translated_action = action_translator.translate(action_tensor, observation)
                
                # Print debug info every 10 steps
                if step % 10 == 0:
                    dbg_printer.print(step, observation, translated_action, raw_action=False)
                
                # Send to robot
                robot_interface.send_action(
                    translated_action,
                    action_translator.action_mode,
                    CONTROLLER
                )
            else:
                print(f"  ‚ö†Ô∏è No observation at step {step}")
            
            # Wait for next cycle
            elapsed = time.time() - start_time
            sleep_time = max(0.0, period - elapsed)
            time.sleep(sleep_time)
        
        print(f"\n‚úÖ Replay complete! Sent {len(actions_array)} actions.")
        
    except KeyboardInterrupt:
        print("\n‚ö†Ô∏è Replay interrupted by user.")
    except Exception as e:
        print(f"\n‚ùå Error during replay: {e}")
        raise

## 9. Cleanup

In [None]:
# Close the gRPC channel
channel.close()
print("Connection closed.")