# Deploy Trained Diffusion Policy

Deploy trained diffusion policy in simulation.

In [8]:
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import dataset_to_policy_features
from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.configs.types import FeatureType
from lerobot.common.datasets.factory import resolve_delta_timestamps
from mujoco_env.y_env import SimpleEnv
import torch
from PIL import Image
import torchvision
import os
import numpy as np

## Load Policy

In [5]:
# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
    

Using device: cuda


In [6]:

# Load dataset metadata and prepare features
print("Loading dataset metadata...")
dataset_metadata = LeRobotDatasetMetadata("omy_pnp", root='./demo_data')
features = dataset_to_policy_features(dataset_metadata.features)
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
input_features = {key: ft for key, ft in features.items() if key not in output_features}
input_features.pop("observation.wrist_image")

# Configure and load the diffusion policy
print("Configuring diffusion policy...")
cfg = DiffusionConfig(
    input_features=input_features, 
    output_features=output_features, 
    horizon=8,  # Must match the training configuration
    n_action_steps=8
)

# Get delta timestamps for action chunking
delta_timestamps = resolve_delta_timestamps(cfg, dataset_metadata)

# Load the trained policy
ckpt_dir = './ckpt/diffusion_y'
print(f"Loading trained policy from checkpoint: {ckpt_dir}")

# Create a new policy with the configuration
policy = DiffusionPolicy(cfg, dataset_stats=dataset_metadata.stats)

# Check for model files
weights_path = os.path.join(ckpt_dir, "model.safetensors")
if not os.path.exists(weights_path):
    weights_path = os.path.join(ckpt_dir, "diffusion_pytorch_model.bin")
    
if os.path.exists(weights_path):
    print(f"Loading weights from {weights_path}")
    if weights_path.endswith('.safetensors'):
        from safetensors.torch import load_file
        state_dict = load_file(weights_path)
    else:
        state_dict = torch.load(weights_path, map_location=device)
    
    policy.load_state_dict(state_dict)
    print("Successfully loaded model weights")
else:
    print(f"Error: Weights file not found. Tried: model.safetensors, diffusion_pytorch_model.bin")
    
policy.to(device)
policy.eval()




Loading dataset metadata...
Configuring diffusion policy...
Loading trained policy from checkpoint: ./ckpt/diffusion_y
Loading weights from ./ckpt/diffusion_y/model.safetensors
Successfully loaded model weights


DiffusionPolicy(
  (normalize_inputs): Normalize(
    (buffer_observation_image): ParameterDict(
        (mean): Parameter containing: [torch.cuda.FloatTensor of size 3x1x1 (cuda:0)]
        (std): Parameter containing: [torch.cuda.FloatTensor of size 3x1x1 (cuda:0)]
    )
    (buffer_observation_state): ParameterDict(
        (max): Parameter containing: [torch.cuda.FloatTensor of size 6 (cuda:0)]
        (min): Parameter containing: [torch.cuda.FloatTensor of size 6 (cuda:0)]
    )
  )
  (normalize_targets): Normalize(
    (buffer_action): ParameterDict(
        (max): Parameter containing: [torch.cuda.FloatTensor of size 7 (cuda:0)]
        (min): Parameter containing: [torch.cuda.FloatTensor of size 7 (cuda:0)]
    )
  )
  (unnormalize_outputs): Unnormalize(
    (buffer_action): ParameterDict(
        (max): Parameter containing: [torch.cuda.FloatTensor of size 7 (cuda:0)]
        (min): Parameter containing: [torch.cuda.FloatTensor of size 7 (cuda:0)]
    )
  )
  (diffusion): Diff

## Load Environment

In [7]:
from mujoco_env.y_env import SimpleEnv
xml_path = './asset/example_scene_y.xml'
PnPEnv = SimpleEnv(xml_path, action_type='joint_angle')

/home/sjyoon/workspaces/lerobot-mujoco-tutorial/asset/example_scene_y.xml
['agentview', 'topview', 'sideview', 'egocentric']
name:[Tabletop] dt:[0.002] HZ:[500]
n_qpos:[24] n_qvel:[22] n_qacc:[22] n_ctrl:[10]

n_body:[21]
 [0/21] [world] mass:[0.00]kg
 [1/21] [front_object_table] mass:[1.00]kg
 [2/21] [camera] mass:[0.00]kg
 [3/21] [camera2] mass:[0.00]kg
 [4/21] [camera3] mass:[0.00]kg
 [5/21] [link1] mass:[2.06]kg
 [6/21] [link2] mass:[3.68]kg
 [7/21] [link3] mass:[2.39]kg
 [8/21] [link4] mass:[1.40]kg
 [9/21] [link5] mass:[1.40]kg
 [10/21] [link6] mass:[0.65]kg
 [11/21] [camera_center] mass:[0.00]kg
 [12/21] [tcp_link] mass:[0.32]kg
 [13/21] [rh_p12_rn_r1] mass:[0.07]kg
 [14/21] [rh_p12_rn_r2] mass:[0.02]kg
 [15/21] [rh_p12_rn_l1] mass:[0.07]kg
 [16/21] [rh_p12_rn_l2] mass:[0.02]kg
 [17/21] [body_obj_mug_5] mass:[0.00]kg
 [18/21] [object_mug_5] mass:[0.08]kg
 [19/21] [body_obj_plate_11] mass:[0.00]kg
 [20/21] [object_plate_11] mass:[0.10]kg
body_total_mass:[13.27]kg

n_geom:[83]
geo

## Roll-Out Your Policy

In [None]:
step = 0
PnPEnv.reset(seed=0)
policy.reset()
policy.eval()
save_image = True
img_transform = torchvision.transforms.ToTensor()

while PnPEnv.env.is_viewer_alive():
    PnPEnv.step_env()
    if PnPEnv.env.loop_every(HZ=20):
        # Check for task completion
        success = PnPEnv.check_success()
        if success:
            print('Success!')
            # Reset environment and policy
            policy.reset()
            PnPEnv.reset(seed=0)
            step = 0
            save_image = False
            action_trajectory = []
            continue
        
        # Get current state and images
        state = PnPEnv.get_ee_pose()
        image, wrist_image = PnPEnv.grab_image()
        
        # Process the image
        image = Image.fromarray(image)
        image = image.resize((256, 256))
        image = img_transform(image)
        
        # Prepare input data for the policy
        data = {
            'observation.state': torch.tensor([state], dtype=torch.float32).to(device),
            'observation.image': image.unsqueeze(0).to(device),
            'task': ['Put mug cup on the plate'],
            'timestamp': torch.tensor([step/20], dtype=torch.float32).to(device)
        }
        
        # Get action from policy
        with torch.no_grad():  # Disable gradient computation for inference
            action_traj = policy.select_action(data)
            
            # 출력 형태에 따라 적절히 처리
            if isinstance(action_traj, torch.Tensor):
                if len(action_traj.shape) >= 3:  # [batch, horizon, action_dim]
                    action = action_traj[0, 0].cpu().detach().numpy()
                elif len(action_traj.shape) == 2:  # [batch, action_dim]
                    action = action_traj[0].cpu().detach().numpy()
                elif len(action_traj.shape) == 1:  # [action_dim]
                    action = action_traj.cpu().detach().numpy()
                else:  # 스칼라
                    # 환경에 맞는 기본 액션 크기 생성 (예: 7차원)
                    action = np.zeros(7)  # 7은 예상되는 액션 차원
                    action[0] = action_traj.cpu().detach().numpy()  # 첫 번째 차원에 값 할당
            elif isinstance(action_traj, (list, tuple)):
                if isinstance(action_traj[0], torch.Tensor):
                    action = action_traj[0].cpu().detach().numpy()
                else:
                    action = np.array(action_traj[0])
            else:
                print(f"Unexpected action type: {type(action_traj)}")
                # 기본 액션 (제로 액션)
                action = np.zeros(7)
        
        # 액션 유효성 확인
        if not isinstance(action, np.ndarray):
            action = np.array(action)
        
        # 액션 차원이 스칼라인 경우 벡터로 확장
        if action.ndim == 0:
            action = np.zeros(7)  # 7은 예상되는 액션 차원
        
        print(f"Action shape: {action.shape}, Action values: {action}")
        
        # Execute action in environment
        try:
            _ = PnPEnv.step(action)
            PnPEnv.render()
        except Exception as e:
            print(f"Error executing action: {e}")
            print(f"Action: {action}, Type: {type(action)}, Shape: {action.shape if hasattr(action, 'shape') else 'Unknown'}")
        
        step += 1
        
        # Check for success after action
        if PnPEnv.check_success():
            print('Task completed successfully!')
            break
        
        # Optional: Add a step limit
        if step >= 1000:  # 50 seconds at 20Hz
            print('Time limit exceeded')
            break


DONE INITIALIZATION


NameError: name 'check_shape' is not defined