In [19]:
from rank_bm25 import BM25Okapi
import nltk
from nltk.tokenize import word_tokenize
import numpy as np

# 下载NLTK的分词资源
nltk.data.path.append("./nltk_data")

class BM25Retrieval:
    def __init__(self, paragraphs, tokenizer=word_tokenize):
        """
        初始化BM25检索系统
        
        :param paragraphs: 字符串列表，包含所有候选段落
        :param tokenizer: 分词函数，默认为NLTK的word_tokenize
        """
        self.paragraphs = paragraphs
        self.tokenizer = tokenizer
        self.tokenized_corpus = self._tokenize_corpus(paragraphs)
        self.bm25 = BM25Okapi(self.tokenized_corpus)
    
    def _tokenize_corpus(self, corpus):
        """分词处理所有段落"""
        return [self.tokenizer(doc.lower()) for doc in corpus]
    
    def _tokenize_query(self, query):
        """分词处理查询"""
        return self.tokenizer(query.lower())
    
    def search(self, keywords, top_n=None, threshold=None):
        """
        执行关键词检索
        
        :param keywords: 字符串列表，包含搜索关键词
        :param top_n: 返回最相关的前n个结果
        :param threshold: 相关性分数阈值
        :return: 排序后的(段落, 分数)列表
        """
        # 将关键词列表合并为查询字符串
        query = " ".join(keywords)
        tokenized_query = self._tokenize_query(query)
        
        # 获取BM25分数
        doc_scores = self.bm25.get_scores(tokenized_query)
        
        # 将结果与原始段落配对
        results = list(zip(self.paragraphs, doc_scores))
        
        # 应用阈值筛选
        if threshold is not None:
            results = [(doc, score) for doc, score in results if score >= threshold]
        
        # 按相关性排序
        results.sort(key=lambda x: x[1], reverse=True)
        
        # 返回top_n结果
        if top_n is not None:
            return results[:top_n]
        
        return results

# 示例使用
if __name__ == "__main__":
    # 示例数据
    paragraphs = [
        "Deep learning requires significant computational resources and large datasets.",
        "Neural networks are a branch of machine learning algorithms.",
        "Artificial intelligence is transforming various industries.",
        "GPU acceleration dramatically speeds up model training.",
        "Natural language processing is a key application of AI.",
        "Reinforcement learning differs from supervised learning approaches.",
        "Convolutional neural networks excel at image recognition tasks."
    ]
    
    # 搜索关键词
    keywords = ["deep learning", "neural networks"]
    
    # 初始化检索系统
    retriever = BM25Retrieval(paragraphs)
    
    # 执行搜索
    # 选项1: 获取所有结果(按相关性排序)
    all_results = retriever.search(keywords)
    print("所有结果(按相关性排序):")
    for i, (doc, score) in enumerate(all_results):
        print(f"{i+1}. [Score: {score:.4f}] {doc}")
    
    # # 选项2: 获取前N个结果
    # top_results = retriever.search(keywords, top_n=3)
    # print("\nTop 3 结果:")
    # for i, (doc, score) in enumerate(top_results):
    #     print(f"{i+1}. [Score: {score:.4f}] {doc}")
    
    # # 选项3: 使用阈值过滤
    # threshold_results = retriever.search(keywords, threshold=1.5)
    # print("\n分数高于1.5的结果:")
    # for i, (doc, score) in enumerate(threshold_results):
    #     print(f"{i+1}. [Score: {score:.4f}] {doc}")


所有结果(按相关性排序):
1. [Score: 1.7279] Neural networks are a branch of machine learning algorithms.
2. [Score: 1.6234] Deep learning requires significant computational resources and large datasets.
3. [Score: 1.5656] Convolutional neural networks excel at image recognition tasks.
4. [Score: 0.3705] Reinforcement learning differs from supervised learning approaches.
5. [Score: 0.0000] Artificial intelligence is transforming various industries.
6. [Score: 0.0000] GPU acceleration dramatically speeds up model training.
7. [Score: 0.0000] Natural language processing is a key application of AI.
