diff --git a/.gitignore b/.gitignore index 5557c585..452ec69b 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,4 @@ aic_utils/aic_isaac/aic_isaaclab/source/aic_task/aic_task/tasks/manager_based/ai # Training checkpoints (too large for git) checkpoints/ +wandb/ diff --git a/aic_example_policies/aic_example_policies/ros/RunRLT.py b/aic_example_policies/aic_example_policies/ros/RunRLT.py index 1d88acd1..3deb5e83 100644 --- a/aic_example_policies/aic_example_policies/ros/RunRLT.py +++ b/aic_example_policies/aic_example_policies/ros/RunRLT.py @@ -139,6 +139,8 @@ def __init__(self, parent_node: Node): ("policy_args.prompt_align", _PHASE_PROMPT_DEFAULTS["align"]), ("policy_args.prompt_insert", _PHASE_PROMPT_DEFAULTS["insert"]), ("policy_args.prompt_verify", _PHASE_PROMPT_DEFAULTS["verify"]), + # Debug: bypass actor and execute VLA reference actions directly. + ("policy_args.vla_only", False), ] for name, default in _params: try: @@ -159,6 +161,10 @@ def __init__(self, parent_node: Node): "insert": parent_node.get_parameter("policy_args.prompt_insert").value, "verify": parent_node.get_parameter("policy_args.prompt_verify").value, } + self._vla_only = parent_node.get_parameter("policy_args.vla_only").value + if self._vla_only: + self.get_logger().info("VLA-only mode: bypassing actor, executing VLA refs directly") + port_pose = list(parent_node.get_parameter("policy_args.port_pose_xyzquat").value) if len(port_pose) == 7: self._port_pos = np.asarray(port_pose[0:3], dtype=np.float64) @@ -385,6 +391,8 @@ def _estimate_phase(self, prop: np.ndarray) -> str: def _maybe_switch_prompt(self, phase: str) -> None: """Call backend.set_instruction when the phase changes.""" + if self._port_pos is None: + return if phase == self._current_phase: return prompt = self._phase_prompts.get(phase) @@ -397,6 +405,43 @@ def _maybe_switch_prompt(self, phase: str) -> None: except Exception as e: self.get_logger().warn(f"set_instruction failed on phase={phase}: {e}") + # Max position delta per control step (m). Prevents large jumps that + # generate torques exceeding Gazebo's effort limits. + _MAX_POS_DELTA: float = 0.005 # 5 mm per step @ 20 Hz = 10 cm/s max + _MAX_ROT_DELTA: float = 0.05 # ~3° per step @ 20 Hz = 60°/s max + + def _clamp_action(self, action: np.ndarray, obs: "Observation") -> np.ndarray: + """Clamp action so the commanded pose is close to the current TCP pose. + + CheatCode works because it interpolates smoothly over 100 steps. + Without clamping, the actor can command a pose far from the current + one → huge joint torques → effort clamped to 0 → arm collapses. + """ + current_pos = np.array([ + obs.controller_state.tcp_pose.position.x, + obs.controller_state.tcp_pose.position.y, + obs.controller_state.tcp_pose.position.z, + ], dtype=np.float32) + + # Clamp position delta + pos_delta = action[0:3] - current_pos + pos_norm = np.linalg.norm(pos_delta) + if pos_norm > self._MAX_POS_DELTA: + action = action.copy() + action[0:3] = current_pos + pos_delta * (self._MAX_POS_DELTA / pos_norm) + + # Clamp rotation delta (in 6D space — limit the norm of the change) + from aic_rlt.vla.xvla_wrapper import quat_to_rot6d + q = obs.controller_state.tcp_pose.orientation + current_rot6d = quat_to_rot6d(np.array([q.x, q.y, q.z, q.w], dtype=np.float32)) + rot_delta = action[3:9] - current_rot6d + rot_norm = np.linalg.norm(rot_delta) + if rot_norm > self._MAX_ROT_DELTA: + action = action.copy() if not action.flags.writeable else action + action[3:9] = current_rot6d + rot_delta * (self._MAX_ROT_DELTA / rot_norm) + + return action + def _action_to_motion_update(self, action: np.ndarray) -> MotionUpdate: """Convert a 9-dim action [xyz, r1, r2] to a MotionUpdate. @@ -417,11 +462,12 @@ def _action_to_motion_update(self, action: np.ndarray) -> MotionUpdate: z=float(quat[2]), w=float(quat[3]), ), ) + # Match CheatCode defaults (policy.py:set_pose_target) motion_update.target_stiffness = np.diag( - [85.0, 85.0, 85.0, 50.0, 50.0, 50.0] + [90.0, 90.0, 90.0, 50.0, 50.0, 50.0] ).flatten().tolist() motion_update.target_damping = np.diag( - [75.0, 75.0, 75.0, 20.0, 20.0, 20.0] + [50.0, 50.0, 50.0, 20.0, 20.0, 20.0] ).flatten().tolist() motion_update.feedforward_wrench_at_tip = Wrench( force=Vector3(x=0.0, y=0.0, z=0.0), @@ -471,14 +517,19 @@ def insert_cable( self._maybe_switch_prompt(self._estimate_phase(prop_np)) z_rl, prop, ref_chunk = self._encode_rl_state(obs) - with torch.no_grad(): - action_chunk_t = self.actor.get_mean(z_rl, prop, ref_chunk) # (1, C, D) - action_chunk = action_chunk_t.squeeze(0).cpu().numpy() # (C, D) + if self._vla_only: + # Execute VLA reference actions directly (debug mode) + action_chunk = ref_chunk.squeeze(0).cpu().numpy() # (C, D) + else: + with torch.no_grad(): + action_chunk_t = self.actor.get_mean(z_rl, prop, ref_chunk) # (1, C, D) + action_chunk = action_chunk_t.squeeze(0).cpu().numpy() # (C, D) chunk_step = 0 action = action_chunk[chunk_step] chunk_step += 1 + action = self._clamp_action(action, obs) move_robot(motion_update=self._action_to_motion_update(action)) send_feedback(f"RLT step {chunk_step}/{self.CHUNK_LENGTH}") diff --git a/aic_utils/aic_rlt/aic_rlt/trainer.py b/aic_utils/aic_rlt/aic_rlt/trainer.py index fd8c6070..289211be 100644 --- a/aic_utils/aic_rlt/aic_rlt/trainer.py +++ b/aic_utils/aic_rlt/aic_rlt/trainer.py @@ -155,7 +155,7 @@ class RLTConfig: # TD3 target network EMA coefficient tau: float = 0.005 # BC regularizer coefficient β (equation (5)) - bc_coeff: float = 5.0 + bc_coeff: float = 0.0 # TD3 target policy noise target_policy_noise: float = 0.2 target_noise_clip: float = 0.5 diff --git a/aic_utils/aic_rlt/scripts/train.py b/aic_utils/aic_rlt/scripts/train.py index 7bf3ec40..9bc6e9ca 100644 --- a/aic_utils/aic_rlt/scripts/train.py +++ b/aic_utils/aic_rlt/scripts/train.py @@ -212,7 +212,7 @@ def parse_args(): parser.add_argument("--load_checkpoint", type=str, default="") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") # Hyperparameter overrides - parser.add_argument("--bc_coeff", type=float, default=5.0) + parser.add_argument("--bc_coeff", type=float, default=0.0) parser.add_argument("--n_warmup_steps", type=int, default=2000) parser.add_argument("--total_env_steps", type=int, default=50000) parser.add_argument("--hidden_dims", type=int, nargs="+", default=[256, 256])