diff --git a/python-ecosystem/inference-orchestrator/src/service/command/command_service.py b/python-ecosystem/inference-orchestrator/src/service/command/command_service.py index 9d81a5e2..f9d60f8c 100644 --- a/python-ecosystem/inference-orchestrator/src/service/command/command_service.py +++ b/python-ecosystem/inference-orchestrator/src/service/command/command_service.py @@ -717,6 +717,11 @@ async def _execute_summarize( # Intermediate text output final_result = item + else: + extracted = self._extract_agent_item_text(item) + if extracted is not None: + final_result = extracted + self._emit_event(event_callback, { "type": "progress", "step": self.MAX_STEPS_SUMMARIZE, @@ -724,49 +729,130 @@ async def _execute_summarize( "message": f"Summarization completed ({step_count} tool calls)" }) - # Process the result - if isinstance(final_result, SummarizeOutput): - logger.info("Successfully received structured summarize output") - return { - "summary": final_result.summary, - "diagram": final_result.diagram, - "diagramType": final_result.diagramType - } - elif isinstance(final_result, str) and final_result: - # Fallback: parse string result - logger.debug(f"Summarize raw result (first 500 chars): {final_result[:500] if final_result else 'None'}") - parsed = self._parse_json_response(final_result) - if parsed: - logger.info("Successfully parsed JSON response for summarize") - return { - "summary": parsed.get("summary", ""), - "diagram": parsed.get("diagram", ""), - "diagramType": parsed.get("diagramType", "MERMAID" if supports_mermaid else "ASCII") - } - else: - # Try regex extraction as last resort - extracted = self._extract_summary_field_fallback(final_result) - if extracted: - logger.warning("Used regex fallback to extract summary field") - return { - "summary": extracted, - "diagram": "", - "diagramType": "MERMAID" if supports_mermaid else "ASCII" - } - - logger.warning(f"JSON parsing failed for summarize, using raw result") - return { - "summary": final_result or "Failed to generate summary", - "diagram": "", - "diagramType": "MERMAID" if supports_mermaid else "ASCII" - } - else: - return {"error": "AI service returned an empty summary"} + result = self._coerce_summarize_final_result(final_result, supports_mermaid) + if "error" not in result: + return result + + logger.warning("Summarize streaming produced an empty final summary; retrying without output_schema") + self._emit_event(event_callback, { + "type": "status", + "state": "retrying", + "message": "Retrying summary generation" + }) + + raw_result = await self._run_agent_with_heartbeat( + agent=agent, + prompt=prompt, + event_callback=event_callback, + max_steps=self.MAX_STEPS_SUMMARIZE + ) + result = self._coerce_summarize_final_result(raw_result, supports_mermaid) + if "error" not in result: + return result + + logger.warning("Summarize agent retry also produced an empty summary; trying direct LLM fallback") + direct_response = await llm.ainvoke( + prompt + + "\n\nIf tool calls are unavailable, summarize from the context already provided. " + "Return a JSON object with non-empty 'summary', 'diagram', and 'diagramType' fields." + ) + return self._coerce_summarize_final_result(direct_response, supports_mermaid) except Exception as e: - logger.error(f"Summarize agent error: {e}", exc_info=True) - sanitized_msg = create_user_friendly_error(e) - return {"error": sanitized_msg} + logger.warning(f"Summarize streaming failed, retrying without output_schema: {e}", exc_info=True) + self._emit_event(event_callback, { + "type": "status", + "state": "retrying", + "message": "Retrying summary generation" + }) + try: + raw_result = await self._run_agent_with_heartbeat( + agent=agent, + prompt=prompt, + event_callback=event_callback, + max_steps=self.MAX_STEPS_SUMMARIZE + ) + result = self._coerce_summarize_final_result(raw_result, supports_mermaid) + if "error" not in result: + return result + + direct_response = await llm.ainvoke( + prompt + + "\n\nIf tool calls are unavailable, summarize from the context already provided. " + "Return a JSON object with non-empty 'summary', 'diagram', and 'diagramType' fields." + ) + return self._coerce_summarize_final_result(direct_response, supports_mermaid) + except Exception as retry_error: + logger.error(f"Summarize agent error: {retry_error}", exc_info=True) + sanitized_msg = create_user_friendly_error(retry_error) + return {"error": sanitized_msg} + + def _coerce_summarize_final_result( + self, + final_result: Any, + supports_mermaid: bool + ) -> Dict[str, Any]: + """Convert structured, dict, message, or text agent output into a summary dict.""" + diagram_type = "MERMAID" if supports_mermaid else "ASCII" + + if isinstance(final_result, SummarizeOutput): + logger.info("Successfully received structured summarize output") + return self._summary_or_empty_error( + summary=final_result.summary, + diagram=final_result.diagram, + diagram_type=final_result.diagramType or diagram_type, + ) + + if isinstance(final_result, dict) and "summary" in final_result: + return self._summary_or_empty_error( + summary=final_result.get("summary"), + diagram=final_result.get("diagram"), + diagram_type=final_result.get("diagramType") or diagram_type, + ) + + text = self._extract_agent_item_text(final_result) + if not self._has_usable_text(text): + return {"error": "AI service returned an empty summary"} + + logger.debug(f"Summarize raw result (first 500 chars): {str(text)[:500] if text else 'None'}") + parsed = self._parse_json_response(str(text)) + if parsed: + logger.info("Successfully parsed JSON response for summarize") + return self._summary_or_empty_error( + summary=parsed.get("summary"), + diagram=parsed.get("diagram"), + diagram_type=parsed.get("diagramType") or diagram_type, + ) + + extracted = self._extract_summary_field_fallback(str(text)) + if extracted: + logger.warning("Used regex fallback to extract summary field") + return self._summary_or_empty_error( + summary=extracted, + diagram="", + diagram_type=diagram_type, + ) + + logger.warning("JSON parsing failed for summarize, using raw result") + return self._summary_or_empty_error( + summary=str(text), + diagram="", + diagram_type=diagram_type, + ) + + def _summary_or_empty_error( + self, + summary: Any, + diagram: Any, + diagram_type: Any + ) -> Dict[str, Any]: + if not self._has_usable_text(summary): + return {"error": "AI service returned an empty summary"} + return { + "summary": str(summary), + "diagram": self._string_or_empty(diagram), + "diagramType": str(diagram_type or "ASCII"), + } def _extract_summary_field_fallback(self, text: str) -> Optional[str]: """ @@ -858,6 +944,11 @@ async def _execute_ask( # Intermediate text output final_result = item + else: + extracted = self._extract_agent_item_text(item) + if extracted is not None: + final_result = extracted + self._emit_event(event_callback, { "type": "progress", "step": self.MAX_STEPS_ASK, @@ -865,25 +956,119 @@ async def _execute_ask( "message": f"Completed ({step_count} tool calls)" }) - # Process the result - if isinstance(final_result, AskOutput): - logger.info("Successfully received structured ask output") - return {"answer": final_result.answer} - elif isinstance(final_result, str) and final_result: - # Fallback: parse string result - parsed = self._parse_json_response(final_result) - if parsed and "answer" in parsed: - return {"answer": parsed["answer"]} - else: - return {"answer": final_result} - else: - return {"error": "AI service returned an empty answer"} + result = self._coerce_ask_final_result(final_result) + if "error" not in result: + return result + + logger.warning("Ask streaming produced an empty final answer; retrying without output_schema") + self._emit_event(event_callback, { + "type": "status", + "state": "retrying", + "message": "Retrying answer generation" + }) + + raw_result = await self._run_agent_with_heartbeat( + agent=agent, + prompt=prompt, + event_callback=event_callback, + max_steps=self.MAX_STEPS_ASK + ) + result = self._coerce_ask_final_result(raw_result) + if "error" not in result: + return result + + logger.warning("Ask agent retry also produced an empty answer; trying direct LLM fallback") + direct_response = await llm.ainvoke( + prompt + + "\n\nIf tool calls are unavailable, answer from the context already provided. " + "Return a JSON object with a non-empty 'answer' field." + ) + return self._coerce_ask_final_result(direct_response) except Exception as e: logger.error(f"Ask agent error: {e}", exc_info=True) sanitized_msg = create_user_friendly_error(e) return {"error": sanitized_msg} + def _coerce_ask_final_result(self, final_result: Any) -> Dict[str, Any]: + """Convert structured, dict, message, or text agent output into an answer dict.""" + if isinstance(final_result, AskOutput): + logger.info("Successfully received structured ask output") + return self._answer_or_empty_error(final_result.answer) + + if isinstance(final_result, dict) and "answer" in final_result: + return self._answer_or_empty_error(final_result.get("answer")) + + text = self._extract_agent_item_text(final_result) + if not self._has_usable_text(text): + return {"error": "AI service returned an empty answer"} + + parsed = self._parse_json_response(str(text)) + if parsed and "answer" in parsed: + return self._answer_or_empty_error(parsed.get("answer")) + + return {"answer": str(text)} + + def _answer_or_empty_error(self, answer: Any) -> Dict[str, Any]: + if not self._has_usable_text(answer): + return {"error": "AI service returned an empty answer"} + return {"answer": str(answer)} + + def _extract_agent_item_text(self, item: Any) -> Optional[str]: + """Extract final text from common LangChain/mcp_use stream item shapes.""" + if item is None: + return None + + if isinstance(item, str): + return item + + if isinstance(item, dict): + for key in ("answer", "output", "final_output", "response", "result", "content", "text"): + if key in item: + return self._extract_agent_item_text(item.get(key)) + + messages = item.get("messages") + if isinstance(messages, list) and messages: + return self._extract_agent_item_text(messages[-1]) + + return None + + if hasattr(item, "content"): + return self._coerce_text_content(getattr(item, "content")) + + if hasattr(item, "model_dump"): + try: + dumped = item.model_dump() + if isinstance(dumped, dict): + return self._extract_agent_item_text(dumped) + except Exception: + return None + + return None + + def _coerce_text_content(self, content: Any) -> str: + """Convert provider content blocks to plain text.""" + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + parts = [] + for block in content: + if isinstance(block, str): + parts.append(block) + elif isinstance(block, dict): + text = block.get("text") or block.get("content") + if text is not None: + parts.append(str(text)) + elif hasattr(block, "text"): + parts.append(str(block.text)) + return "".join(parts) + if isinstance(content, dict): + text = content.get("text") or content.get("content") + return "" if text is None else str(text) + return str(content) + async def _run_agent_with_heartbeat( self, agent: MCPAgent, diff --git a/python-ecosystem/inference-orchestrator/tests/test_command_service.py b/python-ecosystem/inference-orchestrator/tests/test_command_service.py index f98397da..153e2224 100644 --- a/python-ecosystem/inference-orchestrator/tests/test_command_service.py +++ b/python-ecosystem/inference-orchestrator/tests/test_command_service.py @@ -7,7 +7,7 @@ """ import pytest import json -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch from service.command.command_service import CommandService @@ -335,6 +335,106 @@ def test_defaults_missing_diagram_fields(self, service): } +class TestExecuteSummarize: + class FakeAgent: + def __init__(self, stream_items=None, run_result=None, stream_error=None): + self.stream_items = stream_items or [] + self.run_result = run_result + self.stream_error = stream_error + self.run_called = False + + async def stream(self, *_args, **_kwargs): + if self.stream_error: + raise self.stream_error + for item in self.stream_items: + yield item + + async def run(self, *_args, **_kwargs): + self.run_called = True + return self.run_result + + @pytest.mark.asyncio(loop_scope="function") + async def test_extracts_dict_stream_summary(self, service): + message = MagicMock() + message.content = '{"summary": "PR summary", "diagram": "", "diagramType": "ASCII"}' + agent = self.FakeAgent(stream_items=[{"messages": [message]}]) + + with patch("service.command.command_service.MCPAgent", return_value=agent): + result = await service._execute_summarize( + llm=MagicMock(), + client=MagicMock(), + prompt="prompt", + supports_mermaid=False, + event_callback=None, + ) + + assert result == { + "summary": "PR summary", + "diagram": "", + "diagramType": "ASCII", + } + assert agent.run_called is False + + @pytest.mark.asyncio(loop_scope="function") + async def test_retries_agent_run_when_stream_summary_is_empty(self, service): + agent = self.FakeAgent( + stream_items=[{"summary": ""}], + run_result='{"summary": "Fallback summary", "diagram": "", "diagramType": "ASCII"}', + ) + + with patch("service.command.command_service.MCPAgent", return_value=agent): + result = await service._execute_summarize( + llm=MagicMock(), + client=MagicMock(), + prompt="prompt", + supports_mermaid=False, + event_callback=None, + ) + + assert result["summary"] == "Fallback summary" + assert agent.run_called is True + + @pytest.mark.asyncio(loop_scope="function") + async def test_retries_agent_run_when_stream_raises_provider_error(self, service): + agent = self.FakeAgent( + stream_error=Exception("The AI provider rejected the request"), + run_result='{"summary": "Fallback summary", "diagram": "", "diagramType": "ASCII"}', + ) + + with patch("service.command.command_service.MCPAgent", return_value=agent): + result = await service._execute_summarize( + llm=MagicMock(), + client=MagicMock(), + prompt="prompt", + supports_mermaid=False, + event_callback=None, + ) + + assert result["summary"] == "Fallback summary" + assert agent.run_called is True + + @pytest.mark.asyncio(loop_scope="function") + async def test_uses_direct_llm_when_agent_outputs_empty_summary_sentinels(self, service): + agent = self.FakeAgent(stream_items=["null"], run_result="No output generated") + response = MagicMock() + response.content = '{"summary": "Direct fallback summary.", "diagram": "", "diagramType": "ASCII"}' + llm = MagicMock() + llm.ainvoke = AsyncMock(return_value=response) + + with patch("service.command.command_service.MCPAgent", return_value=agent): + result = await service._execute_summarize( + llm=llm, + client=MagicMock(), + prompt="prompt", + supports_mermaid=False, + event_callback=None, + ) + + assert result["summary"] == "Direct fallback summary." + assert agent.run_called is True + llm.ainvoke.assert_awaited_once() + + class TestNormalizeAskResult: def test_preserves_provider_error(self, service): result = service._normalize_ask_result({"error": "provider failed"}) @@ -354,6 +454,95 @@ def test_accepts_answer(self, service): assert result == {"answer": "The PR updates auth handling."} +class TestExecuteAsk: + class FakeAgent: + def __init__(self, stream_items=None, run_result=None): + self.stream_items = stream_items or [] + self.run_result = run_result + self.run_called = False + + async def stream(self, *_args, **_kwargs): + for item in self.stream_items: + yield item + + async def run(self, *_args, **_kwargs): + self.run_called = True + return self.run_result + + @pytest.mark.asyncio(loop_scope="function") + async def test_extracts_dict_stream_answer(self, service): + message = MagicMock() + message.content = '{"answer": "The PR updates auth handling."}' + agent = self.FakeAgent(stream_items=[{"messages": [message]}]) + + with patch("service.command.command_service.MCPAgent", return_value=agent): + result = await service._execute_ask( + llm=MagicMock(), + client=MagicMock(), + prompt="prompt", + event_callback=None, + ) + + assert result == {"answer": "The PR updates auth handling."} + assert agent.run_called is False + + @pytest.mark.asyncio(loop_scope="function") + async def test_retries_agent_run_when_stream_is_empty(self, service): + agent = self.FakeAgent( + stream_items=[], + run_result='{"answer": "Fallback answer from non-structured run."}', + ) + + with patch("service.command.command_service.MCPAgent", return_value=agent): + result = await service._execute_ask( + llm=MagicMock(), + client=MagicMock(), + prompt="prompt", + event_callback=None, + ) + + assert result == {"answer": "Fallback answer from non-structured run."} + assert agent.run_called is True + + @pytest.mark.asyncio(loop_scope="function") + async def test_retries_agent_run_when_stream_answer_is_empty(self, service): + agent = self.FakeAgent( + stream_items=[{"answer": ""}], + run_result='{"answer": "Fallback answer after empty structured output."}', + ) + + with patch("service.command.command_service.MCPAgent", return_value=agent): + result = await service._execute_ask( + llm=MagicMock(), + client=MagicMock(), + prompt="prompt", + event_callback=None, + ) + + assert result == {"answer": "Fallback answer after empty structured output."} + assert agent.run_called is True + + @pytest.mark.asyncio(loop_scope="function") + async def test_uses_direct_llm_when_agent_outputs_empty_sentinels(self, service): + agent = self.FakeAgent(stream_items=["null"], run_result="No output generated") + response = MagicMock() + response.content = '{"answer": "Direct fallback answer."}' + llm = MagicMock() + llm.ainvoke = AsyncMock(return_value=response) + + with patch("service.command.command_service.MCPAgent", return_value=agent): + result = await service._execute_ask( + llm=llm, + client=MagicMock(), + prompt="prompt", + event_callback=None, + ) + + assert result == {"answer": "Direct fallback answer."} + assert agent.run_called is True + llm.ainvoke.assert_awaited_once() + + # ── _create_mcp_client ─────────────────────────────────────────── class TestCreateMcpClient: