Skip to content

Commit

Permalink
is_success consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Jun 10, 2024
1 parent 902502f commit 0d17e18
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion panda_gym/envs/tasks/flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _sample_object(self) -> Tuple[np.ndarray, np.ndarray]:
object_rotation = np.zeros(3)
return object_position, object_rotation

def is_success(self, achieved_goal: np.ndarray, desired_goal: np.ndarray) -> np.ndarray:
def is_success(self, achieved_goal: np.ndarray, desired_goal: np.ndarray, info: Dict[str, Any] = {}) -> np.ndarray:
d = angle_distance(achieved_goal, desired_goal)
return np.array(d < self.distance_threshold, dtype=bool)

Expand Down
2 changes: 1 addition & 1 deletion panda_gym/envs/tasks/pick_and_place.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _sample_object(self) -> np.ndarray:
object_position += noise
return object_position

def is_success(self, achieved_goal: np.ndarray, desired_goal: np.ndarray) -> np.ndarray:
def is_success(self, achieved_goal: np.ndarray, desired_goal: np.ndarray, info: Dict[str, Any] = {}) -> np.ndarray:
d = distance(achieved_goal, desired_goal)
return np.array(d < self.distance_threshold, dtype=bool)

Expand Down
2 changes: 1 addition & 1 deletion panda_gym/envs/tasks/push.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _sample_object(self) -> np.ndarray:
object_position += noise
return object_position

def is_success(self, achieved_goal: np.ndarray, desired_goal: np.ndarray) -> np.ndarray:
def is_success(self, achieved_goal: np.ndarray, desired_goal: np.ndarray, info: Dict[str, Any] = {}) -> np.ndarray:
d = distance(achieved_goal, desired_goal)
return np.array(d < self.distance_threshold, dtype=bool)

Expand Down
2 changes: 1 addition & 1 deletion panda_gym/envs/tasks/reach.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _sample_goal(self) -> np.ndarray:
goal = self.np_random.uniform(self.goal_range_low, self.goal_range_high)
return goal

def is_success(self, achieved_goal: np.ndarray, desired_goal: np.ndarray) -> np.ndarray:
def is_success(self, achieved_goal: np.ndarray, desired_goal: np.ndarray, info: Dict[str, Any] = {}) -> np.ndarray:
d = distance(achieved_goal, desired_goal)
return np.array(d < self.distance_threshold, dtype=bool)

Expand Down
2 changes: 1 addition & 1 deletion panda_gym/envs/tasks/slide.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _sample_object(self) -> np.ndarray:
object_position += noise
return object_position

def is_success(self, achieved_goal: np.ndarray, desired_goal: np.ndarray) -> np.ndarray:
def is_success(self, achieved_goal: np.ndarray, desired_goal: np.ndarray, info: Dict[str, Any] = {}) -> np.ndarray:
d = distance(achieved_goal, desired_goal)
return np.array(d < self.distance_threshold, dtype=bool)

Expand Down
2 changes: 1 addition & 1 deletion panda_gym/envs/tasks/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _sample_objects(self) -> Tuple[np.ndarray, np.ndarray]:
# if distance(object1_position, object2_position) > 0.1:
return object1_position, object2_position

def is_success(self, achieved_goal: np.ndarray, desired_goal: np.ndarray) -> np.ndarray:
def is_success(self, achieved_goal: np.ndarray, desired_goal: np.ndarray, info: Dict[str, Any] = {}) -> np.ndarray:
# must be vectorized !!
d = distance(achieved_goal, desired_goal)
return np.array((d < self.distance_threshold), dtype=bool)
Expand Down

0 comments on commit 0d17e18

Please sign in to comment.