Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ aic_utils/aic_isaac/aic_isaaclab/source/aic_task/aic_task/tasks/manager_based/ai

# Training checkpoints (too large for git)
checkpoints/
wandb/
61 changes: 56 additions & 5 deletions aic_example_policies/aic_example_policies/ros/RunRLT.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def __init__(self, parent_node: Node):
("policy_args.prompt_align", _PHASE_PROMPT_DEFAULTS["align"]),
("policy_args.prompt_insert", _PHASE_PROMPT_DEFAULTS["insert"]),
("policy_args.prompt_verify", _PHASE_PROMPT_DEFAULTS["verify"]),
# Debug: bypass actor and execute VLA reference actions directly.
("policy_args.vla_only", False),
]
for name, default in _params:
try:
Expand All @@ -159,6 +161,10 @@ def __init__(self, parent_node: Node):
"insert": parent_node.get_parameter("policy_args.prompt_insert").value,
"verify": parent_node.get_parameter("policy_args.prompt_verify").value,
}
self._vla_only = parent_node.get_parameter("policy_args.vla_only").value
if self._vla_only:
self.get_logger().info("VLA-only mode: bypassing actor, executing VLA refs directly")

port_pose = list(parent_node.get_parameter("policy_args.port_pose_xyzquat").value)
if len(port_pose) == 7:
self._port_pos = np.asarray(port_pose[0:3], dtype=np.float64)
Expand Down Expand Up @@ -385,6 +391,8 @@ def _estimate_phase(self, prop: np.ndarray) -> str:

def _maybe_switch_prompt(self, phase: str) -> None:
"""Call backend.set_instruction when the phase changes."""
if self._port_pos is None:
return
if phase == self._current_phase:
return
prompt = self._phase_prompts.get(phase)
Expand All @@ -397,6 +405,43 @@ def _maybe_switch_prompt(self, phase: str) -> None:
except Exception as e:
self.get_logger().warn(f"set_instruction failed on phase={phase}: {e}")

# Max position delta per control step (m). Prevents large jumps that
# generate torques exceeding Gazebo's effort limits.
_MAX_POS_DELTA: float = 0.005 # 5 mm per step @ 20 Hz = 10 cm/s max
_MAX_ROT_DELTA: float = 0.05 # ~3° per step @ 20 Hz = 60°/s max

def _clamp_action(self, action: np.ndarray, obs: "Observation") -> np.ndarray:
"""Clamp action so the commanded pose is close to the current TCP pose.

CheatCode works because it interpolates smoothly over 100 steps.
Without clamping, the actor can command a pose far from the current
one → huge joint torques → effort clamped to 0 → arm collapses.
"""
current_pos = np.array([
obs.controller_state.tcp_pose.position.x,
obs.controller_state.tcp_pose.position.y,
obs.controller_state.tcp_pose.position.z,
], dtype=np.float32)

# Clamp position delta
pos_delta = action[0:3] - current_pos
pos_norm = np.linalg.norm(pos_delta)
if pos_norm > self._MAX_POS_DELTA:
action = action.copy()
action[0:3] = current_pos + pos_delta * (self._MAX_POS_DELTA / pos_norm)

# Clamp rotation delta (in 6D space — limit the norm of the change)
from aic_rlt.vla.xvla_wrapper import quat_to_rot6d
q = obs.controller_state.tcp_pose.orientation
current_rot6d = quat_to_rot6d(np.array([q.x, q.y, q.z, q.w], dtype=np.float32))
rot_delta = action[3:9] - current_rot6d
rot_norm = np.linalg.norm(rot_delta)
if rot_norm > self._MAX_ROT_DELTA:
action = action.copy() if not action.flags.writeable else action
action[3:9] = current_rot6d + rot_delta * (self._MAX_ROT_DELTA / rot_norm)

return action

def _action_to_motion_update(self, action: np.ndarray) -> MotionUpdate:
"""Convert a 9-dim action [xyz, r1, r2] to a MotionUpdate.

Expand All @@ -417,11 +462,12 @@ def _action_to_motion_update(self, action: np.ndarray) -> MotionUpdate:
z=float(quat[2]), w=float(quat[3]),
),
)
# Match CheatCode defaults (policy.py:set_pose_target)
motion_update.target_stiffness = np.diag(
[85.0, 85.0, 85.0, 50.0, 50.0, 50.0]
[90.0, 90.0, 90.0, 50.0, 50.0, 50.0]
).flatten().tolist()
motion_update.target_damping = np.diag(
[75.0, 75.0, 75.0, 20.0, 20.0, 20.0]
[50.0, 50.0, 50.0, 20.0, 20.0, 20.0]
).flatten().tolist()
motion_update.feedforward_wrench_at_tip = Wrench(
force=Vector3(x=0.0, y=0.0, z=0.0),
Expand Down Expand Up @@ -471,14 +517,19 @@ def insert_cable(
self._maybe_switch_prompt(self._estimate_phase(prop_np))

z_rl, prop, ref_chunk = self._encode_rl_state(obs)
with torch.no_grad():
action_chunk_t = self.actor.get_mean(z_rl, prop, ref_chunk) # (1, C, D)
action_chunk = action_chunk_t.squeeze(0).cpu().numpy() # (C, D)
if self._vla_only:
# Execute VLA reference actions directly (debug mode)
action_chunk = ref_chunk.squeeze(0).cpu().numpy() # (C, D)
else:
with torch.no_grad():
action_chunk_t = self.actor.get_mean(z_rl, prop, ref_chunk) # (1, C, D)
action_chunk = action_chunk_t.squeeze(0).cpu().numpy() # (C, D)
chunk_step = 0

action = action_chunk[chunk_step]
chunk_step += 1

action = self._clamp_action(action, obs)
move_robot(motion_update=self._action_to_motion_update(action))
send_feedback(f"RLT step {chunk_step}/{self.CHUNK_LENGTH}")

Expand Down
2 changes: 1 addition & 1 deletion aic_utils/aic_rlt/aic_rlt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class RLTConfig:
# TD3 target network EMA coefficient
tau: float = 0.005
# BC regularizer coefficient β (equation (5))
bc_coeff: float = 5.0
bc_coeff: float = 0.0
# TD3 target policy noise
target_policy_noise: float = 0.2
target_noise_clip: float = 0.5
Expand Down
2 changes: 1 addition & 1 deletion aic_utils/aic_rlt/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def parse_args():
parser.add_argument("--load_checkpoint", type=str, default="")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
# Hyperparameter overrides
parser.add_argument("--bc_coeff", type=float, default=5.0)
parser.add_argument("--bc_coeff", type=float, default=0.0)
parser.add_argument("--n_warmup_steps", type=int, default=2000)
parser.add_argument("--total_env_steps", type=int, default=50000)
parser.add_argument("--hidden_dims", type=int, nargs="+", default=[256, 256])
Expand Down