Skip to content

Commit

Permalink
modify bugfix (#3667)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinlu1248 authored May 3, 2024
2 parents 1242668 + 9b11910 commit 0d095db
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 23 deletions.
10 changes: 4 additions & 6 deletions sweepai/agents/modify.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,12 +1153,12 @@ def modify(
break
else:
logger.error("Max iterations reached")
# breakpoint()
diff_string = ""
for file_name, file_data in modify_files_dict.items():
diff = generate_diff(file_data['original_contents'], file_data['contents'])
if diff:
diff_string += f"\nChanges made to {file_name}:\n{diff}"
# print("\n".join([generate_diff(file_data["original_contents"], file_data["contents"]) for file_name, file_data in modify_files_dict.items()])) # adding this as a useful way to render the diffs
return modify_files_dict


Expand Down Expand Up @@ -1432,7 +1432,8 @@ def handle_function_call(

# Check if the changes are valid
if not error_message:
check_results = get_check_results(file_name, new_file_contents)
is_last_fcr_for_file = False # TODO: check if this is the last fcr for this file
check_results = get_check_results(file_name, new_file_contents, last_fcr_for_file=is_last_fcr_for_file)
check_results_message = check_results.is_worse_than_message(llm_state['initial_check_results'][file_name])
failing_parse = check_results.parse_error_message if not llm_state['initial_check_results'][file_name].parse_error_message else ""
current_diff = generate_diff(
Expand All @@ -1441,7 +1442,6 @@ def handle_function_call(
if failing_parse:
error_message = f"Error: Invalid code changes have been applied. You requested the following changes:\n\n```diff\n{current_diff}\n```\n\nBut it produces invalid code with the following error logs:\n```\n{failing_parse}\n```\n\n" + fix_syntax_prompt
# print(error_message)
# breakpoint()
break
elif check_results_message:
warning_message = check_results_message
Expand Down Expand Up @@ -1483,9 +1483,8 @@ def handle_function_call(
else:
llm_response = f"SUCCESS\n\nThe following changes have been applied:\n\n```diff\n{generate_diff(file_contents, new_file_contents, n=25)}\n```\nThe code changes also yield the following warnings:\n```\n{warning_message}\n```\n\n{linter_warning_prompt.format(current_task=llm_state['current_task'])}"

#dify_files_dict[file_name]['contents'] = new_file_contents
modify_files_dict[file_name]['contents'] = new_file_contents
llm_state["attempt_lazy_change"] = False # no longer attempt lazy change
# breakpoint()
elif llm_state["completed_changes_per_fcr"][current_fcr_index] + 1 < llm_state["changes_per_fcr"][current_fcr_index]:
# Incomplete changes, should use a different prompt realistically
llm_response = f"SUCCESS\n\nThe following changes have been applied:\n\n```diff\n{generate_diff(file_contents, new_file_contents, n=25)}\n```\n{self_review_prompt.format(current_task=llm_state['current_task'])}"
Expand All @@ -1495,7 +1494,6 @@ def handle_function_call(
llm_state["completed_changes_per_fcr"][current_fcr_index] += 1
elif diff_string.count("\n+") + diff_string.count("\n-") > 10:
llm_response = f"SUCCESS\n\nThe following changes have been applied:\n\n```diff\n{generate_diff(file_contents, new_file_contents, n=25)}\n```\n\n{self_review_prompt.format(current_task=llm_state['current_task'])}"
# breakpoint()
modify_files_dict[file_name]['contents'] = new_file_contents
llm_state["attempt_lazy_change"] = False # no longer attempt lazy change
else:
Expand Down
47 changes: 32 additions & 15 deletions sweepai/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,27 @@ def find_deepest_error(node: Node) -> Optional[Node]:
}
"""

pylint_args_non_last_fcr = [
"--disable=C",
"--enable=C0413", # Enable only the check for imports not at the top
"--disable=W0611", # Don't check unused import
"--disable=R",
"--disable=import-error",
"--disable=no-member",
]

# add a comment to all lines which are changed
pylint_args_last_fcr = [
"--disable=C",
"--enable=C0413",
"--enable=W0611", # Check unused import
"--disable=R",
"--disable=import-error",
"--disable=no-member",
]

@file_cache()
def get_pylint_check_results(file_path: str, code: str) -> CheckResults:
def get_pylint_check_results(file_path: str, code: str, last_fcr_for_file=False) -> CheckResults:
logger.debug(f"Running pylint on {file_path}...")
file_hash = uuid.uuid4().hex
new_file = os.path.join("/tmp", file_hash + "_" + os.path.basename(file_path))
Expand All @@ -442,15 +461,10 @@ def get_pylint_check_results(file_path: str, code: str) -> CheckResults:
f.write(code)
pylint_output = StringIO()
reporter = TextReporter(pylint_output)
# this allows us to have a more rigorous check for the last file change request
pylint_args = [new_file] + (pylint_args_last_fcr if last_fcr_for_file else pylint_args_non_last_fcr)
Run(
[
new_file,
"--disable=C",
"--enable=C0413", # Enable only the check for imports not at the top
"--disable=R",
"--disable=import-error",
"--disable=no-member",
],
pylint_args,
reporter=reporter,
exit=False,
)
Expand All @@ -467,14 +481,14 @@ def get_pylint_check_results(file_path: str, code: str) -> CheckResults:
logger.debug("Done running pylint.")
return CheckResults(pylint=error_message if not succeeded else "")

def get_check_results(file_path: str, code: str) -> CheckResults:
def get_check_results(file_path: str, code: str, last_fcr_for_file=False) -> CheckResults:
is_valid, error_message = check_syntax(file_path, code)
if not is_valid:
return CheckResults(parse_error_message=error_message)
ext = file_path.rsplit(".")[-1] # noqa
if ext == "py":
try:
return get_pylint_check_results(file_path, code)
return get_pylint_check_results(file_path, code, last_fcr_for_file=last_fcr_for_file)
except Exception as e:
logger.exception(e)
elif ext in ["js", "jsx", "ts", "tsx"]:
Expand Down Expand Up @@ -751,7 +765,9 @@ def get_function_name(file_name: str, source_code: str, line_number: int):
"""

if __name__ == "__main__":
python_code = """import math
python_code = """\
import math
import pandas
def get_circle_area(radius: float) -> float:
return math.pi * radius ** 2
Expand All @@ -763,6 +779,7 @@ def get_circle_area(radius: float) -> float:
# new_code = """console.log("hello world")"""
# check_results = check_syntax("test.js", new_code)
check_results = get_check_results("test.tsx", code)
import pdb
# pylint: disable=no-member
pdb.set_trace()
check_results = get_check_results("test.py", python_code)
assert check_results.pylint == "" # this should pass
check_results = get_check_results("test.py", python_code, last_fcr_for_file=True)
assert "Unused import pandas" in check_results.pylint # this should warn about unused imports
4 changes: 2 additions & 2 deletions tests/rerun_issue.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ def main(
better_stack_prefix: str = "https://logs.betterstack.com/team/199101/tail?rf=now-30m&q=metadata.issue_url%3A",
):
issue_url = issue_url or typer.prompt("Issue URL")
print(f"Fetching issue metdata...")
print("Fetching issue metdata...")
issue_request = fetch_issue_request(issue_url)
wait_for_server(host)
print(f"Sending request...")
print("Sending request...")
response = requests.post(
host,
json=issue_request.dict(),
Expand Down

0 comments on commit 0d095db

Please sign in to comment.