# ü§ñ Deploy Policy from Dataset Observations

This notebook runs the trained policy using **observations from a training episode** and sends the predicted actions to the robot.

This is useful for:
1. Testing if the policy produces reasonable actions from known observations
2. Debugging deployment issues without needing live sensor data
3. Replaying a demonstration with learned actions

‚ö†Ô∏è **Warning**: This will move the physical robot! Ensure the workspace is clear.

## 1. Configuration

Set paths and deployment parameters.

In [None]:
from pathlib import Path

# TODO: Set these paths
CHECKPOINT_DIR = Path("outputs/train/006200/pretrained_model")

# Dataset path - try multiple locations
possible_dataset_paths = [
    Path("/data/single_stack_demo"),  # JupyterHub absolute
    Path("../../data/single_stack_demo"),  # JupyterHub relative
    Path("../data/single_stack_demo"),  # Local
    Path("../data/lerobot_output"),  # Local alternative
]

DATASET_DIR = None
for p in possible_dataset_paths:
    if p.exists():
        DATASET_DIR = p
        print(f"‚úÖ Found dataset at: {p}")
        break
    else:
        print(f"‚ùå Not found: {p}")

if DATASET_DIR is None:
    print("\n‚ö†Ô∏è Please set DATASET_DIR manually")

# Robot connection
SERVER_ENDPOINT = "<robot_ip_address>:50051"  # TODO: Set robot IP

# Episode to replay
EPISODE_INDEX = 0

# Inference frequency (Hz) - controls how fast actions are sent
INFERENCE_FREQUENCY_HZ = 10.0

print(f"\nCheckpoint: {CHECKPOINT_DIR}")
print(f"Dataset: {DATASET_DIR}")
print(f"Episode: {EPISODE_INDEX}")
print(f"Robot server: {SERVER_ENDPOINT}")
print(f"Frequency: {INFERENCE_FREQUENCY_HZ} Hz")

## 2. Load Policy and Dataset

In [None]:
import torch
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from example_policies.robot_deploy.deploy_core.policy_loader import load_policy

# Select device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load policy
policy, cfg = load_policy(CHECKPOINT_DIR)
cfg.device = device
policy.to(device)
policy.eval()
print(f"‚úÖ Policy loaded")

# Load dataset
dataset = LeRobotDataset(
    repo_id=str(DATASET_DIR),
    root=DATASET_DIR,
)
print(f"‚úÖ Dataset loaded: {len(dataset)} frames")

# Get episode info
all_episodes = sorted(dataset.meta.episodes.keys())
print(f"Available episodes: {all_episodes}")

## 3. Prepare Episode Data

Get the frame indices for the selected episode (memory efficient - don't load data yet).

In [None]:
import json
import pandas as pd

def to_device_batch(batch: dict, device: torch.device, non_blocking: bool = True):
    """Move all tensors in a batch to the specified device."""
    out = {}
    for k, v in batch.items():
        if torch.is_tensor(v):
            out[k] = v.to(device, non_blocking=non_blocking)
        else:
            out[k] = v
    return out

# Load episode metadata from JSON (no heavy data)
episodes_json = DATASET_DIR / "meta" / "episodes.jsonl"
episodes = []
with open(episodes_json, "r") as f:
    for line in f:
        episodes.append(json.loads(line))

# Get episode info
episode_info = episodes[EPISODE_INDEX]
episode_length = episode_info["length"]

# Calculate starting index by summing lengths of all previous episodes
episode_start_idx = 0
for ep_idx in range(EPISODE_INDEX):
    episode_start_idx += episodes[ep_idx]["length"]

# Create list of dataset indices for this episode
episode_indices = list(range(episode_start_idx, episode_start_idx + episode_length))

# Verify parquet file exists
parquet_path = DATASET_DIR / f"data/chunk-000/episode_{EPISODE_INDEX:06d}.parquet"
assert parquet_path.exists(), f"Parquet file not found: {parquet_path}"

# Verify video files exist
video_keys = ["observation.images.rgb_static", "observation.images.rgb_left", "observation.images.rgb_right"]
for key in video_keys:
    video_path = DATASET_DIR / f"videos/chunk-000/{key}/episode_{EPISODE_INDEX:06d}.mp4"
    if video_path.exists():
        print(f"‚úÖ Video: {key}")
    else:
        print(f"‚ö†Ô∏è Video not found: {key}")

print(f"\n‚úÖ Episode {EPISODE_INDEX}: {episode_length} frames")
print(f"   Dataset indices: {episode_start_idx} to {episode_start_idx + episode_length - 1}")

## 4. Preview Actions (Dry Run) - Optional

Run inference on a few frames to verify the policy works. Skip this section if you want to go straight to deployment.

In [None]:
import cv2
import numpy as np
import av
from example_policies.robot_deploy.deploy_core.action_translator import ActionTranslator

# Only test on first N frames to save memory
NUM_TEST_FRAMES = 5

# Prepare action translator
action_translator = ActionTranslator(cfg)

print("Running quick test inference...")
print(f"Action mode: {action_translator.action_mode}")
print(f"Testing on first {NUM_TEST_FRAMES} frames...")

# Load parquet data
parquet_path = DATASET_DIR / f"data/chunk-000/episode_{EPISODE_INDEX:06d}.parquet"
df = pd.read_parquet(parquet_path)

# Open video files using PyAV (supports AV1)
video_keys = ["observation.images.rgb_static", "observation.images.rgb_left", "observation.images.rgb_right"]
video_paths = {
    key: DATASET_DIR / f"videos/chunk-000/{key}/episode_{EPISODE_INDEX:06d}.mp4"
    for key in video_keys
}

# Create video containers and frame generators
video_containers = {}
video_streams = {}
for key, path in video_paths.items():
    if path.exists():
        container = av.open(str(path))
        video_containers[key] = container
        video_streams[key] = container.decode(video=0)
        print(f"‚úÖ Opened: {key}")

# Reset policy
policy.reset()

with torch.inference_mode():
    for i in range(min(NUM_TEST_FRAMES, len(df))):
        row = df.iloc[i]
        
        # State from parquet - handle both scalar and array columns
        state_cols = sorted([c for c in df.columns if c.startswith("observation.state")])
        state_values = []
        for c in state_cols:
            val = row[c]
            if isinstance(val, np.ndarray):
                state_values.extend(val.flatten().tolist())
            else:
                state_values.append(float(val))
        state = torch.tensor(state_values, dtype=torch.float32).unsqueeze(0).to(device)
        
        obs = {"observation.state": state}
        
        # Images from video using PyAV
        for key in video_streams:
            try:
                frame = next(video_streams[key])
                # Convert to numpy RGB array, normalize to [0,1], reshape to (C,H,W)
                img = frame.to_ndarray(format="rgb24")
                img_tensor = torch.tensor(img, dtype=torch.float32).permute(2, 0, 1) / 255.0
                obs[key] = img_tensor.unsqueeze(0).to(device)
            except StopIteration:
                print(f"  ‚ö†Ô∏è {key}: ran out of frames")
        
        # Get policy prediction
        raw_action = policy.select_action(obs)
        
        # Get ground truth action
        action_cols = sorted([c for c in df.columns if c.startswith("action")])
        gt_action = np.array([float(row[c]) for c in action_cols], dtype=np.float32)
        
        print(f"\nFrame {i}:")
        print(f"  GT action[0:3]:   {gt_action[:3].tolist()}")
        print(f"  Pred action[0:3]: {raw_action[0, :3].tolist()}")

# Cleanup
for container in video_containers.values():
    container.close()

print(f"\n‚úÖ Test complete! Policy is working.")

## 5. (Optional) Visualize Ground Truth Actions

Preview what the ground truth actions look like for this episode.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Load parquet (already have it from section 4, but reload to be safe)
parquet_path = DATASET_DIR / f"data/chunk-000/episode_{EPISODE_INDEX:06d}.parquet"
df = pd.read_parquet(parquet_path)

# Extract actions
action_cols = sorted([c for c in df.columns if c.startswith("action")])
gt_stack = df[action_cols].values

# Create time array
times = np.arange(len(gt_stack)) / 10.0  # Assuming 10 fps

# Get action names from column names (strip "action." prefix)
action_names = [c.replace("action.", "") for c in action_cols]

D = gt_stack.shape[1]

fig, axes = plt.subplots(D, 1, figsize=(14, 2 * D), sharex=True)
if D == 1:
    axes = [axes]

for d in range(D):
    ax = axes[d]
    ax.plot(times, gt_stack[:, d], label="Ground Truth", alpha=0.8)
    ax.set_ylabel(action_names[d] if d < len(action_names) else f"dim {d}", fontsize=8)
    ax.grid(True, alpha=0.3)

axes[0].set_title(f"Episode {EPISODE_INDEX}: Ground Truth Actions")
axes[-1].set_xlabel("Time (s)")
plt.tight_layout()
plt.show()

print(f"‚úÖ Visualized {len(gt_stack)} frames, {D} action dimensions")

## 6. Connect to Robot

Establish connection to the robot service.

In [None]:
import grpc
from example_policies.robot_deploy.robot_io.robot_service import robot_service_pb2_grpc
from example_policies.robot_deploy.robot_io.robot_interface import RobotInterface
from example_policies.robot_deploy.robot_io.robot_client import RobotClient

# Connect to robot
print(f"Connecting to robot at {SERVER_ENDPOINT}...")
channel = grpc.insecure_channel(SERVER_ENDPOINT)
stub = robot_service_pb2_grpc.RobotServiceStub(channel)

# Create robot interface
robot_interface = RobotInterface(stub, cfg)

# Test connection by getting a snapshot
try:
    obs = robot_interface.get_observation(device, show=False)
    if obs:
        print("‚úÖ Connected to robot!")
    else:
        print("‚ö†Ô∏è Connected but no observation received")
except Exception as e:
    print(f"‚ùå Connection failed: {e}")

## 7. (Optional) Move Robot to Home Position

In [None]:
# Move robot to home position
try:
    response = robot_interface.move_home()
    print(f"‚úÖ Robot homing command sent")
    print(f"   Response: {response}")
except Exception as e:
    print(f"‚ùå Homing failed: {e}")

## 7.5 Debug: Compare Dataset vs Live Observations

Compare the format of observations from the dataset vs live robot to identify mismatches.

In [None]:
import av
import numpy as np
import matplotlib.pyplot as plt

# Get one frame from dataset
print("=" * 60)
print("DATASET OBSERVATION (from parquet + video)")
print("=" * 60)

# Load first frame from parquet
parquet_path = DATASET_DIR / f"data/chunk-000/episode_{EPISODE_INDEX:06d}.parquet"
df = pd.read_parquet(parquet_path)
row = df.iloc[0]

# State
state_cols = sorted([c for c in df.columns if c.startswith("observation.state")])
state_values = []
for c in state_cols:
    val = row[c]
    if isinstance(val, np.ndarray):
        state_values.extend(val.flatten().tolist())
    else:
        state_values.append(float(val))
dataset_state = torch.tensor(state_values, dtype=torch.float32).unsqueeze(0).to(device)

print(f"\nobservation.state:")
print(f"  Shape: {dataset_state.shape}")
print(f"  Dtype: {dataset_state.dtype}")
print(f"  Range: [{dataset_state.min():.4f}, {dataset_state.max():.4f}]")
print(f"  First 5 values: {dataset_state[0, :5].tolist()}")

# Images from video
video_keys = ["observation.images.rgb_static", "observation.images.rgb_left", "observation.images.rgb_right"]
dataset_images = {}

for key in video_keys:
    video_path = DATASET_DIR / f"videos/chunk-000/{key}/episode_{EPISODE_INDEX:06d}.mp4"
    if video_path.exists():
        container = av.open(str(video_path))
        frame = next(container.decode(video=0))
        img = frame.to_ndarray(format="rgb24")
        img_tensor = torch.tensor(img, dtype=torch.float32).permute(2, 0, 1) / 255.0
        dataset_images[key] = img_tensor.unsqueeze(0).to(device)
        container.close()
        
        print(f"\n{key}:")
        print(f"  Shape: {dataset_images[key].shape}")
        print(f"  Dtype: {dataset_images[key].dtype}")
        print(f"  Range: [{dataset_images[key].min():.4f}, {dataset_images[key].max():.4f}]")
        print(f"  Mean per channel (RGB): {dataset_images[key][0].mean(dim=(1,2)).tolist()}")

# Get live observation
print("\n" + "=" * 60)
print("LIVE OBSERVATION (from robot)")
print("=" * 60)

live_obs = robot_interface.get_observation(device, show=False)

if live_obs:
    # Check state
    if "observation.state" in live_obs:
        live_state = live_obs["observation.state"]
        print(f"\nobservation.state:")
        print(f"  Shape: {live_state.shape}")
        print(f"  Dtype: {live_state.dtype}")
        print(f"  Range: [{live_state.min():.4f}, {live_state.max():.4f}]")
        print(f"  First 5 values: {live_state[0, :5].tolist()}")
    else:
        print("\n‚ö†Ô∏è observation.state NOT FOUND in live obs")
        print(f"  Available keys: {list(live_obs.keys())}")
    
    # Check images
    for key in video_keys:
        if key in live_obs:
            live_img = live_obs[key]
            print(f"\n{key}:")
            print(f"  Shape: {live_img.shape}")
            print(f"  Dtype: {live_img.dtype}")
            print(f"  Range: [{live_img.min():.4f}, {live_img.max():.4f}]")
            print(f"  Mean per channel (RGB): {live_img[0].mean(dim=(1,2)).tolist()}")
        else:
            print(f"\n‚ö†Ô∏è {key} NOT FOUND in live obs")
    
    # Show all available keys
    print(f"\n\nAll live obs keys: {list(live_obs.keys())}")
else:
    print("‚ùå No observation received from robot")

# Visual comparison - show images side by side
print("\n" + "=" * 60)
print("VISUAL COMPARISON")
print("=" * 60)

if live_obs:
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    for i, key in enumerate(video_keys):
        # Dataset image
        if key in dataset_images:
            img_np = dataset_images[key][0].cpu().permute(1, 2, 0).numpy()
            axes[0, i].imshow(img_np)
            axes[0, i].set_title(f"Dataset: {key.split('.')[-1]}")
            axes[0, i].axis('off')
        
        # Live image
        if key in live_obs:
            img_np = live_obs[key][0].cpu().permute(1, 2, 0).numpy()
            # Clip to valid range for display
            img_np = np.clip(img_np, 0, 1)
            axes[1, i].imshow(img_np)
            axes[1, i].set_title(f"Live: {key.split('.')[-1]}")
            axes[1, i].axis('off')
    
    plt.tight_layout()
    plt.show()

# Check policy's expected input features
print("\n" + "=" * 60)
print("POLICY EXPECTED INPUT FEATURES")
print("=" * 60)

print(f"\nImage features expected by policy: {cfg.image_features}")
print(f"\nAll input features:")
for key, feature in cfg.input_features.items():
    print(f"  {key}: shape={feature.shape}")

# Check normalization stats
print("\n" + "=" * 60)
print("NORMALIZATION STATS (from training)")
print("=" * 60)

if hasattr(policy, 'normalize_inputs') and hasattr(policy.normalize_inputs, 'stats') and policy.normalize_inputs.stats is not None:
    for key, stats in policy.normalize_inputs.stats.items():
        print(f"\n{key}:")
        if hasattr(stats, 'mean') and stats.mean is not None:
            mean_vals = stats.mean.flatten()[:5].tolist() if stats.mean.numel() > 5 else stats.mean.flatten().tolist()
            print(f"  Mean (first 5): {mean_vals}")
        if hasattr(stats, 'std') and stats.std is not None:
            std_vals = stats.std.flatten()[:5].tolist() if stats.std.numel() > 5 else stats.std.flatten().tolist()
            print(f"  Std (first 5): {std_vals}")
else:
    print("Could not access normalization stats (stats is None or not available)")

# Compare shapes and ranges
print("\n" + "=" * 60)
print("COMPARISON SUMMARY")
print("=" * 60)

if live_obs:
    # State comparison
    if "observation.state" in live_obs:
        print(f"\nState shape match: {dataset_state.shape == live_obs['observation.state'].shape}")
        if dataset_state.shape != live_obs['observation.state'].shape:
            print(f"  Dataset: {dataset_state.shape}, Live: {live_obs['observation.state'].shape}")
    
    # Image comparison
    for key in video_keys:
        if key in live_obs and key in dataset_images:
            shape_match = dataset_images[key].shape == live_obs[key].shape
            range_diff = abs(dataset_images[key].max().item() - live_obs[key].max().item())
            mean_diff = abs(dataset_images[key].mean().item() - live_obs[key].mean().item())
            
            print(f"\n{key}:")
            print(f"  Shape match: {shape_match}")
            if not shape_match:
                print(f"    Dataset: {dataset_images[key].shape}, Live: {live_obs[key].shape}")
            print(f"  Max diff: {range_diff:.4f}")
            print(f"  Mean diff: {mean_diff:.4f}")
            print(f"    Dataset max: {dataset_images[key].max():.4f}, Live max: {live_obs[key].max():.4f}")
            print(f"    Dataset mean: {dataset_images[key].mean():.4f}, Live mean: {live_obs[key].mean():.4f}")

## 8. Deploy: Send Actions to Robot

‚ö†Ô∏è **This will move the robot!** Make sure the workspace is clear.

This cell replays the episode using observations from the dataset and sends the policy's predicted actions to the robot.

In [None]:
import time
import gc
import av
import numpy as np
import pandas as pd
from example_policies.robot_deploy.deploy_core.action_translator import ActionTranslator
from example_policies.utils.action_order import ActionMode

# Controller mode
CONTROLLER = RobotClient.CART_WAYPOINT  # Most stable

print("="*60)
print("‚ö†Ô∏è  ROBOT DEPLOYMENT FROM DATASET OBSERVATIONS")
print("="*60)
print(f"Episode: {EPISODE_INDEX}")
print(f"Frames: {len(episode_indices)}")
print(f"Frequency: {INFERENCE_FREQUENCY_HZ} Hz")
print(f"Controller: {CONTROLLER}")
print("="*60)
print("\nThis uses RECORDED observations from the dataset as policy input,")
print("but sends the PREDICTED actions to the robot.")

# Load parquet data for state (small, can stay in memory)
print("\nLoading episode state data from parquet...")
parquet_path = DATASET_DIR / f"data/chunk-000/episode_{EPISODE_INDEX:06d}.parquet"
df = pd.read_parquet(parquet_path)
print(f"‚úÖ Loaded {len(df)} frames from parquet")

# Get video paths
video_keys = ["observation.images.rgb_static", "observation.images.rgb_left", "observation.images.rgb_right"]
video_paths = {
    key: DATASET_DIR / f"videos/chunk-000/{key}/episode_{EPISODE_INDEX:06d}.mp4"
    for key in video_keys
}

# Open video files using PyAV (supports AV1)
print("Opening video files...")
video_containers = {}
video_streams = {}
for key, path in video_paths.items():
    if path.exists():
        container = av.open(str(path))
        video_containers[key] = container
        video_streams[key] = container.decode(video=0)
        print(f"  ‚úÖ {key}")
    else:
        print(f"  ‚ùå {key} not found")

confirm = input("\nType 'yes' to start deployment: ")
if confirm.lower() != 'yes':
    print("Deployment cancelled.")
    for container in video_containers.values():
        container.close()
else:
    print("\nüöÄ Starting deployment...")
    
    # Reset policy and action translator
    policy.reset()
    action_translator = ActionTranslator(cfg)
    
    period = 1.0 / INFERENCE_FREQUENCY_HZ
    
    try:
        for i in range(len(df)):
            start_time = time.time()
            
            # Build observation from parquet + video
            row = df.iloc[i]
            
            # State from parquet - handle both scalar and array columns
            state_cols = sorted([c for c in df.columns if c.startswith("observation.state")])
            state_values = []
            for c in state_cols:
                val = row[c]
                if isinstance(val, np.ndarray):
                    state_values.extend(val.flatten().tolist())
                else:
                    state_values.append(float(val))
            state = torch.tensor(state_values, dtype=torch.float32).unsqueeze(0).to(device)
            
            obs = {"observation.state": state}
            
            # Images from video using PyAV
            for key in video_streams:
                try:
                    frame = next(video_streams[key])
                    img = frame.to_ndarray(format="rgb24")
                    img_tensor = torch.tensor(img, dtype=torch.float32).permute(2, 0, 1) / 255.0
                    obs[key] = img_tensor.unsqueeze(0).to(device)
                except StopIteration:
                    print(f"  ‚ö†Ô∏è {key}: ran out of frames at {i}")
            
            with torch.inference_mode():
                # Get policy prediction
                raw_action = policy.select_action(obs)
                
                # Translate to robot command
                robot_action = action_translator.translate(raw_action, obs)
            
            # Send to robot
            robot_interface.send_action(
                robot_action,
                action_translator.action_mode,
                CONTROLLER
            )
            
            # Clear tensors
            del obs, raw_action, robot_action
            
            # Progress
            if i % 10 == 0:
                print(f"  Frame {i+1}/{len(df)}")
            
            # Wait for next cycle
            elapsed = time.time() - start_time
            sleep_time = max(0.0, period - elapsed)
            time.sleep(sleep_time)
        
        print(f"\n‚úÖ Deployment complete! Sent {len(df)} actions from dataset observations.")
        
    except KeyboardInterrupt:
        print("\n‚ö†Ô∏è Deployment interrupted by user.")
    except Exception as e:
        print(f"\n‚ùå Error during deployment: {e}")
        raise
    finally:
        # Close video containers
        for container in video_containers.values():
            container.close()
        gc.collect()

## 9. Cleanup

In [None]:
# Close the gRPC channel
channel.close()
print("Connection closed.")