From f644f600d742ac7c802250527139eeb68a036c64 Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Sun, 28 Apr 2024 22:17:27 +0200 Subject: [PATCH 01/10] cache task environment as docker images with separate tags --- run.py | 1 + sweagent/environment/swe_env.py | 58 +++++++++++++++++------- sweagent/environment/utils.py | 79 +++++++++++++++++++-------------- 3 files changed, 90 insertions(+), 48 deletions(-) diff --git a/run.py b/run.py index 905ef5c9b..0a7ef1280 100644 --- a/run.py +++ b/run.py @@ -447,6 +447,7 @@ def get_args(args=None) -> ScriptArguments: split="dev", verbose=True, install_environment=True, + cache_task_images=False, ), skip_existing=True, agent=AgentArguments( diff --git a/sweagent/environment/swe_env.py b/sweagent/environment/swe_env.py index b4a3594da..2dfc38b06 100644 --- a/sweagent/environment/swe_env.py +++ b/sweagent/environment/swe_env.py @@ -22,6 +22,7 @@ PROCESS_DONE_MARKER_END, PROCESS_DONE_MARKER_START, InvalidGithubURL, + image_exists, copy_anything_to_container, copy_file_to_container, format_trajectory_markdown, @@ -56,7 +57,7 @@ class EnvironmentArguments(FrozenSerializable): """Configure data sources and setup instructions for th environment in which we solve the tasks. """ - # Source of issue statement/problem statement. To run over a batch of issues: Path to a data file + # Source of issue statement/problem statement. To run over a batch of issues: Path to a data file # (`json`, `jsonl`) or directory. To run over single issue: github issue url or path to markdown file # with problem statement. data_path: str @@ -68,12 +69,13 @@ class EnvironmentArguments(FrozenSerializable): timeout: int = 35 verbose: bool = False no_mirror: bool = False + cache_task_images: bool = False # Custom environment setup. Currently only used when data_path points to a single issue. # This needs to be either a string pointing to a yaml file (with yaml, yml file extension) # or a shell script (with sh extension). # See https://github.com/princeton-nlp/SWE-agent/pull/153 for more information environment_setup: Optional[str] = None - # Only used when running on single issue. Path to local repository or github repository. + # Only used when running on single issue. Path to local repository or github repository. repo_path: str = "" @@ -81,13 +83,13 @@ class EnvironmentArguments(FrozenSerializable): class EnvHook: def on_init(self): ... - + def on_copy_repo_started(self, *, repo_type: str, repo_path: str): ... - + def on_install_env_started(self): ... - + class SWEEnv(gym.Env): """Gym environment for SWE-bench. This class should handle all communication with the docker container.""" @@ -135,6 +137,14 @@ def __init__(self, args: EnvironmentArguments): self.image_name = args.image_name self._reset_container() + # Prepare image tag prefix for cached task environments + if self.args.cache_task_images: + logger.info("Task environment caching enabled") + tag = f"{self.args.data_path.replace('/', '_')}__{self.args.split}__{self.args.base_commit or 'head'}__" + assert len(tag) < 128, f"Cached image tag {tag} too long, probably due to long data path or base commit hash." + image_name_without_tag = self.image_name.split(":")[0] + self.cached_image_prefix = f"{image_name_without_tag}:{tag}" + # Set timeout self.timeout = self.args.timeout self.idx = 0 @@ -150,7 +160,7 @@ def _repo_name(self) -> str: """Name of the local copy of the repository""" assert self.record is not None return self.record["repo"].replace("/", "__") - + def _copy_repo(self) -> str: """Clone/copy repository/codebase in container Returns: @@ -216,6 +226,15 @@ def reset(self, index: Optional[int] = None, apply_test_patch: bool = False) -> self.reward = None ### Reset Container ### + cached_image = f"{self.cached_image_prefix}{index}" + if self.args.cache_task_images and cached_image is not None: + if image_exists(cached_image): + logger.info(f"Restore environment from cached image {cached_image}") + self.stop_container() # stop current container + self._init_container(cached_image=cached_image) + return None, info + else: + logger.info(f"Cached image {cached_image} not found, rebuilding task environment...") # Clone repository if not already cloned self.communicate(input="cd /") @@ -290,6 +309,8 @@ def reset(self, index: Optional[int] = None, apply_test_patch: bool = False) -> ) os.remove(path_to_patch) + self.container_obj.commit(cached_image) + logger.info(f"Container with environment {self.container_obj.id} cached as image {cached_image}") # Write any metadata to info if necessary return None, info @@ -383,6 +404,9 @@ def close(self): raise except: pass + self.stop_container() + + def stop_container(self): assert self.container is not None assert self.container_obj is not None self.container.terminate() @@ -420,10 +444,14 @@ def reset_container(self) -> None: self.container_obj = None self._reset_container() - def _init_container(self) -> None: + def _init_container(self, cached_image=None) -> None: """ Handles container initialization. Defines container name and creates it """ + image_name = self.image_name + if cached_image is not None: + image_name = cached_image + logger.info(f"Using cached image: {image_name}") if self.container_name is None: process_id = str(os.getpid()) current_time = str(datetime.datetime.now()) @@ -431,11 +459,11 @@ def _init_container(self) -> None: hash_object = hashlib.sha256(unique_string.encode()) # Cannot have colons/slashes in container name, but those are important in image names # i.e., when we want swe-agent to pull the image from dockerhub - image_name_sanitized = self.image_name.replace("/", "-") + image_name_sanitized = image_name.replace("/", "-") image_name_sanitized = image_name_sanitized.replace(":", "-") self.container_name = f"{image_name_sanitized}-{hash_object.hexdigest()[:10]}" self.container, self.parent_pids = get_container( - self.container_name, self.image_name, persistent=self.persistent + self.container_name, image_name, persistent=self.persistent ) try: client = docker.from_env() @@ -490,7 +518,7 @@ def _communicate_experimental( ) raise RuntimeError("Failed to communicate with container") - buffer, exit_code = read_with_timeout_experimental(self.container, timeout_duration) + buffer, exit_code = read_with_timeout_experimental(self.container, timeout_duration) self.returncode = int(exit_code) return buffer @@ -622,9 +650,9 @@ def run_shell_script(self, script_path: Path, *, location: str) -> None: elif location == "container": raise NotImplementedError raise ValueError(f"Invalid 'location': {location}") - + def _run_shell_script_host(self, script_path: Path) -> None: - """Run shell script file (located on host) in container""" + """Run shell script file (located on host) in container""" if not script_path.is_file(): raise FileNotFoundError(f"Script not found at {script_path}") shell_commands = Path(script_path).read_text().splitlines() @@ -828,7 +856,7 @@ def interrupt(self): def open_pr(self, *, trajectory, _dry_run: bool=False): """Create PR to repository - + Args: trajectory: Trajectory of actions taken by the agent _dry_run: Whether to actually push anything or just simulate it @@ -836,7 +864,7 @@ def open_pr(self, *, trajectory, _dry_run: bool=False): logger.info("Opening PR") # todo: have better way of handling this # Adding random string suffix to avoid name conflicts if we had a previously failed run - issue_url = self.args.data_path + issue_url = self.args.data_path try: issue = get_gh_issue_data(issue_url, token=self._github_token) except InvalidGithubURL as e: @@ -901,7 +929,7 @@ def open_pr(self, *, trajectory, _dry_run: bool=False): timeout_duration=10, ) body = ( - f"This is a PR opened by AI tool [SWE Agent](https://github.com/princeton-nlp/SWE-agent/) " + f"This is a PR opened by AI tool [SWE Agent](https://github.com/princeton-nlp/SWE-agent/) " f"to close [#{issue.number}]({issue_url}) ({issue.title}).\n\nCloses #{issue.number}." ) body += "\n\n" + format_trajectory_markdown(trajectory) diff --git a/sweagent/environment/utils.py b/sweagent/environment/utils.py index a4be13ffe..580ed1c7f 100644 --- a/sweagent/environment/utils.py +++ b/sweagent/environment/utils.py @@ -102,7 +102,7 @@ def copy_file_to_container(container, contents, container_path): def copy_anything_to_container(container, host_path: str, container_path: str) -> None: """Copy files or directories from host to container - + Note: Will need to set ownership on the copied files in the container. """ if not Path(host_path).exists(): @@ -349,12 +349,34 @@ def get_container(ctr_name: str, image_name: str, persistent: bool = False) -> T Returns: Container object """ - # Let's first check that the image exists and give some better error messages + if not image_exists(image_name): + msg = ( + f"Image {image_name} not found. Please ensure it is built and available. " + "Please double-check that you followed all installation/setup instructions from the " + "readme." + ) + raise RuntimeError(msg) + + if persistent: + return _get_persistent_container(ctr_name, image_name) + else: + return _get_non_persistent_container(ctr_name, image_name) + + +def image_exists(image_name): + """ + Check that the image exists and give some better error messages. + + Arguments: + image_name (str): Name of image + Returns: + bool: True if image exists + """ try: client = docker.from_env() except docker.errors.DockerException as e: docker_not_runnnig = any(( - "connection aborted" in str(e).lower(), + "connection aborted" in str(e).lower(), "connection refused" in str(e).lower(), "error while fetching server api version" in str(e).lower(), )) @@ -370,25 +392,16 @@ def get_container(ctr_name: str, image_name: str, persistent: bool = False) -> T raise filterred_images = client.images.list(filters={'reference': image_name}) if len(filterred_images) == 0: - msg = ( - f"Image {image_name} not found. Please ensure it is built and available. " - "Please double-check that you followed all installation/setup instructions from the " - "readme." - ) - raise RuntimeError(msg) + return False elif len(filterred_images) > 1: - logger.warning(f"Multiple images found for {image_name}, that's weird.") - attrs = filterred_images[0].attrs + RuntimeError(f"Multiple images found for {image_name}, that's weird.") + attrs = filterred_images[0].attrs if attrs is not None: logger.info( f"Found image {image_name} with tags: {attrs['RepoTags']}, created: {attrs['Created']} " f"for {attrs['Os']} {attrs['Architecture']}." ) - - if persistent: - return _get_persistent_container(ctr_name, image_name) - else: - return _get_non_persistent_container(ctr_name, image_name) + return True def get_commit(api: GhApi, owner: str, repo: str, base_commit: str = None): @@ -446,7 +459,7 @@ def get_problem_statement_from_github_issue(owner: str, repo: str, issue_number: class InstanceBuilder: def __init__(self, token: Optional[str] = None): - """This helper class is used to build the data for an instance object, + """This helper class is used to build the data for an instance object, retrieving problem statements from github issues or local files and setting repo paths from github urls or local paths. """ @@ -460,14 +473,14 @@ def set_problem_statement_from_gh_issue(self, issue_url: str): self.args["problem_statement"] = get_problem_statement_from_github_issue(owner, repo, issue_number, token=self.token) self.args["instance_id"] = f"{owner}__{repo}-i{issue_number}" self.args["problem_statement_source"] = "online" - + def set_problem_statement_from_file(self, file_path: str): self.args["problem_statement"] = Path(file_path).read_text() self.args["instance_id"] = hashlib.sha256(self.args["problem_statement"].encode()).hexdigest()[:6] self.args["problem_statement_source"] = "local" def set_problem_statement(self, data_path: str ): - """Get problem statement for a single instance from a github issue url or a + """Get problem statement for a single instance from a github issue url or a path to a markdown or text file. """ if is_github_issue_url(data_path): @@ -477,7 +490,7 @@ def set_problem_statement(self, data_path: str ): else: msg = f"Not sure how to get problem statement from {data_path=}." raise ValueError(msg) - + def set_repo_info_from_gh_url(self, url: str, base_commit: Optional[str] = None): owner, repo = parse_gh_repo_url(url) self.args["repo"] = f"{owner}/{repo}" @@ -488,7 +501,7 @@ def set_repo_info_from_gh_url(self, url: str, base_commit: Optional[str] = None) api = GhApi(token=self.token) self.args["base_commit"] = get_commit(api, owner, repo, base_commit).sha self.args["version"] = self.args["base_commit"][:7] - + def set_repo_info_from_local_path(self, path: str, base_commit: Optional[str] = None): self.args["repo"] = str(Path(path).resolve()) self.args["repo_type"] = "local" @@ -505,7 +518,7 @@ def set_repo_info_from_local_path(self, path: str, base_commit: Optional[str] = raise ValueError(msg) self.args["base_commit"] = repo.head.object.hexsha self.args["version"] = self.args["base_commit"][:7] - + def set_repo_info(self, repo: str, base_commit: Optional[str] = None): if is_github_repo_url(repo): self.set_repo_info_from_gh_url(repo, base_commit=base_commit) @@ -513,10 +526,10 @@ def set_repo_info(self, repo: str, base_commit: Optional[str] = None): self.set_repo_info_from_local_path(repo, base_commit=base_commit) else: raise ValueError(f"Could not determine repo path from {repo=}.") - + def set_from_dict(self, instance_dict: Dict[str, Any]): self.args |= instance_dict - + def set_missing_fields(self): # todo: This field is only needed while swe_env is using some questionable logic # to determine whether to clone from a mirror or not. This should be removed in the future. @@ -524,9 +537,9 @@ def set_missing_fields(self): # 'online' (loaded from github issue or similar) or 'local' (loaded from local file) if "problem_statement_source" not in self.args: self.args["problem_statement_source"] = "swe-bench" - if "repo_type" not in self.args: + if "repo_type" not in self.args: self.args["repo_type"] = "github" - + def validate(self): required_fields = [ "problem_statement", @@ -544,17 +557,17 @@ def validate(self): raise ValueError(f"Invalid repo type: {self.args['repo_type']=}") if self.args["repo_type"] == "github" and self.args["repo"].count("/") != 1: raise ValueError(f"Invalid repo format for {self.args['repo_type']=}: {self.args['repo']=}") - + def build(self) -> Dict[str, Any]: self.set_missing_fields() self.validate() return self.args - + def get_instances( - file_path: str, - base_commit: Optional[str] = None, - split: Optional[str] = None, + file_path: str, + base_commit: Optional[str] = None, + split: Optional[str] = None, token: Optional[str] = None, *, repo_path: str = "", @@ -603,7 +616,7 @@ def postproc_instance_list(instances): raise ValueError(f"Could not determine repo path from {file_path=}, {repo_path=}") return [ib.build()] - + if base_commit is not None: raise ValueError("base_commit must be None if data_path is not a github issue url") @@ -682,6 +695,6 @@ def format_trajectory_markdown(trajectory: List[Dict[str, str]]): suffix = [ "", "", - ] + ] return "\n".join(prefix) + "\n\n---\n\n".join(steps) + "\n".join(suffix) From 815e12cd9b15e5b73fdc174413841736f0eb4d67 Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Mon, 29 Apr 2024 00:12:19 +0200 Subject: [PATCH 02/10] save env vars inside the task image before docker commit, debug timing --- sweagent/environment/swe_env.py | 51 ++++++++++++++++++++++++--------- 1 file changed, 38 insertions(+), 13 deletions(-) diff --git a/sweagent/environment/swe_env.py b/sweagent/environment/swe_env.py index 2dfc38b06..e832c07db 100644 --- a/sweagent/environment/swe_env.py +++ b/sweagent/environment/swe_env.py @@ -136,6 +136,7 @@ def __init__(self, args: EnvironmentArguments): # Establish connection with execution container self.image_name = args.image_name self._reset_container() + self._timings = [] # Prepare image tag prefix for cached task environments if self.args.cache_task_images: @@ -155,6 +156,14 @@ def add_hook(self, hook: EnvHook): hook.on_init() self.hooks.append(hook) + def update_timings(self, dt: float): + self._timings.append(dt) + cached = "_cached" if self.args.cache_task_images else "" + fname = f"env_timings_{self.args.data_path.replace('/', '_')}_{self.args.split}{cached}.yaml" + with open(fname, "w") as f: + yaml.dump(self._timings, f) + logger.debug(f"Total time spent in the environment preparation: {sum(self._timings):.1f} sec") + @property def _repo_name(self) -> str: """Name of the local copy of the repository""" @@ -212,6 +221,7 @@ def reset(self, index: Optional[int] = None, apply_test_patch: bool = False) -> observation (`str`) - output from container info (`dict`) - additional information (e.g. debugging information) """ + dt = time.monotonic() info = {} info["commit_sha"] = self.commit_sha @@ -232,6 +242,12 @@ def reset(self, index: Optional[int] = None, apply_test_patch: bool = False) -> logger.info(f"Restore environment from cached image {cached_image}") self.stop_container() # stop current container self._init_container(cached_image=cached_image) + self.communicate("export $(xargs error_msg="Failed to install flake8 (lint library)" ) - # Apply test patch for oracle setting + envs = self.communicate("env") + logger.debug(f"Environment variables to save:\n{envs}\n") + self.communicate("env >> /.env") + self.container_obj.commit(cached_image) + logger.info(f"Container with environment {self.container_obj.id} cached as image {cached_image}") + + self.update_timings(time.monotonic() - dt) + if apply_test_patch: - path_to_patch = "test.patch" - with open(path_to_patch, "w") as f: - f.write(self.record["test_patch"]) - subprocess.run( + self._apply_test_patch() + # Write any metadata to info if necessary + return None, info + + def _apply_test_patch(self): + """ + Apply test patch for oracle setting + """ + path_to_patch = "test.patch" + with open(path_to_patch, "w") as f: + f.write(self.record["test_patch"]) + subprocess.run( f"docker cp {path_to_patch} {self.container_name}:/root/test.patch", shell=True, ) - self.communicate_with_handling( + self.communicate_with_handling( input="git apply /root/test.patch", error_msg="Failed to apply test patch correctly" ) - os.remove(path_to_patch) - - self.container_obj.commit(cached_image) - logger.info(f"Container with environment {self.container_obj.id} cached as image {cached_image}") - - # Write any metadata to info if necessary - return None, info + os.remove(path_to_patch) def step(self, action: str) -> Tuple[Optional[str], int, bool, dict]: """ From e75d8243e06ff0c5573b323810b1eb92df1f4dce Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Mon, 29 Apr 2024 11:44:24 +0200 Subject: [PATCH 03/10] increase docker api timeout to afford long commits --- sweagent/environment/swe_env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sweagent/environment/swe_env.py b/sweagent/environment/swe_env.py index e832c07db..b00f8fa88 100644 --- a/sweagent/environment/swe_env.py +++ b/sweagent/environment/swe_env.py @@ -491,7 +491,7 @@ def _init_container(self, cached_image=None) -> None: self.container_name, image_name, persistent=self.persistent ) try: - client = docker.from_env() + client = docker.from_env(timeout=600) except docker.errors.DockerException as e: if "Error while fetching server API version" in str(e): raise RuntimeError( From 438a2aa9b598f9eaf29328f245eb3a9247407e38 Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Mon, 29 Apr 2024 23:20:42 +0200 Subject: [PATCH 04/10] fix --- sweagent/environment/swe_env.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sweagent/environment/swe_env.py b/sweagent/environment/swe_env.py index b00f8fa88..2df1ac312 100644 --- a/sweagent/environment/swe_env.py +++ b/sweagent/environment/swe_env.py @@ -236,9 +236,10 @@ def reset(self, index: Optional[int] = None, apply_test_patch: bool = False) -> self.reward = None ### Reset Container ### - cached_image = f"{self.cached_image_prefix}{index}" - if self.args.cache_task_images and cached_image is not None: - if image_exists(cached_image): + + if self.args.cache_task_images: + cached_image = f"{self.cached_image_prefix}{index}" + if cached_image is not None and image_exists(cached_image): logger.info(f"Restore environment from cached image {cached_image}") self.stop_container() # stop current container self._init_container(cached_image=cached_image) From 388643d892fe543a6cc0523b5028d2658243407c Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Mon, 29 Apr 2024 23:25:05 +0200 Subject: [PATCH 05/10] fix --- sweagent/environment/swe_env.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sweagent/environment/swe_env.py b/sweagent/environment/swe_env.py index 2df1ac312..6feaceffc 100644 --- a/sweagent/environment/swe_env.py +++ b/sweagent/environment/swe_env.py @@ -311,11 +311,12 @@ def reset(self, index: Optional[int] = None, apply_test_patch: bool = False) -> error_msg="Failed to install flake8 (lint library)" ) - envs = self.communicate("env") - logger.debug(f"Environment variables to save:\n{envs}\n") - self.communicate("env >> /.env") - self.container_obj.commit(cached_image) - logger.info(f"Container with environment {self.container_obj.id} cached as image {cached_image}") + if self.args.cache_task_images: + envs = self.communicate("env") + logger.debug(f"Environment variables to save:\n{envs}\n") + self.communicate("env >> /.env") + self.container_obj.commit(cached_image) + logger.info(f"Container with environment {self.container_obj.id} cached as image {cached_image}") self.update_timings(time.monotonic() - dt) From 9556c4ac21f50bf33e0de543e96126802b8e8a0d Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Mon, 6 May 2024 22:08:06 +0200 Subject: [PATCH 06/10] remove timing collection code --- sweagent/environment/swe_env.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/sweagent/environment/swe_env.py b/sweagent/environment/swe_env.py index aa0a6cb55..9556e99b7 100644 --- a/sweagent/environment/swe_env.py +++ b/sweagent/environment/swe_env.py @@ -156,14 +156,6 @@ def add_hook(self, hook: EnvHook): hook.on_init() self.hooks.append(hook) - def update_timings(self, dt: float): - self._timings.append(dt) - cached = "_cached" if self.args.cache_task_images else "" - fname = f"env_timings_{self.args.data_path.replace('/', '_')}_{self.args.split}{cached}.yaml" - with open(fname, "w") as f: - yaml.dump(self._timings, f) - logger.debug(f"Total time spent in the environment preparation: {sum(self._timings):.1f} sec") - @property def _repo_name(self) -> str: """Name of the local copy of the repository""" @@ -221,7 +213,6 @@ def reset(self, index: Optional[int] = None, apply_test_patch: bool = False) -> observation (`str`) - output from container info (`dict`) - additional information (e.g. debugging information) """ - dt = time.monotonic() info = {} info["commit_sha"] = self.commit_sha @@ -246,7 +237,6 @@ def reset(self, index: Optional[int] = None, apply_test_patch: bool = False) -> self.communicate("export $(xargs self.container_obj.commit(cached_image) logger.info(f"Container with environment {self.container_obj.id} cached as image {cached_image}") - self.update_timings(time.monotonic() - dt) - if apply_test_patch: self._apply_test_patch() # Write any metadata to info if necessary From c115007ac74f73e3b2873c160fa9b18bb3499da9 Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Mon, 6 May 2024 22:23:53 +0200 Subject: [PATCH 07/10] some cleanup --- sweagent/environment/swe_env.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sweagent/environment/swe_env.py b/sweagent/environment/swe_env.py index 9556e99b7..0b69fc1d3 100644 --- a/sweagent/environment/swe_env.py +++ b/sweagent/environment/swe_env.py @@ -230,13 +230,13 @@ def reset(self, index: Optional[int] = None, apply_test_patch: bool = False) -> if self.args.cache_task_images: cached_image = f"{self.cached_image_prefix}{index}" - if cached_image is not None and image_exists(cached_image): + if image_exists(cached_image): logger.info(f"Restore environment from cached image {cached_image}") self.stop_container() # stop current container self._init_container(cached_image=cached_image) self.communicate("export $(xargs Date: Mon, 6 May 2024 22:29:15 +0200 Subject: [PATCH 08/10] remove timings storage --- sweagent/environment/swe_env.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sweagent/environment/swe_env.py b/sweagent/environment/swe_env.py index 0b69fc1d3..ca818ced1 100644 --- a/sweagent/environment/swe_env.py +++ b/sweagent/environment/swe_env.py @@ -136,7 +136,6 @@ def __init__(self, args: EnvironmentArguments): # Establish connection with execution container self.image_name = args.image_name self._reset_container() - self._timings = [] # Prepare image tag prefix for cached task environments if self.args.cache_task_images: From fef3d32be7c85974c8094e3c55c22da4056ae348 Mon Sep 17 00:00:00 2001 From: Oleh Shliazhko Date: Mon, 6 May 2024 22:35:57 +0200 Subject: [PATCH 09/10] use close func to stop container --- sweagent/environment/swe_env.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/sweagent/environment/swe_env.py b/sweagent/environment/swe_env.py index ca818ced1..15a05a465 100644 --- a/sweagent/environment/swe_env.py +++ b/sweagent/environment/swe_env.py @@ -231,7 +231,7 @@ def reset(self, index: Optional[int] = None, apply_test_patch: bool = False) -> cached_image = f"{self.cached_image_prefix}{index}" if image_exists(cached_image): logger.info(f"Restore environment from cached image {cached_image}") - self.stop_container() # stop current container + self.close() # stop current container self._init_container(cached_image=cached_image) self.communicate("export $(xargs Date: Mon, 13 May 2024 23:36:57 +0200 Subject: [PATCH 10/10] address review comment, type hint --- sweagent/environment/swe_env.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sweagent/environment/swe_env.py b/sweagent/environment/swe_env.py index 65d641547..b8eb381da 100644 --- a/sweagent/environment/swe_env.py +++ b/sweagent/environment/swe_env.py @@ -462,9 +462,10 @@ def reset_container(self) -> None: self.container_obj = None self._reset_container() - def _init_container(self, cached_image=None) -> None: + def _init_container(self, cached_image: Optional[str] = None) -> None: """ - Handles container initialization. Defines container name and creates it + Handles container initialization. Defines container name and creates it. + If cached_image is provided, it will use that image name instead of the default. """ image_name = self.image_name if cached_image is not None: