Skip to content

Commit

Permalink
Improved task resets. Resolved #14. Resolves #24.
Browse files Browse the repository at this point in the history
  • Loading branch information
stepjam committed Jan 20, 2020
1 parent 3990baf commit 10953da
Show file tree
Hide file tree
Showing 10 changed files with 17 additions and 8 deletions.
1 change: 1 addition & 0 deletions rlbench/backend/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def init_episode(self, index: int, randomly_place: bool=True,
break
except (BoundaryError, WaypointError) as e:
self._active_task.cleanup_()
self._active_task.restore_state(self._inital_task_state)
attempts += 1
if attempts >= max_attempts:
raise e
Expand Down
2 changes: 1 addition & 1 deletion rlbench/tasks/close_jar.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def init_episode(self, index: int) -> List[str]:
% target_color_name]

def variation_count(self) -> int:
return 2 * len(colors)
return len(colors)

def cleanup(self) -> None:
self.conditions = [NothingGrasped(self.robot.gripper)]
Expand Down
2 changes: 1 addition & 1 deletion rlbench/tasks/empty_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def init_episode(self, index: int) -> List[str]:
% target_color_name]

def variation_count(self) -> int:
return 2*len(colors)
return len(colors)

def cleanup(self) -> None:
[ob.remove() for ob in self.bin_objects if ob.still_exists()]
Expand Down
2 changes: 1 addition & 1 deletion rlbench/tasks/light_bulb_in.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def init_episode(self, index: int) -> List[str]:
% target_color_name]

def variation_count(self) -> int:
return 2 * len(colors)
return len(colors)

def step(self) -> None:
if DetectedCondition(self.bulbs[self._variation_index % 2],
Expand Down
2 changes: 1 addition & 1 deletion rlbench/tasks/light_bulb_out.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def init_episode(self, index: int) -> List[str]:
% target_color_name]

def variation_count(self) -> int:
return 2 * len(colors)
return len(colors)

def step(self) -> None:
if DetectedCondition(self.bulb, ProximitySensor('lamp_detector'),
Expand Down
2 changes: 1 addition & 1 deletion rlbench/tasks/open_jar.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def init_episode(self, index: int) -> List[str]:
'table' % target_color_name]

def variation_count(self) -> int:
return 2 * len(colors)
return len(colors)

def cleanup(self) -> None:
self.conditions = [NothingGrasped(self.robot.gripper)]
Expand Down
4 changes: 3 additions & 1 deletion rlbench/tasks/push_buttons.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from rlbench.backend.conditions import JointCondition, ConditionSet

MAX_TARGET_BUTTONS = 3
MAX_VARIATIONS = 50

# button top plate and wrapper will be be red before task completion
# and be changed to cyan upon success of task, so colors list used to randomly vary colors of
Expand Down Expand Up @@ -139,7 +140,8 @@ def init_episode(self, index: int) -> List[str]:
return [rtn0, rtn1, rtn2]

def variation_count(self) -> int:
return len(color_permutations) * MAX_TARGET_BUTTONS
return np.minimum(
len(color_permutations) * MAX_TARGET_BUTTONS, MAX_VARIATIONS)

def step(self) -> None:
for i in range(len(self.target_buttons)):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# Z: Patch version (e.g. small changes to tasks, bug fixes, etc)

setup(name='rlbench',
version='1.0.3',
version='1.0.4',
description='RLBench',
author='Stephen James',
author_email='slj12@ic.ac.uk',
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_run_task_validator(self):
active_task = task_class(sim, robot)
try:
task_smoke(active_task, scene, variation=-1,
test_demos=False)
test_demos=False, max_variations=9999)
except Exception as e:
sim.stop()
raise e
Expand Down
6 changes: 6 additions & 0 deletions tools/task_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import argparse

DEMO_ATTEMPTS = 5
MAX_VARIATIONS = 100


class TaskValidationError(Exception):
Expand All @@ -35,6 +36,11 @@ def task_smoke(task: Task, scene: Scene, variation=-1, demos=4, success=0.50,
raise TaskValidationError(
"The method 'variation_count' should return a number > 0.")

if variation_count > MAX_VARIATIONS:
raise TaskValidationError(
"This task had %d variations. Currently the limit is set to %d" %
(variation_count, MAX_VARIATIONS))

# Base rotation bounds
base_pos, base_ori = task.base_rotation_bounds()
if len(base_pos) != 3 or len(base_ori) != 3:
Expand Down

0 comments on commit 10953da

Please sign in to comment.