From 75b5f36f091a3b820ace50dfe703abede38ccb9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=AD=E8=AF=A3?= Date: Fri, 1 May 2026 22:19:08 +0800 Subject: [PATCH] =?UTF-8?q?fix(knowledge):=20=E4=BF=AE=E5=A4=8D=20general?= =?UTF-8?q?=20parser=20=E5=88=86=E5=9D=97=E8=B6=85=E9=99=90=E5=AF=BC?= =?UTF-8?q?=E8=87=B4=20LightRAG=20=E7=B4=A2=E5=BC=95=E5=A4=B1=E8=B4=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit naive_merge 不保证输出 chunk 在 token 上限内,当单行内容超过 chunk_token_num 时会产生超大 chunk,导致 LightRAG 报错 "Chunk token length 3140 exceeds chunk_token_size 1200"。 - nlp.py: 新增 hard_split_by_token_limit 公共硬切分函数 - general.py: 新增 _ensure_chunk_token_limit 兜底保护 - laws.py: 删除本地重复函数,改用 nlp 版本(DRY) --- .../knowledge/chunking/ragflow_like/nlp.py | 27 +++ .../chunking/ragflow_like/parsers/general.py | 21 +- .../chunking/ragflow_like/parsers/laws.py | 28 +-- .../test/unit/test_chunking_token_limit.py | 185 ++++++++++++++++++ 4 files changed, 233 insertions(+), 28 deletions(-) create mode 100644 backend/test/unit/test_chunking_token_limit.py diff --git a/backend/package/yuxi/knowledge/chunking/ragflow_like/nlp.py b/backend/package/yuxi/knowledge/chunking/ragflow_like/nlp.py index fb93a6ecc..58b0d3f35 100644 --- a/backend/package/yuxi/knowledge/chunking/ragflow_like/nlp.py +++ b/backend/package/yuxi/knowledge/chunking/ragflow_like/nlp.py @@ -55,6 +55,33 @@ def count_tokens(text: str) -> int: return max(1, len(parts)) if text.strip() else 0 +def hard_split_by_token_limit(text: str, chunk_token_num: int) -> list[str]: + """将文本按 token 上限硬切,用于 naive_merge 之后的兜底保护。""" + token_iter = list(re.finditer(r"[A-Za-z0-9_]+|[一-鿿]", text or "")) + if not token_iter: + cleaned = (text or "").strip() + return [cleaned] if cleaned else [] + + chunks: list[str] = [] + start = 0 + index = 0 + max_tokens = max(int(chunk_token_num or 0), 1) + + while index < len(token_iter): + end_index = min(index + max_tokens, len(token_iter)) - 1 + end = token_iter[end_index].end() + piece = text[start:end].strip() + if piece: + chunks.append(piece) + start = end + index = end_index + 1 + + tail = text[start:].strip() + if tail: + chunks.append(tail) + return chunks + + def random_choices(arr: list[str], k: int) -> list[str]: if not arr: return [] diff --git a/backend/package/yuxi/knowledge/chunking/ragflow_like/parsers/general.py b/backend/package/yuxi/knowledge/chunking/ragflow_like/parsers/general.py index 635450854..2ef9e621b 100644 --- a/backend/package/yuxi/knowledge/chunking/ragflow_like/parsers/general.py +++ b/backend/package/yuxi/knowledge/chunking/ragflow_like/parsers/general.py @@ -30,6 +30,24 @@ def _iter_sections(markdown_content: str, delimiter: str) -> list[tuple[str, str return sections +def _ensure_chunk_token_limit(chunks: list[str], chunk_token_num: int) -> list[str]: + """对输出 chunk 做 token 上限保护:超长的直接硬切。""" + max_tokens = int(chunk_token_num or 0) + if max_tokens <= 0: + return [c.strip() for c in chunks if c and c.strip()] + + protected: list[str] = [] + for chunk in chunks: + cleaned = (chunk or "").strip() + if not cleaned: + continue + if nlp.count_tokens(cleaned) <= max_tokens: + protected.append(cleaned) + else: + protected.extend(nlp.hard_split_by_token_limit(cleaned, max_tokens)) + return protected + + def chunk_markdown(markdown_content: str, parser_config: dict[str, Any] | None = None) -> list[str]: parser_config = parser_config or {} @@ -38,9 +56,10 @@ def chunk_markdown(markdown_content: str, parser_config: dict[str, Any] | None = overlapped_percent = int(parser_config.get("overlapped_percent", 0) or 0) sections = _iter_sections(markdown_content, delimiter) - return nlp.naive_merge( + chunks = nlp.naive_merge( sections, chunk_token_num=chunk_token_num, delimiter=delimiter, overlapped_percent=overlapped_percent, ) + return _ensure_chunk_token_limit(chunks, chunk_token_num) diff --git a/backend/package/yuxi/knowledge/chunking/ragflow_like/parsers/laws.py b/backend/package/yuxi/knowledge/chunking/ragflow_like/parsers/laws.py index 25b1aaada..882f56bb4 100644 --- a/backend/package/yuxi/knowledge/chunking/ragflow_like/parsers/laws.py +++ b/backend/package/yuxi/knowledge/chunking/ragflow_like/parsers/laws.py @@ -84,32 +84,6 @@ def _docx_heading_tree(markdown_content: str) -> list[str]: return [element for element in root.get_tree() if element] -def _hard_split_by_token_limit(text: str, chunk_token_num: int) -> list[str]: - token_iter = list(re.finditer(r"[A-Za-z0-9_]+|[\u4e00-\u9fff]", text or "")) - if not token_iter: - cleaned = (text or "").strip() - return [cleaned] if cleaned else [] - - chunks: list[str] = [] - start = 0 - index = 0 - max_tokens = max(int(chunk_token_num or 0), 1) - - while index < len(token_iter): - end_index = min(index + max_tokens, len(token_iter)) - 1 - end = token_iter[end_index].end() - piece = text[start:end].strip() - if piece: - chunks.append(piece) - start = end - index = end_index + 1 - - tail = text[start:].strip() - if tail: - chunks.append(tail) - return chunks - - def _ensure_chunk_token_limit( chunks: list[str], chunk_token_num: int, delimiter: str, overlapped_percent: int ) -> list[str]: @@ -161,7 +135,7 @@ def _ensure_chunk_token_limit( if nlp.count_tokens(text) <= max_tokens: protected.append(text) else: - protected.extend(_hard_split_by_token_limit(text, max_tokens)) + protected.extend(nlp.hard_split_by_token_limit(text, max_tokens)) return [chunk for chunk in protected if chunk.strip()] diff --git a/backend/test/unit/test_chunking_token_limit.py b/backend/test/unit/test_chunking_token_limit.py new file mode 100644 index 000000000..de700e6d5 --- /dev/null +++ b/backend/test/unit/test_chunking_token_limit.py @@ -0,0 +1,185 @@ +"""测试分块 token 上限保护:general parser 对超长 chunk 的硬切分。 + +nlp.py / general.py 只依赖 re 和标准库,用 sys.modules 占位绕过 yuxi 包的 +重依赖链(langchain / pydantic / .env 配置等),实现纯单元测试。 +跑完后清理 sys.modules,避免污染其他测试。 +""" + +import importlib.util +import sys +import types +from pathlib import Path + +import pytest + +_PKG = Path(__file__).resolve().parents[2] / "package" + +_STUB_NAMES = [ + "yuxi", + "yuxi.knowledge", + "yuxi.knowledge.chunking", + "yuxi.knowledge.chunking.ragflow_like", + "yuxi.knowledge.chunking.ragflow_like.parsers", + "yuxi.knowledge.chunking.ragflow_like.nlp", + "yuxi.knowledge.chunking.ragflow_like.parsers.general", +] + +# 由 _isolated_modules fixture 在运行时注入 +nlp = None # type: ignore[assignment] +general = None # type: ignore[assignment] + + +@pytest.fixture(autouse=True, scope="module") +def _isolated_modules(): + """在模块级加载 nlp/general,跑完后清理 sys.modules 避免污染其他测试。""" + saved = {name: sys.modules.get(name) for name in _STUB_NAMES} + + for name in _STUB_NAMES[:5]: + sys.modules.setdefault(name, types.ModuleType(name)) + + def _load(name: str, rel: str): + spec = importlib.util.spec_from_file_location(name, _PKG / rel) + mod = importlib.util.module_from_spec(spec) + sys.modules[name] = mod + spec.loader.exec_module(mod) # type: ignore[union-attr] + return mod + + _nlp = _load( + "yuxi.knowledge.chunking.ragflow_like.nlp", + "yuxi/knowledge/chunking/ragflow_like/nlp.py", + ) + sys.modules["yuxi.knowledge.chunking.ragflow_like"].nlp = _nlp # type: ignore[attr-defined] + + _general = _load( + "yuxi.knowledge.chunking.ragflow_like.parsers.general", + "yuxi/knowledge/chunking/ragflow_like/parsers/general.py", + ) + + # 注入模块级变量供测试用例访问 + global nlp, general # noqa: PLW0603 + nlp = _nlp + general = _general + + yield + + # 清理:恢复原始状态 + for name in _STUB_NAMES: + if saved[name] is None: + sys.modules.pop(name, None) + else: + sys.modules[name] = saved[name] + + +# ── nlp.hard_split_by_token_limit ────────────────────────────────── + + +class TestHardSplitByTokenLimit: + def test_short_text_unchanged(self): + text = "这是一段短文本" + result = nlp.hard_split_by_token_limit(text, 512) + assert result == [text] + + def test_splits_long_chinese_text(self): + text = "测试内容" * 300 # ~600 CJK tokens + result = nlp.hard_split_by_token_limit(text, 512) + assert len(result) > 1 + for chunk in result: + assert nlp.count_tokens(chunk) <= 512 + + def test_splits_long_english_text(self): + text = "hello world " * 1000 # ~2000 word tokens + result = nlp.hard_split_by_token_limit(text, 512) + assert len(result) > 1 + for chunk in result: + assert nlp.count_tokens(chunk) <= 512 + + def test_empty_text_returns_empty(self): + assert nlp.hard_split_by_token_limit("", 512) == [] + + def test_whitespace_only_returns_empty(self): + assert nlp.hard_split_by_token_limit(" \n\t ", 512) == [] + + def test_zero_limit_floors_to_one(self): + text = "a b c" # 3 个独立 token(单词) + result = nlp.hard_split_by_token_limit(text, 0) + # max_tokens = max(0, 1) = 1, 每个 token 单独一个 chunk + assert len(result) == 3 + + def test_punctuation_only_text(self): + text = ",。!?" + result = nlp.hard_split_by_token_limit(text, 512) + assert result == [",。!?"] + + +# ── general._ensure_chunk_token_limit ────────────────────────────── + + +class TestEnsureChunkTokenLimit: + def test_all_chunks_within_limit_pass_through(self): + chunks = ["短文本一", "短文本二", "短文本三"] + result = general._ensure_chunk_token_limit(chunks, 512) + assert result == ["短文本一", "短文本二", "短文本三"] + + def test_oversized_chunk_gets_split(self): + long_text = "内容" * 300 # ~600 CJK tokens + chunks = ["短文本", long_text, "短文本二"] + result = general._ensure_chunk_token_limit(chunks, 512) + assert result[0] == "短文本" + assert result[-1] == "短文本二" + middle_chunks = result[1:-1] + assert len(middle_chunks) > 1 + for chunk in middle_chunks: + assert nlp.count_tokens(chunk) <= 512 + + def test_empty_chunks_filtered(self): + chunks = ["有效文本", "", " ", "另一段"] + result = general._ensure_chunk_token_limit(chunks, 512) + assert result == ["有效文本", "另一段"] + + def test_zero_limit_returns_stripped(self): + chunks = [" 文本一 ", "文本二"] + result = general._ensure_chunk_token_limit(chunks, 0) + assert result == ["文本一", "文本二"] + + +# ── general.chunk_markdown 集成 ──────────────────────────────────── + + +class TestGeneralChunkMarkdown: + def test_normal_document_chunks_within_limit(self): + doc = "# 标题\n\n第一段内容\n\n第二段内容\n\n第三段内容" + chunks = general.chunk_markdown(doc, {"chunk_token_num": 512}) + assert len(chunks) > 0 + for chunk in chunks: + assert nlp.count_tokens(chunk) <= 512 + + def test_oversized_single_line_gets_split(self): + long_line = "运维知识" * 800 # ~3200 CJK tokens + doc = f"# 运维知识库\n\n{long_line}" + chunks = general.chunk_markdown(doc, {"chunk_token_num": 512}) + assert len(chunks) > 1 + for chunk in chunks: + assert nlp.count_tokens(chunk) <= 512 + + def test_empty_document_returns_empty(self): + assert general.chunk_markdown("", {"chunk_token_num": 512}) == [] + + def test_default_config_uses_512(self): + doc = "测试\n" * 200 + chunks = general.chunk_markdown(doc) + for chunk in chunks: + assert nlp.count_tokens(chunk) <= 512 + + +# ── laws parser 回归 ────────────────────────────────────────────── + + +class TestLawsParserRegression: + """验证 nlp.hard_split_by_token_limit 可被 laws parser 正常调用。""" + + def test_hard_split_produces_same_result(self): + text = "法规内容" * 300 + result = nlp.hard_split_by_token_limit(text, 512) + assert len(result) > 1 + for chunk in result: + assert nlp.count_tokens(chunk) <= 512