Skip to content

Commit 66423b8

Browse files
authored
fix: resolve TypeError when LLM returns invalid JSON across all retries (#488) (#490)
- Rename shadowed `max_retries` variable to `llm_max_retries` and move config resolution outside the loop; the old code captured `range(2)` then overwrote `max_retries` inside the loop, so comparisons used a different value than the loop bound — causing `continue` on the final iteration, exhausting the loop, and reaching `raise last_error` where `last_error` was still None → TypeError - Add fallback `raise RuntimeError(...)` after the retry loop so that if `last_error` is None a descriptive error is raised instead of None - Add unit tests covering non-dict JSON responses with various retry counts
1 parent abbf874 commit 66423b8

File tree

2 files changed

+159
-17
lines changed

2 files changed

+159
-17
lines changed

hindsight-api/hindsight_api/engine/retain/fact_extraction.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -944,16 +944,15 @@ async def _extract_facts_from_chunk(
944944
user_message = _build_user_message(chunk, chunk_index, total_chunks, event_date, context, metadata)
945945

946946
# Retry logic for JSON validation errors
947-
max_retries = 2
948-
last_error = None
947+
# Use retain-specific overrides if set, otherwise fall back to global LLM config
948+
llm_max_retries = (
949+
config.retain_llm_max_retries if config.retain_llm_max_retries is not None else config.llm_max_retries
950+
)
951+
last_error: Exception | None = None
949952

950953
usage = TokenUsage() # Track cumulative usage across retries
951-
for attempt in range(max_retries):
954+
for attempt in range(llm_max_retries):
952955
try:
953-
# Use retain-specific overrides if set, otherwise fall back to global LLM config
954-
max_retries = (
955-
config.retain_llm_max_retries if config.retain_llm_max_retries is not None else config.llm_max_retries
956-
)
957956
initial_backoff = (
958957
config.retain_llm_initial_backoff
959958
if config.retain_llm_initial_backoff is not None
@@ -969,7 +968,7 @@ async def _extract_facts_from_chunk(
969968
scope="retain_extract_facts",
970969
temperature=0.1,
971970
max_completion_tokens=config.retain_max_completion_tokens,
972-
max_retries=max_retries,
971+
max_retries=llm_max_retries,
973972
initial_backoff=initial_backoff,
974973
max_backoff=max_backoff,
975974
skip_validation=True, # Get raw JSON, we'll validate leniently
@@ -983,14 +982,14 @@ async def _extract_facts_from_chunk(
983982

984983
# Handle malformed LLM responses
985984
if not isinstance(extraction_response_json, dict):
986-
if attempt < max_retries - 1:
985+
if attempt < llm_max_retries - 1:
987986
logger.warning(
988-
f"LLM returned non-dict JSON on attempt {attempt + 1}/{max_retries}: {type(extraction_response_json).__name__}. Retrying..."
987+
f"LLM returned non-dict JSON on attempt {attempt + 1}/{llm_max_retries}: {type(extraction_response_json).__name__}. Retrying..."
989988
)
990989
continue
991990
else:
992991
logger.warning(
993-
f"LLM returned non-dict JSON after {max_retries} attempts: {type(extraction_response_json).__name__}. "
992+
f"LLM returned non-dict JSON after {llm_max_retries} attempts: {type(extraction_response_json).__name__}. "
994993
f"Raw: {str(extraction_response_json)[:500]}"
995994
)
996995
return [], usage
@@ -1206,9 +1205,9 @@ def get_value(field_name):
12061205
continue
12071206

12081207
# If we got malformed facts and haven't exhausted retries, try again
1209-
if has_malformed_facts and len(chunk_facts) < len(raw_facts) * 0.8 and attempt < max_retries - 1:
1208+
if has_malformed_facts and len(chunk_facts) < len(raw_facts) * 0.8 and attempt < llm_max_retries - 1:
12101209
logger.warning(
1211-
f"Got {len(raw_facts) - len(chunk_facts)} malformed facts out of {len(raw_facts)} on attempt {attempt + 1}/{max_retries}. Retrying..."
1210+
f"Got {len(raw_facts) - len(chunk_facts)} malformed facts out of {len(raw_facts)} on attempt {attempt + 1}/{llm_max_retries}. Retrying..."
12121211
)
12131212
continue
12141213

@@ -1241,16 +1240,18 @@ def get_value(field_name):
12411240

12421241
if "json_validate_failed" in str(e):
12431242
logger.warning(
1244-
f" [1.3.{chunk_index + 1}] Attempt {attempt + 1}/{max_retries} failed with JSON validation error: {e}"
1243+
f" [1.3.{chunk_index + 1}] Attempt {attempt + 1}/{llm_max_retries} failed with JSON validation error: {e}"
12451244
)
1246-
if attempt < max_retries - 1:
1245+
if attempt < llm_max_retries - 1:
12471246
logger.info(f" [1.3.{chunk_index + 1}] Retrying...")
12481247
continue
12491248
# If it's not a JSON validation error or we're out of retries, re-raise
12501249
raise
12511250

1252-
# If we exhausted all retries, raise the last error
1253-
raise last_error
1251+
# If we exhausted all retries, raise the last error or a descriptive fallback
1252+
if last_error is not None:
1253+
raise last_error
1254+
raise RuntimeError(f"Fact extraction failed after {llm_max_retries} attempts: LLM did not return valid JSON")
12541255

12551256

12561257
async def _extract_facts_with_auto_split(
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
"""
2+
Unit tests for fact extraction retry logic.
3+
4+
Tests the fix for the TypeError when LLM returns invalid JSON across all retries.
5+
Previously, `raise last_error` would raise None (TypeError) because last_error was
6+
only set in the BadRequestError handler, not when the LLM returned non-dict JSON.
7+
"""
8+
9+
from datetime import datetime, timezone
10+
from unittest.mock import AsyncMock, MagicMock, patch
11+
12+
import pytest
13+
14+
15+
def _make_config(llm_max_retries: int = 3, retain_llm_max_retries: int | None = None):
16+
"""Build a minimal HindsightConfig for fact extraction tests."""
17+
from hindsight_api.config import HindsightConfig
18+
19+
cfg = MagicMock(spec=HindsightConfig)
20+
cfg.retain_llm_max_retries = retain_llm_max_retries
21+
cfg.llm_max_retries = llm_max_retries
22+
cfg.retain_llm_initial_backoff = None
23+
cfg.llm_initial_backoff = 0.0
24+
cfg.retain_llm_max_backoff = None
25+
cfg.llm_max_backoff = 0.0
26+
cfg.retain_max_completion_tokens = 8192
27+
cfg.retain_extraction_mode = "concise"
28+
cfg.retain_extract_causal_links = False
29+
cfg.retain_mission = None
30+
return cfg
31+
32+
33+
def _make_llm_config(mock_response):
34+
"""Build a mock LLMProvider that returns the given response."""
35+
from hindsight_api.engine.llm_wrapper import LLMProvider
36+
37+
llm = MagicMock(spec=LLMProvider)
38+
llm.provider = "mock"
39+
token_usage = MagicMock()
40+
token_usage.__add__ = lambda self, other: self
41+
llm.call = AsyncMock(return_value=(mock_response, token_usage))
42+
return llm
43+
44+
45+
@pytest.mark.asyncio
46+
async def test_non_dict_json_all_retries_returns_empty():
47+
"""
48+
When LLM returns non-dict JSON on every attempt, extraction should return []
49+
without raising TypeError ('exceptions must derive from BaseException').
50+
51+
This was the bug: the loop ran range(2) times (hardcoded), but comparisons
52+
used config.llm_max_retries (default 10). On the last loop iteration (attempt=1),
53+
`attempt < 10 - 1` was True, so the code called `continue`, the loop
54+
exhausted, and `raise last_error` raised None → TypeError.
55+
"""
56+
from hindsight_api.engine.retain.fact_extraction import _extract_facts_from_chunk
57+
58+
# llm_max_retries=3 ensures the bug triggers with the old code (3 != 2 hardcoded)
59+
config = _make_config(llm_max_retries=3, retain_llm_max_retries=None)
60+
61+
# Mock: always returns a list (non-dict), which is invalid
62+
llm_config = _make_llm_config(mock_response=[{"invalid": "response"}])
63+
64+
with patch(
65+
"hindsight_api.engine.retain.fact_extraction._build_extraction_prompt_and_schema",
66+
return_value=("system prompt", MagicMock()),
67+
):
68+
facts, usage = await _extract_facts_from_chunk(
69+
chunk="Alice visited Paris in 2023.",
70+
chunk_index=0,
71+
total_chunks=1,
72+
event_date=datetime(2023, 1, 1, tzinfo=timezone.utc),
73+
context="travel notes",
74+
llm_config=llm_config,
75+
config=config,
76+
agent_name="test-agent",
77+
)
78+
79+
assert facts == []
80+
81+
82+
@pytest.mark.asyncio
83+
async def test_non_dict_json_with_default_max_retries_returns_empty():
84+
"""
85+
Same scenario with the default llm_max_retries=10 (matching real default config).
86+
The old code ran range(2) but checked against 10, always continuing until
87+
the loop exhausted, then raised None → TypeError.
88+
"""
89+
from hindsight_api.engine.retain.fact_extraction import _extract_facts_from_chunk
90+
91+
config = _make_config(llm_max_retries=10, retain_llm_max_retries=None)
92+
llm_config = _make_llm_config(mock_response="not a dict at all")
93+
94+
with patch(
95+
"hindsight_api.engine.retain.fact_extraction._build_extraction_prompt_and_schema",
96+
return_value=("system prompt", MagicMock()),
97+
):
98+
facts, usage = await _extract_facts_from_chunk(
99+
chunk="Some text.",
100+
chunk_index=0,
101+
total_chunks=1,
102+
event_date=datetime(2023, 6, 1, tzinfo=timezone.utc),
103+
context="",
104+
llm_config=llm_config,
105+
config=config,
106+
agent_name="agent",
107+
)
108+
109+
assert facts == []
110+
111+
112+
@pytest.mark.asyncio
113+
async def test_retain_llm_max_retries_overrides_global():
114+
"""
115+
When retain_llm_max_retries is set, it should be used for the loop range
116+
and all comparisons (no shadowing bug).
117+
"""
118+
from hindsight_api.engine.retain.fact_extraction import _extract_facts_from_chunk
119+
120+
# retain_llm_max_retries=5 should override llm_max_retries=10
121+
config = _make_config(llm_max_retries=10, retain_llm_max_retries=5)
122+
llm_config = _make_llm_config(mock_response=42) # non-dict: integer
123+
124+
with patch(
125+
"hindsight_api.engine.retain.fact_extraction._build_extraction_prompt_and_schema",
126+
return_value=("system prompt", MagicMock()),
127+
):
128+
facts, usage = await _extract_facts_from_chunk(
129+
chunk="Bob likes Python.",
130+
chunk_index=0,
131+
total_chunks=1,
132+
event_date=datetime(2024, 1, 1, tzinfo=timezone.utc),
133+
context="",
134+
llm_config=llm_config,
135+
config=config,
136+
agent_name="agent",
137+
)
138+
139+
assert facts == []
140+
# Verify it retried exactly retain_llm_max_retries times
141+
assert llm_config.call.call_count == 5

0 commit comments

Comments
 (0)