In [1]:
from langchain.llms import tongyi
from langchain_openai import ChatOpenAI
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import alibabacloud_opensearch, faiss 
from langchain_community.embeddings import baidu_qianfan_endpoint
from langchain.retrievers import BM25Retriever, EnsembleRetriever
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.document_loaders import TextLoader
from langchain_core.prompts import PromptTemplate
from FlagEmbedding import FlagReranker
from langchain_community.document_loaders import DirectoryLoader, text
from langchain_community.vectorstores.utils import DistanceStrategy
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoModelForSequenceClassification
from tqdm import tqdm
import numpy as np
import torch

class RAGSystem:
    def __init__(self, data_directory, 
                 model_name="gpt-4o", 
                 embedding_model_name='BAAI/bge-m3', 
                 cache_folder="your_model_save_path", 
                 vector_db_path="your_vec_model_save_path"):
        # vector_db_path，向量数据库路径
        # temperature越高文案生成越宽泛，越低越严格，（0.0，2.0）
        self.llm = ChatOpenAI(model_name=model_name, 
                              api_key="your_api", 
                              temperature=1.2, 
                              )
        
        self.data_directory = data_directory
        self.text_splitter = self._initialize_text_splitter()
        self.RDATA = self._load_and_split_data(self.data_directory)
        self.embedding_model = self._initialize_embedding_model(embedding_model_name, cache_folder)
        self.cosine_knowledge_vector_database = self._load_vector_db(vector_db_path)
        self.retriever_vectordb = self.cosine_knowledge_vector_database.as_retriever()
        self.keyword_retriever = BM25Retriever.from_documents(self.RDATA, k=50, bm25_params={'k1': 1.5})
        self.ensemble_retriever = EnsembleRetriever(retrievers=[self.retriever_vectordb, self.keyword_retriever], weights=[0.5, 0.5])
        self.reranker = self._initialize_reranker()

    def _initialize_text_splitter(self):
        MARKDOWN_SEPARATORS = [
            "\n#{1,6} ", "```\n", "\n\\*\\*\\*+\n", "\n---+\n", "\n___+\n", "\n\n", "\n", " ", "", ".", ",", 
            "\u200B", "\uff0c", "\u3001", "\uff0e", "\u3002"
        ]
        return RecursiveCharacterTextSplitter(
            chunk_size=1000,
            chunk_overlap=100,
            add_start_index=True,
            strip_whitespace=True,
            separators=MARKDOWN_SEPARATORS,
        )

    def _load_and_split_data(self, path):
        data_loader = DirectoryLoader(path, glob="*.txt", recursive=True, loader_cls=text.TextLoader)
        raw_knowledge_base = data_loader.load()
        data = []
        print("正在加载数据库")
        for doc in tqdm(raw_knowledge_base):
            data += self.text_splitter.split_documents([doc])
        
        # Remove duplicates
        data_index = {}
        RDATA = []
        for doc in data:
            if doc.page_content not in data_index:
                data_index[doc.page_content] = True
                RDATA.append(doc)
        return RDATA

    def _initialize_embedding_model(self, model_name, cache_folder):
        print("正在加载embedding模型")
        return HuggingFaceEmbeddings(
            model_name=model_name,
            cache_folder=cache_folder,
            multi_process=True,
            model_kwargs={"device": "cuda:0", "trust_remote_code": True},
            encode_kwargs={"normalize_embeddings": True},
        )

    def _load_vector_db(self, vector_db_path):
        print("正在加载向量数据库")
        cosine_knowledge_vector_database = faiss.FAISS.load_local(vector_db_path, embeddings=self.embedding_model, allow_dangerous_deserialization=True)
        return cosine_knowledge_vector_database

    def _initialize_reranker(self):
        return FlagReranker('BAAI/bge-reranker-v2-m3', use_fp16=True, cache_dir="your_model_save_path")
    
    def add_vectorize_data(self, new_data_path, vec_data_path, save_path):
        # 将数据库向量化，非必须，每次加入新数据库时使用
        # data_path，新数据文件夹路径
        # vec_data_path, 旧向量数据库路径
        # save_path, 新向量数据库路径
        print("正在合并数据库")
        knowledge_vector_database = self._load_vector_db(vec_data_path)
        RDATA = self._load_and_split_data(new_data_path)
        add_ed_index = knowledge_vector_database.add_documents(RDATA)
        knowledge_vector_database.save_local(save_path)

    def get_relevant_knowledge(self, query, final_extract_num=20):
        relevant_doc = self.ensemble_retriever.invoke(query)
        temp_docs = [[query, doc.page_content] for doc in relevant_doc]
        score = self.reranker.compute_score(temp_docs)
        index = np.argsort(score)[-final_extract_num:][::-1]
        final_retrieval = [relevant_doc[i].page_content for i in index]
        return final_retrieval

    def generate_script(self, query):
        print("正在生成大纲")
        query_retrieval = self.get_relevant_knowledge(query)
        context = ''.join([f'文件{i}: {doc}' for i, doc in enumerate(query_retrieval)])
        script_prompt_template = PromptTemplate.from_template('''
            参考资料:
            {context}
            你的大纲。
            主题:
            {query}
            答案:
            ''')
        chain = script_prompt_template | self.llm
        script = chain.invoke({'query': query, 'context': context})
        return script.content

    def extract_keywords(self, query, script):
        # 提取关键词
        important_extraction_prompt = PromptTemplate.from_template('''
            以下是一篇根据主题写的大纲，提取其中的关键词以供后续在资料库中搜索。
            主题：
            {query}                                                           
            大纲：
            {script}
            ''')
        chain = important_extraction_prompt | self.llm
        important = chain.invoke({"query": query, "script": script})
        return important.content

    def generate_article(self, query):
        # 围绕主题生成大纲
        script = self.generate_script(query)
        # 提取大纲中的关键词
        important = self.extract_keywords(query, script)
        # 根据关键词二次搜索
        important_retrieval = self.get_relevant_knowledge(important)
        context = ''.join([f'文件{i}: {doc}' for i, doc in enumerate(important_retrieval)])
        print("正在生成文案")
        prompt_template = PromptTemplate.from_template('''
                                                    
        参考资料:
        {context} 
                                                    
        你的prompt。
                                                                                                
        大纲:
        {script}

        答案:
        ''')
        chain = prompt_template | self.llm
        article = chain.invoke({'script': script, 'context': context})
        return article.content

# Example usage:
# data_directory是文本资料库路径
rag_system = RAGSystem(data_directory="your_data_path")
query = ""
answer = rag_system.generate_article(query)
# script = rag_system.generate_script(query)
# keywords = rag_system.extract_keywords(query, script)
# important_retrieval = rag_system.get_relevant_knowledge(keywords, final_extract_num=30）


正在加载数据库


100%|██████████| 4498/4498 [00:00<00:00, 7333.04it/s] 
  warn_deprecated(


正在加载embedding模型
正在加载向量数据库
----------using 2*GPUs----------
