-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speed up evaluation by caching task environments as docker images #317
Merged
Merged
Changes from 11 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
f644f60
cache task environment as docker images with separate tags
ollmer 815e12c
save env vars inside the task image before docker commit, debug timing
ollmer e75d824
increase docker api timeout to afford long commits
ollmer 438a2aa
fix
ollmer 388643d
fix
ollmer ab59b47
Merge pull request #1 from princeton-nlp/main
ollmer e53dea8
Merge remote-tracking branch 'origin/main' into cached_task_environments
ollmer 9556c4a
remove timing collection code
ollmer c115007
some cleanup
ollmer bd42b26
remove timings storage
ollmer fef3d32
use close func to stop container
ollmer ffba574
Merge remote-tracking branch 'upstream/main' into cached_task_environ…
ollmer 3d28971
address review comment, type hint
ollmer File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ | |
PROCESS_DONE_MARKER_END, | ||
PROCESS_DONE_MARKER_START, | ||
InvalidGithubURL, | ||
image_exists, | ||
copy_anything_to_container, | ||
copy_file_to_container, | ||
format_trajectory_markdown, | ||
|
@@ -56,7 +57,7 @@ | |
class EnvironmentArguments(FrozenSerializable): | ||
"""Configure data sources and setup instructions for th environment in which we solve the tasks. | ||
""" | ||
# Source of issue statement/problem statement. To run over a batch of issues: Path to a data file | ||
# Source of issue statement/problem statement. To run over a batch of issues: Path to a data file | ||
# (`json`, `jsonl`) or directory. To run over single issue: github issue url or path to markdown file | ||
# with problem statement or problem statement as text prefixed with `text://`. | ||
data_path: str | ||
|
@@ -68,26 +69,27 @@ class EnvironmentArguments(FrozenSerializable): | |
timeout: int = 35 | ||
verbose: bool = False | ||
no_mirror: bool = False | ||
cache_task_images: bool = False | ||
# Custom environment setup. Currently only used when data_path points to a single issue. | ||
# This needs to be either a string pointing to a yaml file (with yaml, yml file extension) | ||
# or a shell script (with sh extension). | ||
# See https://github.com/princeton-nlp/SWE-agent/pull/153 for more information | ||
environment_setup: Optional[str] = None | ||
# Only used when running on single issue. Path to local repository or github repository. | ||
# Only used when running on single issue. Path to local repository or github repository. | ||
repo_path: str = "" | ||
|
||
|
||
|
||
class EnvHook: | ||
def on_init(self): | ||
... | ||
|
||
def on_copy_repo_started(self, *, repo_type: str, repo_path: str): | ||
... | ||
|
||
def on_install_env_started(self): | ||
... | ||
|
||
|
||
class SWEEnv(gym.Env): | ||
"""Gym environment for SWE-bench. This class should handle all communication with the docker container.""" | ||
|
@@ -135,6 +137,14 @@ def __init__(self, args: EnvironmentArguments): | |
self.image_name = args.image_name | ||
self._reset_container() | ||
|
||
# Prepare image tag prefix for cached task environments | ||
if self.args.cache_task_images: | ||
logger.info("Task environment caching enabled") | ||
tag = f"{self.args.data_path.replace('/', '_')}__{self.args.split}__{self.args.base_commit or 'head'}__" | ||
assert len(tag) < 128, f"Cached image tag {tag} too long, probably due to long data path or base commit hash." | ||
image_name_without_tag = self.image_name.split(":")[0] | ||
self.cached_image_prefix = f"{image_name_without_tag}:{tag}" | ||
|
||
# Set timeout | ||
self.timeout = self.args.timeout | ||
self.idx = 0 | ||
|
@@ -150,7 +160,7 @@ def _repo_name(self) -> str: | |
"""Name of the local copy of the repository""" | ||
assert self.record is not None | ||
return self.record["repo"].replace("/", "__") | ||
|
||
def _copy_repo(self) -> str: | ||
"""Clone/copy repository/codebase in container | ||
Returns: | ||
|
@@ -217,6 +227,21 @@ def reset(self, index: Optional[int] = None, apply_test_patch: bool = False) -> | |
|
||
### Reset Container ### | ||
|
||
if self.args.cache_task_images: | ||
cached_image = f"{self.cached_image_prefix}{index}" | ||
if image_exists(cached_image): | ||
logger.info(f"Restore environment from cached image {cached_image}") | ||
self.close() # stop current container | ||
self._init_container(cached_image=cached_image) | ||
self.communicate("export $(xargs </.env)") | ||
envs = self.communicate("env") | ||
logger.debug(f"Environment variables restored from the image:\n{envs}\n") | ||
if apply_test_patch: | ||
self._apply_test_patch() | ||
return None, info | ||
else: | ||
logger.info(f"Cached image {cached_image} not found, rebuilding task environment...") | ||
|
||
# Clone repository if not already cloned | ||
self.communicate(input="cd /") | ||
folders = self.communicate(input="ls").split("\n") | ||
|
@@ -275,24 +300,34 @@ def reset(self, index: Optional[int] = None, apply_test_patch: bool = False) -> | |
error_msg="Failed to install flake8 (lint library)" | ||
) | ||
|
||
# Apply test patch for oracle setting | ||
if self.args.cache_task_images: | ||
envs = self.communicate("env") | ||
logger.debug(f"Environment variables to save:\n{envs}\n") | ||
self.communicate("env >> /.env") | ||
self.container_obj.commit(cached_image) | ||
logger.info(f"Container with environment {self.container_obj.id} cached as image {cached_image}") | ||
|
||
if apply_test_patch: | ||
path_to_patch = "test.patch" | ||
with open(path_to_patch, "w") as f: | ||
f.write(self.record["test_patch"]) | ||
subprocess.run( | ||
self._apply_test_patch() | ||
# Write any metadata to info if necessary | ||
return None, info | ||
|
||
def _apply_test_patch(self): | ||
""" | ||
Apply test patch for oracle setting | ||
""" | ||
path_to_patch = "test.patch" | ||
with open(path_to_patch, "w") as f: | ||
f.write(self.record["test_patch"]) | ||
subprocess.run( | ||
f"docker cp {path_to_patch} {self.container_name}:/root/test.patch", | ||
shell=True, | ||
) | ||
self.communicate_with_handling( | ||
self.communicate_with_handling( | ||
input="git apply /root/test.patch", | ||
error_msg="Failed to apply test patch correctly" | ||
) | ||
os.remove(path_to_patch) | ||
|
||
|
||
# Write any metadata to info if necessary | ||
return None, info | ||
os.remove(path_to_patch) | ||
|
||
def step(self, action: str) -> Tuple[Optional[str], int, bool, dict]: | ||
""" | ||
|
@@ -420,25 +455,29 @@ def reset_container(self) -> None: | |
self.container_obj = None | ||
self._reset_container() | ||
|
||
def _init_container(self) -> None: | ||
def _init_container(self, cached_image=None) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's put type hints There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
""" | ||
Handles container initialization. Defines container name and creates it | ||
""" | ||
image_name = self.image_name | ||
if cached_image is not None: | ||
image_name = cached_image | ||
logger.info(f"Using cached image: {image_name}") | ||
if self.container_name is None: | ||
process_id = str(os.getpid()) | ||
current_time = str(datetime.datetime.now()) | ||
unique_string = current_time + process_id | ||
hash_object = hashlib.sha256(unique_string.encode()) | ||
# Cannot have colons/slashes in container name, but those are important in image names | ||
# i.e., when we want swe-agent to pull the image from dockerhub | ||
image_name_sanitized = self.image_name.replace("/", "-") | ||
image_name_sanitized = image_name.replace("/", "-") | ||
image_name_sanitized = image_name_sanitized.replace(":", "-") | ||
self.container_name = f"{image_name_sanitized}-{hash_object.hexdigest()[:10]}" | ||
self.container, self.parent_pids = get_container( | ||
self.container_name, self.image_name, persistent=self.persistent | ||
self.container_name, image_name, persistent=self.persistent | ||
) | ||
try: | ||
client = docker.from_env() | ||
client = docker.from_env(timeout=600) | ||
except docker.errors.DockerException as e: | ||
if "Error while fetching server API version" in str(e): | ||
raise RuntimeError( | ||
|
@@ -490,7 +529,7 @@ def _communicate_experimental( | |
) | ||
raise RuntimeError("Failed to communicate with container") | ||
|
||
buffer, exit_code = read_with_timeout_experimental(self.container, timeout_duration) | ||
buffer, exit_code = read_with_timeout_experimental(self.container, timeout_duration) | ||
self.returncode = int(exit_code) | ||
return buffer | ||
|
||
|
@@ -622,9 +661,9 @@ def run_shell_script(self, script_path: Path, *, location: str) -> None: | |
elif location == "container": | ||
raise NotImplementedError | ||
raise ValueError(f"Invalid 'location': {location}") | ||
|
||
def _run_shell_script_host(self, script_path: Path) -> None: | ||
"""Run shell script file (located on host) in container""" | ||
"""Run shell script file (located on host) in container""" | ||
if not script_path.is_file(): | ||
raise FileNotFoundError(f"Script not found at {script_path}") | ||
shell_commands = Path(script_path).read_text().splitlines() | ||
|
@@ -828,15 +867,15 @@ def interrupt(self): | |
|
||
def open_pr(self, *, trajectory, _dry_run: bool=False): | ||
"""Create PR to repository | ||
|
||
Args: | ||
trajectory: Trajectory of actions taken by the agent | ||
_dry_run: Whether to actually push anything or just simulate it | ||
""" | ||
logger.info("Opening PR") | ||
# todo: have better way of handling this | ||
# Adding random string suffix to avoid name conflicts if we had a previously failed run | ||
issue_url = self.args.data_path | ||
issue_url = self.args.data_path | ||
try: | ||
issue = get_gh_issue_data(issue_url, token=self._github_token) | ||
except InvalidGithubURL as e: | ||
|
@@ -901,7 +940,7 @@ def open_pr(self, *, trajectory, _dry_run: bool=False): | |
timeout_duration=10, | ||
) | ||
body = ( | ||
f"This is a PR opened by AI tool [SWE Agent](https://github.com/princeton-nlp/SWE-agent/) " | ||
f"This is a PR opened by AI tool [SWE Agent](https://github.com/princeton-nlp/SWE-agent/) " | ||
f"to close [#{issue.number}]({issue_url}) ({issue.title}).\n\nCloses #{issue.number}." | ||
) | ||
body += "\n\n" + format_trajectory_markdown(trajectory) | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Data path can now be all kinds of things, including the full text of the problem statement. I could push some code to work around these things.
Though I also wonder if
data_path
is really what we should make this depend on. Perhaps we could rather intercept the actual setup stages and hash the setup config or something like that?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some dataset fingerprints would be even better, I agree. Also, there is a limit of 128 characters for the docker image tag, so this would help stay within the limit with any dataset. I will try to implement the idea with the config hash.