Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions backend/package/yuxi/knowledge/chunking/ragflow_like/nlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,33 @@ def count_tokens(text: str) -> int:
return max(1, len(parts)) if text.strip() else 0


def hard_split_by_token_limit(text: str, chunk_token_num: int) -> list[str]:
"""将文本按 token 上限硬切,用于 naive_merge 之后的兜底保护。"""
token_iter = list(re.finditer(r"[A-Za-z0-9_]+|[一-鿿]", text or ""))
if not token_iter:
cleaned = (text or "").strip()
return [cleaned] if cleaned else []

chunks: list[str] = []
start = 0
index = 0
max_tokens = max(int(chunk_token_num or 0), 1)

while index < len(token_iter):
end_index = min(index + max_tokens, len(token_iter)) - 1
end = token_iter[end_index].end()
piece = text[start:end].strip()
if piece:
chunks.append(piece)
start = end
index = end_index + 1

tail = text[start:].strip()
if tail:
chunks.append(tail)
return chunks


def random_choices(arr: list[str], k: int) -> list[str]:
if not arr:
return []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,24 @@ def _iter_sections(markdown_content: str, delimiter: str) -> list[tuple[str, str
return sections


def _ensure_chunk_token_limit(chunks: list[str], chunk_token_num: int) -> list[str]:
"""对输出 chunk 做 token 上限保护:超长的直接硬切。"""
max_tokens = int(chunk_token_num or 0)
if max_tokens <= 0:
return [c.strip() for c in chunks if c and c.strip()]

protected: list[str] = []
for chunk in chunks:
cleaned = (chunk or "").strip()
if not cleaned:
continue
if nlp.count_tokens(cleaned) <= max_tokens:
protected.append(cleaned)
else:
protected.extend(nlp.hard_split_by_token_limit(cleaned, max_tokens))
return protected


def chunk_markdown(markdown_content: str, parser_config: dict[str, Any] | None = None) -> list[str]:
parser_config = parser_config or {}

Expand All @@ -38,9 +56,10 @@ def chunk_markdown(markdown_content: str, parser_config: dict[str, Any] | None =
overlapped_percent = int(parser_config.get("overlapped_percent", 0) or 0)

sections = _iter_sections(markdown_content, delimiter)
return nlp.naive_merge(
chunks = nlp.naive_merge(
sections,
chunk_token_num=chunk_token_num,
delimiter=delimiter,
overlapped_percent=overlapped_percent,
)
return _ensure_chunk_token_limit(chunks, chunk_token_num)
Original file line number Diff line number Diff line change
Expand Up @@ -84,32 +84,6 @@ def _docx_heading_tree(markdown_content: str) -> list[str]:
return [element for element in root.get_tree() if element]


def _hard_split_by_token_limit(text: str, chunk_token_num: int) -> list[str]:
token_iter = list(re.finditer(r"[A-Za-z0-9_]+|[\u4e00-\u9fff]", text or ""))
if not token_iter:
cleaned = (text or "").strip()
return [cleaned] if cleaned else []

chunks: list[str] = []
start = 0
index = 0
max_tokens = max(int(chunk_token_num or 0), 1)

while index < len(token_iter):
end_index = min(index + max_tokens, len(token_iter)) - 1
end = token_iter[end_index].end()
piece = text[start:end].strip()
if piece:
chunks.append(piece)
start = end
index = end_index + 1

tail = text[start:].strip()
if tail:
chunks.append(tail)
return chunks


def _ensure_chunk_token_limit(
chunks: list[str], chunk_token_num: int, delimiter: str, overlapped_percent: int
) -> list[str]:
Expand Down Expand Up @@ -161,7 +135,7 @@ def _ensure_chunk_token_limit(
if nlp.count_tokens(text) <= max_tokens:
protected.append(text)
else:
protected.extend(_hard_split_by_token_limit(text, max_tokens))
protected.extend(nlp.hard_split_by_token_limit(text, max_tokens))

return [chunk for chunk in protected if chunk.strip()]

Expand Down
185 changes: 185 additions & 0 deletions backend/test/unit/test_chunking_token_limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
"""测试分块 token 上限保护:general parser 对超长 chunk 的硬切分。

nlp.py / general.py 只依赖 re 和标准库,用 sys.modules 占位绕过 yuxi 包的
重依赖链(langchain / pydantic / .env 配置等),实现纯单元测试。
跑完后清理 sys.modules,避免污染其他测试。
"""

import importlib.util
import sys
import types
from pathlib import Path

import pytest

_PKG = Path(__file__).resolve().parents[2] / "package"

_STUB_NAMES = [
"yuxi",
"yuxi.knowledge",
"yuxi.knowledge.chunking",
"yuxi.knowledge.chunking.ragflow_like",
"yuxi.knowledge.chunking.ragflow_like.parsers",
"yuxi.knowledge.chunking.ragflow_like.nlp",
"yuxi.knowledge.chunking.ragflow_like.parsers.general",
]

# 由 _isolated_modules fixture 在运行时注入
nlp = None # type: ignore[assignment]
general = None # type: ignore[assignment]


@pytest.fixture(autouse=True, scope="module")
def _isolated_modules():
"""在模块级加载 nlp/general,跑完后清理 sys.modules 避免污染其他测试。"""
saved = {name: sys.modules.get(name) for name in _STUB_NAMES}

for name in _STUB_NAMES[:5]:
sys.modules.setdefault(name, types.ModuleType(name))

def _load(name: str, rel: str):
spec = importlib.util.spec_from_file_location(name, _PKG / rel)
mod = importlib.util.module_from_spec(spec)
sys.modules[name] = mod
spec.loader.exec_module(mod) # type: ignore[union-attr]
return mod

_nlp = _load(
"yuxi.knowledge.chunking.ragflow_like.nlp",
"yuxi/knowledge/chunking/ragflow_like/nlp.py",
)
sys.modules["yuxi.knowledge.chunking.ragflow_like"].nlp = _nlp # type: ignore[attr-defined]

_general = _load(
"yuxi.knowledge.chunking.ragflow_like.parsers.general",
"yuxi/knowledge/chunking/ragflow_like/parsers/general.py",
)

# 注入模块级变量供测试用例访问
global nlp, general # noqa: PLW0603
nlp = _nlp
general = _general

yield

# 清理:恢复原始状态
for name in _STUB_NAMES:
if saved[name] is None:
sys.modules.pop(name, None)
else:
sys.modules[name] = saved[name]


# ── nlp.hard_split_by_token_limit ──────────────────────────────────


class TestHardSplitByTokenLimit:
def test_short_text_unchanged(self):
text = "这是一段短文本"
result = nlp.hard_split_by_token_limit(text, 512)
assert result == [text]

def test_splits_long_chinese_text(self):
text = "测试内容" * 300 # ~600 CJK tokens
result = nlp.hard_split_by_token_limit(text, 512)
assert len(result) > 1
for chunk in result:
assert nlp.count_tokens(chunk) <= 512

def test_splits_long_english_text(self):
text = "hello world " * 1000 # ~2000 word tokens
result = nlp.hard_split_by_token_limit(text, 512)
assert len(result) > 1
for chunk in result:
assert nlp.count_tokens(chunk) <= 512

def test_empty_text_returns_empty(self):
assert nlp.hard_split_by_token_limit("", 512) == []

def test_whitespace_only_returns_empty(self):
assert nlp.hard_split_by_token_limit(" \n\t ", 512) == []

def test_zero_limit_floors_to_one(self):
text = "a b c" # 3 个独立 token(单词)
result = nlp.hard_split_by_token_limit(text, 0)
# max_tokens = max(0, 1) = 1, 每个 token 单独一个 chunk
assert len(result) == 3

def test_punctuation_only_text(self):
text = ",。!?"
result = nlp.hard_split_by_token_limit(text, 512)
assert result == [",。!?"]


# ── general._ensure_chunk_token_limit ──────────────────────────────


class TestEnsureChunkTokenLimit:
def test_all_chunks_within_limit_pass_through(self):
chunks = ["短文本一", "短文本二", "短文本三"]
result = general._ensure_chunk_token_limit(chunks, 512)
assert result == ["短文本一", "短文本二", "短文本三"]

def test_oversized_chunk_gets_split(self):
long_text = "内容" * 300 # ~600 CJK tokens
chunks = ["短文本", long_text, "短文本二"]
result = general._ensure_chunk_token_limit(chunks, 512)
assert result[0] == "短文本"
assert result[-1] == "短文本二"
middle_chunks = result[1:-1]
assert len(middle_chunks) > 1
for chunk in middle_chunks:
assert nlp.count_tokens(chunk) <= 512

def test_empty_chunks_filtered(self):
chunks = ["有效文本", "", " ", "另一段"]
result = general._ensure_chunk_token_limit(chunks, 512)
assert result == ["有效文本", "另一段"]

def test_zero_limit_returns_stripped(self):
chunks = [" 文本一 ", "文本二"]
result = general._ensure_chunk_token_limit(chunks, 0)
assert result == ["文本一", "文本二"]


# ── general.chunk_markdown 集成 ────────────────────────────────────


class TestGeneralChunkMarkdown:
def test_normal_document_chunks_within_limit(self):
doc = "# 标题\n\n第一段内容\n\n第二段内容\n\n第三段内容"
chunks = general.chunk_markdown(doc, {"chunk_token_num": 512})
assert len(chunks) > 0
for chunk in chunks:
assert nlp.count_tokens(chunk) <= 512

def test_oversized_single_line_gets_split(self):
long_line = "运维知识" * 800 # ~3200 CJK tokens
doc = f"# 运维知识库\n\n{long_line}"
chunks = general.chunk_markdown(doc, {"chunk_token_num": 512})
assert len(chunks) > 1
for chunk in chunks:
assert nlp.count_tokens(chunk) <= 512

def test_empty_document_returns_empty(self):
assert general.chunk_markdown("", {"chunk_token_num": 512}) == []

def test_default_config_uses_512(self):
doc = "测试\n" * 200
chunks = general.chunk_markdown(doc)
for chunk in chunks:
assert nlp.count_tokens(chunk) <= 512


# ── laws parser 回归 ──────────────────────────────────────────────


class TestLawsParserRegression:
"""验证 nlp.hard_split_by_token_limit 可被 laws parser 正常调用。"""

def test_hard_split_produces_same_result(self):
text = "法规内容" * 300
result = nlp.hard_split_by_token_limit(text, 512)
assert len(result) > 1
for chunk in result:
assert nlp.count_tokens(chunk) <= 512