<a href="https://colab.research.google.com/github/wangyiyang/RAG-Cookbook-Code/blob/main/ch04/context_builder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install numpy transformers torch sentence-transformers

In [1]:
"""
上下文构建器
实现多文档信息融合、冲突解决和智能截断
"""

import re
import hashlib
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
from collections import defaultdict


@dataclass
class Document:
    """文档数据结构"""
    id: str
    page_content: str
    metadata: Dict[str, Any]

    def __post_init__(self):
        if not self.id:
            # 如果没有ID，基于内容生成
            self.id = hashlib.md5(self.page_content.encode()).hexdigest()[:8]


@dataclass
class Fact:
    """事实信息数据结构"""
    content: str
    entity: str
    attribute: str
    value: str
    confidence: float = 1.0


class FactChecker:
    """事实检查器"""

    def extract_facts(self, text: str) -> List[Fact]:
        """从文本中提取事实信息"""
        facts = []

        # 简化的事实提取（实际应用中使用NER和关系抽取）
        sentences = text.split('。')

        for sentence in sentences:
            if len(sentence.strip()) < 10:
                continue

            # 识别简单的事实模式：A是B、A有B等
            patterns = [
                r'(.+?)是(.+)',
                r'(.+?)有(.+)',
                r'(.+?)包含(.+)',
                r'(.+?)支持(.+)'
            ]

            for pattern in patterns:
                match = re.search(pattern, sentence)
                if match:
                    entity = match.group(1).strip()
                    value = match.group(2).strip()

                    facts.append(Fact(
                        content=sentence,
                        entity=entity,
                        attribute='description',
                        value=value,
                        confidence=0.8
                    ))
                    break

        return facts

    def are_conflicting(self, fact1: Fact, fact2: Fact) -> bool:
        """判断两个事实是否冲突"""
        # 简化的冲突检测
        if fact1.entity.lower() == fact2.entity.lower():
            if fact1.attribute == fact2.attribute:
                if fact1.value.lower() != fact2.value.lower():
                    return True
        return False


class SourceRanker:
    """信息源权威性排序器"""

    def __init__(self):
        self.authority_scores = {
            'wikipedia': 0.9,
            'academic_paper': 0.95,
            'official_doc': 0.98,
            'news': 0.7,
            'blog': 0.5,
            'social_media': 0.3,
            'unknown': 0.6
        }

    def get_source_authority(self, document: Document) -> float:
        """获取文档权威性分数"""
        source_type = document.metadata.get('source_type', 'unknown')
        base_score = self.authority_scores.get(source_type, 0.6)

        # 根据其他因素调整分数
        if 'publish_date' in document.metadata:
            # 较新的文档权威性稍高
            base_score += 0.05

        if document.metadata.get('citation_count', 0) > 100:
            # 高引用文档权威性更高
            base_score += 0.1

        return min(base_score, 1.0)


class ConflictResolver:
    """信息冲突解决器"""

    def __init__(self):
        self.fact_checker = FactChecker()
        self.source_ranker = SourceRanker()

    def resolve_conflicts(self, documents: List[Document]) -> List[Document]:
        """解决文档间的信息冲突"""
        # 检测冲突
        conflicts = self.detect_conflicts(documents)

        if not conflicts:
            return documents

        # 解决冲突
        resolved_docs = []
        for doc in documents:
            if doc.id in conflicts:
                # 选择权威性更高的信息
                authoritative_content = self.select_authoritative_info(
                    doc, conflicts[doc.id]
                )
                doc.page_content = authoritative_content
            resolved_docs.append(doc)

        return resolved_docs

    def detect_conflicts(self, documents: List[Document]) -> Dict[str, List[Dict]]:
        """检测信息冲突"""
        conflicts = defaultdict(list)

        # 提取所有文档的事实
        doc_facts = {}
        for doc in documents:
            facts = self.fact_checker.extract_facts(doc.page_content)
            doc_facts[doc.id] = facts

        # 检测冲突
        for doc_id1, facts1 in doc_facts.items():
            for doc_id2, facts2 in doc_facts.items():
                if doc_id1 >= doc_id2:  # 避免重复比较
                    continue

                for fact1 in facts1:
                    for fact2 in facts2:
                        if self.fact_checker.are_conflicting(fact1, fact2):
                            conflicts[doc_id1].append({
                                'conflicting_doc': doc_id2,
                                'fact1': fact1,
                                'fact2': fact2
                            })

        return conflicts

    def select_authoritative_info(
        self,
        document: Document,
        conflict_info: List[Dict]
    ) -> str:
        """选择权威性更高的信息"""
        doc_authority = self.source_ranker.get_source_authority(document)

        # 简化处理：如果当前文档权威性较高，保留原内容
        # 否则，标记存在争议
        if doc_authority > 0.8:
            return document.page_content
        else:
            return (document.page_content +
                   "\n\n[注意：此信息可能存在不同观点，请谨慎参考]")


class ContextBuilder:
    """上下文构建器主类"""

    def __init__(self, max_context_length: int = 4000):
        self.max_context_length = max_context_length
        self.conflict_resolver = ConflictResolver()

    def build_context(
        self,
        query: str,
        documents: List[Document]
    ) -> str:
        """智能上下文构建"""
        # 1. 文档去重和排序
        unique_docs = self.deduplicate_documents(documents)
        ranked_docs = self.rank_by_relevance(query, unique_docs)

        # 2. 信息冲突检测与解决
        resolved_docs = self.conflict_resolver.resolve_conflicts(ranked_docs)

        # 3. 上下文组装
        context_parts = []
        current_length = 0

        for i, doc in enumerate(resolved_docs):
            # 添加来源标识
            source_info = f"来源{i+1}[{doc.metadata.get('title', 'Unknown')}]:"
            content = f"{source_info}\n{doc.page_content}\n"

            # 长度控制
            if current_length + len(content) > self.max_context_length:
                # 智能截断：保留最重要的信息
                remaining_space = self.max_context_length - current_length
                truncated_content = self.intelligent_truncate(
                    content, remaining_space, query
                )
                if truncated_content:
                    context_parts.append(truncated_content)
                break

            context_parts.append(content)
            current_length += len(content)

        return "\n".join(context_parts)

    def deduplicate_documents(self, documents: List[Document]) -> List[Document]:
        """文档去重"""
        seen_content = set()
        unique_docs = []

        for doc in documents:
            # 使用内容hash进行去重
            content_hash = hashlib.md5(doc.page_content.encode()).hexdigest()

            if content_hash not in seen_content:
                seen_content.add(content_hash)
                unique_docs.append(doc)

        return unique_docs

    def rank_by_relevance(
        self,
        query: str,
        documents: List[Document]
    ) -> List[Document]:
        """根据相关性排序文档"""
        scored_docs = []

        for doc in documents:
            relevance_score = self.calculate_relevance(query, doc)
            scored_docs.append((doc, relevance_score))

        # 按相关性分数降序排序
        scored_docs.sort(key=lambda x: x[1], reverse=True)

        return [doc for doc, score in scored_docs]

    def calculate_relevance(self, query: str, document: Document) -> float:
        """计算文档与查询的相关性"""
        query_words = set(query.lower().split())
        doc_words = set(document.page_content.lower().split())

        # 简单的词汇重叠相关性计算
        intersection = query_words & doc_words
        union = query_words | doc_words

        jaccard_score = len(intersection) / len(union) if union else 0

        # 考虑文档元数据中的权威性
        authority_boost = self.conflict_resolver.source_ranker.get_source_authority(document)

        return jaccard_score * 0.7 + authority_boost * 0.3

    def intelligent_truncate(
        self,
        content: str,
        max_length: int,
        query: str
    ) -> str:
        """基于查询相关性的智能截断"""
        if len(content) <= max_length:
            return content

        sentences = content.split('。')
        sentence_scores = []

        for sentence in sentences:
            if len(sentence.strip()) < 5:
                continue
            # 计算句子与查询的相关性
            relevance_score = self.calculate_sentence_relevance(sentence, query)
            sentence_scores.append((sentence, relevance_score))

        # 按相关性排序，选择最相关的句子
        sentence_scores.sort(key=lambda x: x[1], reverse=True)

        selected_content = ""
        for sentence, score in sentence_scores:
            potential_content = selected_content + sentence + "。"
            if len(potential_content) <= max_length:
                selected_content = potential_content
            else:
                break

        return selected_content

    def calculate_sentence_relevance(self, sentence: str, query: str) -> float:
        """计算句子与查询的相关性"""
        query_words = set(query.lower().split())
        sentence_words = set(sentence.lower().split())

        # 词汇重叠分数
        intersection = query_words & sentence_words
        overlap_score = len(intersection) / len(query_words) if query_words else 0

        # 句子长度惩罚（太短的句子信息量有限）
        length_score = min(len(sentence) / 50, 1.0)

        return overlap_score * 0.8 + length_score * 0.2


# 使用示例
if __name__ == "__main__":
    # 创建示例文档
    documents = [
        Document(
            id="doc1",
            page_content="RAG是检索增强生成技术，它结合了信息检索和文本生成。RAG通过检索相关文档来提升生成质量。",
            metadata={"title": "RAG技术介绍", "source_type": "academic_paper"}
        ),
        Document(
            id="doc2",
            page_content="检索增强生成（RAG）是一种新型的AI技术。它能够根据检索到的信息生成更准确的回答。",
            metadata={"title": "AI技术概述", "source_type": "blog"}
        ),
        Document(
            id="doc3",
            page_content="RAG系统包含两个主要组件：检索器和生成器。检索器负责找到相关文档，生成器负责产生最终答案。",
            metadata={"title": "RAG系统架构", "source_type": "official_doc"}
        )
    ]

    # 初始化上下文构建器
    context_builder = ContextBuilder(max_context_length=500)

    # 构建上下文
    query = "什么是RAG技术？"
    context = context_builder.build_context(query, documents)

    print("构建的上下文:")
    print("=" * 50)
    print(context)
    print("=" * 50)
    print(f"上下文长度: {len(context)} 字符")

构建的上下文:
来源1[RAG系统架构]:
RAG系统包含两个主要组件：检索器和生成器。检索器负责找到相关文档，生成器负责产生最终答案。

来源2[RAG技术介绍]:
RAG是检索增强生成技术，它结合了信息检索和文本生成。RAG通过检索相关文档来提升生成质量。

来源3[AI技术概述]:
检索增强生成（RAG）是一种新型的AI技术。它能够根据检索到的信息生成更准确的回答。

上下文长度: 180 字符
