Skip to content

Commit 273cf1b

Browse files
authored
fix: last_code_generated retrieval logic in multi-turn conversations (#1784)
* Fix last_code_generated retrieval logic in multi-turn conversations * refine comments
1 parent fcb78a9 commit 273cf1b

File tree

3 files changed

+30
-2
lines changed

3 files changed

+30
-2
lines changed

pandasai/core/code_generation/base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,16 @@ def generate_code(self, prompt: BasePrompt) -> str:
3131

3232
# Generate the code
3333
code = self._context.config.llm.generate_code(prompt, self._context)
34+
# Store the original generated code (for logging purposes)
3435
self._context.last_code_generated = code
3536
self._context.logger.log(f"Code Generated:\n{code}")
3637

37-
return self.validate_and_clean_code(code)
38+
# Validate and clean the code
39+
cleaned_code = self.validate_and_clean_code(code)
40+
# Update with the final cleaned code (for subsequent processing and multi-turn conversations)
41+
self._context.last_code_generated = cleaned_code
42+
43+
return cleaned_code
3844

3945
except Exception as e:
4046
error_message = f"An error occurred during code generation: {e}"

pandasai/core/prompts/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
def get_chat_prompt_for_sql(context: AgentState) -> BasePrompt:
2020
return GeneratePythonCodeWithSQLPrompt(
2121
context=context,
22-
last_code_generated=context.get("last_code_generated"),
22+
last_code_generated=context.last_code_generated,
2323
output_type=context.output_type,
2424
)
2525

tests/unit_tests/agent/test_agent.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,3 +542,25 @@ def test_handle_exception(self, agent):
542542
# Verify the error was logged
543543
mock_logger.log.assert_called_once()
544544
assert "Processing failed with error" in mock_logger.log.call_args[0][0]
545+
546+
def test_last_code_generated_retrieval(self, agent: Agent):
547+
"""Test that last_code_generated is correctly retrieved in get_chat_prompt_for_sql."""
548+
# Set last_code_generated
549+
test_code = "print('Test code')"
550+
agent._state.last_code_generated = test_code
551+
552+
# 使用 get_chat_prompt_for_sql 获取提示
553+
from pandasai.core.prompts import get_chat_prompt_for_sql
554+
555+
prompt = get_chat_prompt_for_sql(agent._state)
556+
557+
# 验证提示中使用了正确的 last_code_generated
558+
assert prompt.props["last_code_generated"] == test_code
559+
560+
# 验证不是从 intermediate_values 中获取的
561+
agent._state.add("last_code_generated", "Wrong code")
562+
prompt = get_chat_prompt_for_sql(agent._state)
563+
564+
# 应该仍然使用 last_code_generated 属性,而不是 intermediate_values 中的值
565+
assert prompt.props["last_code_generated"] == test_code
566+
assert prompt.props["last_code_generated"] != "Wrong code"

0 commit comments

Comments
 (0)