Skip to content

Commit

Permalink
Speed up evaluation by caching task environments as docker images (#317)
Browse files Browse the repository at this point in the history
* cache task environment as docker images with separate tags

* save env vars inside the task image before docker commit, debug timing

* increase docker api timeout to afford long commits

* fix

* fix

* remove timing collection code

* some cleanup

* remove timings storage

* use close func to stop container

* address review comment, type hint
  • Loading branch information
ollmer committed May 27, 2024
1 parent b625105 commit 2a0c164
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 62 deletions.
1 change: 1 addition & 0 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ def get_args(args=None) -> ScriptArguments:
split="dev",
verbose=True,
install_environment=True,
cache_task_images=False,
),
skip_existing=True,
agent=AgentArguments(
Expand Down
94 changes: 67 additions & 27 deletions sweagent/environment/swe_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
PROCESS_DONE_MARKER_END,
PROCESS_DONE_MARKER_START,
InvalidGithubURL,
image_exists,
copy_anything_to_container,
copy_file_to_container,
format_trajectory_markdown,
Expand Down Expand Up @@ -56,7 +57,7 @@
class EnvironmentArguments(FrozenSerializable):
"""Configure data sources and setup instructions for the environment in which we solve the tasks.
"""
# Source of issue statement/problem statement. To run over a batch of issues: Path to a data file
# Source of issue statement/problem statement. To run over a batch of issues: Path to a data file
# (`json`, `jsonl`) or directory. To run over single issue: github issue url or path to markdown file
# with problem statement or problem statement as text prefixed with `text://`.
data_path: str
Expand All @@ -70,23 +71,24 @@ class EnvironmentArguments(FrozenSerializable):
timeout: int = 35
verbose: bool = False
no_mirror: bool = False
cache_task_images: bool = False
# Custom environment setup. Currently only used when data_path points to a single issue.
# This needs to be either a string pointing to a yaml file (with yaml, yml file extension)
# or a shell script (with sh extension).
# See https://github.com/princeton-nlp/SWE-agent/pull/153 for more information
environment_setup: Optional[str] = None
# Only used when running on single issue. Path to local repository or github repository.
# Only used when running on single issue. Path to local repository or github repository.
repo_path: str = ""



class EnvHook:
def on_init(self):
...

def on_copy_repo_started(self, *, repo_type: str, repo_path: str):
...

def on_install_env_started(self):
...

Expand Down Expand Up @@ -140,6 +142,14 @@ def __init__(self, args: EnvironmentArguments):
self.image_name = args.image_name
self._reset_container()

# Prepare image tag prefix for cached task environments
if self.args.cache_task_images:
logger.info("Task environment caching enabled")
tag = f"{self.args.data_path.replace('/', '_')}__{self.args.split}__{self.args.base_commit or 'head'}__"
assert len(tag) < 128, f"Cached image tag {tag} too long, probably due to long data path or base commit hash."
image_name_without_tag = self.image_name.split(":")[0]
self.cached_image_prefix = f"{image_name_without_tag}:{tag}"

# Set timeout
self.timeout = self.args.timeout
self.idx = 0
Expand All @@ -155,7 +165,7 @@ def _repo_name(self) -> str:
"""Name of the local copy of the repository"""
assert self.record is not None
return self.record["repo"].replace("/", "__")

def _copy_repo(self) -> str:
"""Clone/copy repository/codebase in container
Returns:
Expand Down Expand Up @@ -222,6 +232,21 @@ def reset(self, index: Optional[int] = None, apply_test_patch: bool = False) ->

### Reset Container ###

if self.args.cache_task_images:
cached_image = f"{self.cached_image_prefix}{index}"
if image_exists(cached_image):
logger.info(f"Restore environment from cached image {cached_image}")
self.close() # stop current container
self._init_container(cached_image=cached_image)
self.communicate("export $(xargs </.env)")
envs = self.communicate("env")
logger.debug(f"Environment variables restored from the image:\n{envs}\n")
if apply_test_patch:
self._apply_test_patch()
return None, info
else:
logger.info(f"Cached image {cached_image} not found, rebuilding task environment...")

# Clone repository if not already cloned
self.communicate(input="cd /")
folders = self.communicate(input="ls").split("\n")
Expand Down Expand Up @@ -280,24 +305,34 @@ def reset(self, index: Optional[int] = None, apply_test_patch: bool = False) ->
error_msg="Failed to install flake8 (lint library)"
)

# Apply test patch for oracle setting
if self.args.cache_task_images:
envs = self.communicate("env")
logger.debug(f"Environment variables to save:\n{envs}\n")
self.communicate("env >> /.env")
self.container_obj.commit(cached_image)
logger.info(f"Container with environment {self.container_obj.id} cached as image {cached_image}")

if apply_test_patch:
path_to_patch = "test.patch"
with open(path_to_patch, "w") as f:
f.write(self.record["test_patch"])
subprocess.run(
self._apply_test_patch()
# Write any metadata to info if necessary
return None, info

def _apply_test_patch(self):
"""
Apply test patch for oracle setting
"""
path_to_patch = "test.patch"
with open(path_to_patch, "w") as f:
f.write(self.record["test_patch"])
subprocess.run(
f"docker cp {path_to_patch} {self.container_name}:/root/test.patch",
shell=True,
)
self.communicate_with_handling(
self.communicate_with_handling(
input="git apply /root/test.patch",
error_msg="Failed to apply test patch correctly"
)
os.remove(path_to_patch)


# Write any metadata to info if necessary
return None, info
os.remove(path_to_patch)

def step(self, action: str) -> Tuple[Optional[str], int, bool, dict]:
"""
Expand Down Expand Up @@ -427,25 +462,30 @@ def reset_container(self) -> None:
self.container_obj = None
self._reset_container()

def _init_container(self) -> None:
def _init_container(self, cached_image: Optional[str] = None) -> None:
"""
Handles container initialization. Defines container name and creates it
Handles container initialization. Defines container name and creates it.
If cached_image is provided, it will use that image name instead of the default.
"""
image_name = self.image_name
if cached_image is not None:
image_name = cached_image
logger.info(f"Using cached image: {image_name}")
if self.container_name is None:
process_id = str(os.getpid())
current_time = str(datetime.datetime.now())
unique_string = current_time + process_id
hash_object = hashlib.sha256(unique_string.encode())
# Cannot have colons/slashes in container name, but those are important in image names
# i.e., when we want swe-agent to pull the image from dockerhub
image_name_sanitized = self.image_name.replace("/", "-")
image_name_sanitized = image_name.replace("/", "-")
image_name_sanitized = image_name_sanitized.replace(":", "-")
self.container_name = f"{image_name_sanitized}-{hash_object.hexdigest()[:10]}"
self.container, self.parent_pids = get_container(
self.container_name, self.image_name, persistent=self.persistent
self.container_name, image_name, persistent=self.persistent
)
try:
client = docker.from_env()
client = docker.from_env(timeout=600)
except docker.errors.DockerException as e:
if "Error while fetching server API version" in str(e):
raise RuntimeError(
Expand Down Expand Up @@ -502,7 +542,7 @@ def _communicate_experimental(
)
raise RuntimeError("Failed to communicate with container")

buffer, exit_code = read_with_timeout_experimental(self.container, timeout_duration)
buffer, exit_code = read_with_timeout_experimental(self.container, timeout_duration)
self.returncode = int(exit_code)
return buffer

Expand Down Expand Up @@ -634,9 +674,9 @@ def run_shell_script(self, script_path: Path, *, location: str) -> None:
elif location == "container":
raise NotImplementedError
raise ValueError(f"Invalid 'location': {location}")

def _run_shell_script_host(self, script_path: Path) -> None:
"""Run shell script file (located on host) in container"""
"""Run shell script file (located on host) in container"""
if not script_path.is_file():
raise FileNotFoundError(f"Script not found at {script_path}")
shell_commands = Path(script_path).read_text().splitlines()
Expand Down Expand Up @@ -840,15 +880,15 @@ def interrupt(self):

def open_pr(self, *, trajectory, _dry_run: bool=False):
"""Create PR to repository
Args:
trajectory: Trajectory of actions taken by the agent
_dry_run: Whether to actually push anything or just simulate it
"""
logger.info("Opening PR")
# todo: have better way of handling this
# Adding random string suffix to avoid name conflicts if we had a previously failed run
issue_url = self.args.data_path
issue_url = self.args.data_path
try:
issue = get_gh_issue_data(issue_url, token=self._github_token)
except InvalidGithubURL as e:
Expand Down Expand Up @@ -913,7 +953,7 @@ def open_pr(self, *, trajectory, _dry_run: bool=False):
timeout_duration=10,
)
body = (
f"This is a PR opened by AI tool [SWE Agent](https://github.com/princeton-nlp/SWE-agent/) "
f"This is a PR opened by AI tool [SWE Agent](https://github.com/princeton-nlp/SWE-agent/) "
f"to close [#{issue.number}]({issue_url}) ({issue.title}).\n\nCloses #{issue.number}."
)
body += "\n\n" + format_trajectory_markdown(trajectory)
Expand Down
Loading

0 comments on commit 2a0c164

Please sign in to comment.