Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up evaluation by caching task environments as docker images #317

Merged
merged 13 commits into from
May 27, 2024
1 change: 1 addition & 0 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ def get_args(args=None) -> ScriptArguments:
split="dev",
verbose=True,
install_environment=True,
cache_task_images=False,
),
skip_existing=True,
agent=AgentArguments(
Expand Down
93 changes: 66 additions & 27 deletions sweagent/environment/swe_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
PROCESS_DONE_MARKER_END,
PROCESS_DONE_MARKER_START,
InvalidGithubURL,
image_exists,
copy_anything_to_container,
copy_file_to_container,
format_trajectory_markdown,
Expand Down Expand Up @@ -56,7 +57,7 @@
class EnvironmentArguments(FrozenSerializable):
"""Configure data sources and setup instructions for th environment in which we solve the tasks.
"""
# Source of issue statement/problem statement. To run over a batch of issues: Path to a data file
# Source of issue statement/problem statement. To run over a batch of issues: Path to a data file
# (`json`, `jsonl`) or directory. To run over single issue: github issue url or path to markdown file
# with problem statement or problem statement as text prefixed with `text://`.
data_path: str
Expand All @@ -68,26 +69,27 @@ class EnvironmentArguments(FrozenSerializable):
timeout: int = 35
verbose: bool = False
no_mirror: bool = False
cache_task_images: bool = False
# Custom environment setup. Currently only used when data_path points to a single issue.
# This needs to be either a string pointing to a yaml file (with yaml, yml file extension)
# or a shell script (with sh extension).
# See https://github.com/princeton-nlp/SWE-agent/pull/153 for more information
environment_setup: Optional[str] = None
# Only used when running on single issue. Path to local repository or github repository.
# Only used when running on single issue. Path to local repository or github repository.
repo_path: str = ""



class EnvHook:
def on_init(self):
...

def on_copy_repo_started(self, *, repo_type: str, repo_path: str):
...

def on_install_env_started(self):
...


class SWEEnv(gym.Env):
"""Gym environment for SWE-bench. This class should handle all communication with the docker container."""
Expand Down Expand Up @@ -135,6 +137,14 @@ def __init__(self, args: EnvironmentArguments):
self.image_name = args.image_name
self._reset_container()

# Prepare image tag prefix for cached task environments
if self.args.cache_task_images:
logger.info("Task environment caching enabled")
tag = f"{self.args.data_path.replace('/', '_')}__{self.args.split}__{self.args.base_commit or 'head'}__"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Data path can now be all kinds of things, including the full text of the problem statement. I could push some code to work around these things.

Though I also wonder if data_path is really what we should make this depend on. Perhaps we could rather intercept the actual setup stages and hash the setup config or something like that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some dataset fingerprints would be even better, I agree. Also, there is a limit of 128 characters for the docker image tag, so this would help stay within the limit with any dataset. I will try to implement the idea with the config hash.

assert len(tag) < 128, f"Cached image tag {tag} too long, probably due to long data path or base commit hash."
image_name_without_tag = self.image_name.split(":")[0]
self.cached_image_prefix = f"{image_name_without_tag}:{tag}"

# Set timeout
self.timeout = self.args.timeout
self.idx = 0
Expand All @@ -150,7 +160,7 @@ def _repo_name(self) -> str:
"""Name of the local copy of the repository"""
assert self.record is not None
return self.record["repo"].replace("/", "__")

def _copy_repo(self) -> str:
"""Clone/copy repository/codebase in container
Returns:
Expand Down Expand Up @@ -217,6 +227,21 @@ def reset(self, index: Optional[int] = None, apply_test_patch: bool = False) ->

### Reset Container ###

if self.args.cache_task_images:
cached_image = f"{self.cached_image_prefix}{index}"
if image_exists(cached_image):
logger.info(f"Restore environment from cached image {cached_image}")
self.close() # stop current container
self._init_container(cached_image=cached_image)
self.communicate("export $(xargs </.env)")
envs = self.communicate("env")
logger.debug(f"Environment variables restored from the image:\n{envs}\n")
if apply_test_patch:
self._apply_test_patch()
return None, info
else:
logger.info(f"Cached image {cached_image} not found, rebuilding task environment...")

# Clone repository if not already cloned
self.communicate(input="cd /")
folders = self.communicate(input="ls").split("\n")
Expand Down Expand Up @@ -275,24 +300,34 @@ def reset(self, index: Optional[int] = None, apply_test_patch: bool = False) ->
error_msg="Failed to install flake8 (lint library)"
)

# Apply test patch for oracle setting
if self.args.cache_task_images:
envs = self.communicate("env")
logger.debug(f"Environment variables to save:\n{envs}\n")
self.communicate("env >> /.env")
self.container_obj.commit(cached_image)
logger.info(f"Container with environment {self.container_obj.id} cached as image {cached_image}")

if apply_test_patch:
path_to_patch = "test.patch"
with open(path_to_patch, "w") as f:
f.write(self.record["test_patch"])
subprocess.run(
self._apply_test_patch()
# Write any metadata to info if necessary
return None, info

def _apply_test_patch(self):
"""
Apply test patch for oracle setting
"""
path_to_patch = "test.patch"
with open(path_to_patch, "w") as f:
f.write(self.record["test_patch"])
subprocess.run(
f"docker cp {path_to_patch} {self.container_name}:/root/test.patch",
shell=True,
)
self.communicate_with_handling(
self.communicate_with_handling(
input="git apply /root/test.patch",
error_msg="Failed to apply test patch correctly"
)
os.remove(path_to_patch)


# Write any metadata to info if necessary
return None, info
os.remove(path_to_patch)

def step(self, action: str) -> Tuple[Optional[str], int, bool, dict]:
"""
Expand Down Expand Up @@ -420,25 +455,29 @@ def reset_container(self) -> None:
self.container_obj = None
self._reset_container()

def _init_container(self) -> None:
def _init_container(self, cached_image=None) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's put type hints

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

"""
Handles container initialization. Defines container name and creates it
"""
image_name = self.image_name
if cached_image is not None:
image_name = cached_image
logger.info(f"Using cached image: {image_name}")
if self.container_name is None:
process_id = str(os.getpid())
current_time = str(datetime.datetime.now())
unique_string = current_time + process_id
hash_object = hashlib.sha256(unique_string.encode())
# Cannot have colons/slashes in container name, but those are important in image names
# i.e., when we want swe-agent to pull the image from dockerhub
image_name_sanitized = self.image_name.replace("/", "-")
image_name_sanitized = image_name.replace("/", "-")
image_name_sanitized = image_name_sanitized.replace(":", "-")
self.container_name = f"{image_name_sanitized}-{hash_object.hexdigest()[:10]}"
self.container, self.parent_pids = get_container(
self.container_name, self.image_name, persistent=self.persistent
self.container_name, image_name, persistent=self.persistent
)
try:
client = docker.from_env()
client = docker.from_env(timeout=600)
except docker.errors.DockerException as e:
if "Error while fetching server API version" in str(e):
raise RuntimeError(
Expand Down Expand Up @@ -490,7 +529,7 @@ def _communicate_experimental(
)
raise RuntimeError("Failed to communicate with container")

buffer, exit_code = read_with_timeout_experimental(self.container, timeout_duration)
buffer, exit_code = read_with_timeout_experimental(self.container, timeout_duration)
self.returncode = int(exit_code)
return buffer

Expand Down Expand Up @@ -622,9 +661,9 @@ def run_shell_script(self, script_path: Path, *, location: str) -> None:
elif location == "container":
raise NotImplementedError
raise ValueError(f"Invalid 'location': {location}")

def _run_shell_script_host(self, script_path: Path) -> None:
"""Run shell script file (located on host) in container"""
"""Run shell script file (located on host) in container"""
if not script_path.is_file():
raise FileNotFoundError(f"Script not found at {script_path}")
shell_commands = Path(script_path).read_text().splitlines()
Expand Down Expand Up @@ -828,15 +867,15 @@ def interrupt(self):

def open_pr(self, *, trajectory, _dry_run: bool=False):
"""Create PR to repository

Args:
trajectory: Trajectory of actions taken by the agent
_dry_run: Whether to actually push anything or just simulate it
"""
logger.info("Opening PR")
# todo: have better way of handling this
# Adding random string suffix to avoid name conflicts if we had a previously failed run
issue_url = self.args.data_path
issue_url = self.args.data_path
try:
issue = get_gh_issue_data(issue_url, token=self._github_token)
except InvalidGithubURL as e:
Expand Down Expand Up @@ -901,7 +940,7 @@ def open_pr(self, *, trajectory, _dry_run: bool=False):
timeout_duration=10,
)
body = (
f"This is a PR opened by AI tool [SWE Agent](https://github.com/princeton-nlp/SWE-agent/) "
f"This is a PR opened by AI tool [SWE Agent](https://github.com/princeton-nlp/SWE-agent/) "
f"to close [#{issue.number}]({issue_url}) ({issue.title}).\n\nCloses #{issue.number}."
)
body += "\n\n" + format_trajectory_markdown(trajectory)
Expand Down
Loading
Loading