In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain_huggingface import HuggingFaceEmbeddings
import torch
# from unsloth import FastLanguageModel
import torch
max_seq_length = 10000 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

llm_path = 'Qwen/Qwen2.5-0.5B-Instruct'
model = AutoModelForCausalLM.from_pretrained(
            llm_path,
            device_map="auto",  # 自动分配设备
            trust_remote_code=True,
            torch_dtype=torch.float16,  # 使用float16精度
            load_in_4bit=load_in_4bit,  # 4bit量化
        )
tokenizer = AutoTokenizer.from_pretrained(
                llm_path,
                trust_remote_code=True,
            )
device = "cuda" if torch.cuda.is_available() else "cpu"

print("model ok")

embeddings_path = 'BAAI/bge-large-zh-v1.5'
# embeddings_path = "/home/ubuntu/embedding_models/bge-large-zh-v1.5"
embeddings = HuggingFaceEmbeddings(
    model_name=embeddings_path,
    model_kwargs={
        'device': device,
        'local_files_only': True  # 指定使用本地模型
    },
    encode_kwargs={
        'normalize_embeddings': True,
        'batch_size': 32
    }
)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


model ok


In [2]:
from langchain_huggingface import HuggingFacePipeline
pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
#         max_length=10000,
        max_new_tokens=500,
        temperature=0.7,
        top_p=0.95,
        repetition_penalty=1.15,
        no_repeat_ngram_size=3,    # 避免重复生成n-gram
        pad_token_id=tokenizer.eos_token_id,
        # 重要：禁用截断
        truncation=False,
        # 处理长文本
        return_full_text=True
    )
llm = HuggingFacePipeline(pipeline=pipe)

In [3]:
import oss2
import pandas as pd
import pdfplumber
import docx
import os
import chardet
import javalang
import ast
import esprima
import clang
## 文件处理相关函数
def init_oss():
    # OSS初始化配置
    accessKeyId = os.getenv('ACCESSKEY_ID')
    accessKeySecret = os.getenv('ACCESSKEY_SECRET')
    auth = oss2.Auth(accessKeyId, accessKeySecret)

    endpoint = 'http://oss-cn-beijing.aliyuncs.com'
    bucketName = 'csgroup'
    return oss2.Bucket(auth, endpoint, bucketName)

bucket = init_oss()
def process_folder(folder_path):
    result = []
    supported_extensions = ['.txt', '.csv', '.docx', '.pdf', '.xlsx', '.cpp', '.py', '.c', '.h', '.hpp', '.java', '.js']

    for root, dirs, files in os.walk(folder_path):
        for file in files:
            # 获取文件路径
            file_path = os.path.join(root, file)
            # 获取文件扩展名，不区分大小写
            file_extension = os.path.splitext(file)[1].lower()
            relative_path = os.path.relpath(file_path, folder_path)

            if file_extension in supported_extensions:
                try:
                    # content = extract_content(file_path, file_extension)
                    # result.append(content)
                    content = extract_content(file_path, file_extension)
                    # 创建包含文件信息的字典
                    doc_info = {
                        'filename': file,
                        'path': relative_path,
                        'extension': file_extension,
                        'content': content
                    }
                    result.append(doc_info)
                    print(f"成功处理文件: {relative_path}")
                except Exception as e:
                    print(f"处理文件 {file_path} 时出错: {str(e)}")

    return result

## 普通文件
def extract_document(file_path, file_extension):

    if file_extension == '.txt':
        with open(file_path, 'rb') as f:
            raw_data = f.read()
            detected = chardet.detect(raw_data)
            return raw_data.decode(detected['encoding'])

    elif file_extension == '.csv':
        df = pd.read_csv(file_path)
        return df.to_string()

    elif file_extension == '.docx':
        doc = docx.Document(file_path)
        return "\n".join(paragraph.text for paragraph in doc.paragraphs)

    elif file_extension == '.pdf':
        text = ''
        with pdfplumber.open(file_path) as pdf:
            for page in pdf.pages:
                text += page.extract_text() + '\n'
        return text

    elif file_extension == '.xlsx':
        df = pd.read_excel(file_path)
        return df.to_string()

## 代码文件
def extract_code(file_path, file_extension):
    with open(file_path, 'r', encoding='utf-8') as f:
        content = f.read()
    # Java代码解析
    if file_extension == '.java':
        try:
            tree = javalang.parse.parse(content)
            # 提取类名、方法名、变量名等
            analysis = []
            for path, node in tree.filter(javalang.tree.ClassDeclaration):
                analysis.append(f"类名: {node.name}")
                for method in node.methods:
                    analysis.append(f"方法: {method.name}")
                    if method.documentation:
                        analysis.append(f"文档: {method.documentation}")
            return "\n".join(analysis) + "\n原始代码:\n" + content
        except:
            return content

    # Python代码解析
    elif file_extension == '.py':
        try:
            tree = ast.parse(content)
            analysis = []
            for node in ast.walk(tree):
                if isinstance(node, ast.ClassDef):
                    analysis.append(f"类名: {node.name}")
                elif isinstance(node, ast.FunctionDef):
                    analysis.append(f"函数: {node.name}")
                    if ast.get_docstring(node):
                        analysis.append(f"文档: {ast.get_docstring(node)}")
            return "\n".join(analysis) + "\n原始代码:\n" + content
        except:
            return content

    # C/C++代码解析
    elif file_extension in ['.cpp', '.c', '.h', '.hpp']:
        try:
            index = clang.cindex.Index.create()
            tu = index.parse(file_path)
            analysis = []

            def process_node(node):
                if node.kind == clang.cindex.CursorKind.FUNCTION_DECL:
                    analysis.append(f"函数: {node.spelling}")
                elif node.kind == clang.cindex.CursorKind.CLASS_DECL:
                    analysis.append(f"类名: {node.spelling}")
                for child in node.get_children():
                    process_node(child)

            process_node(tu.cursor)
            return "\n".join(analysis) + "\n原始代码:\n" + content
        except:
            return content

    # JavaScript代码解析
    elif file_extension == '.js':
        try:
            ast = esprima.parseScript(content)
            analysis = []

            def process_node(node):
                if node.type == 'FunctionDeclaration':
                    analysis.append(f"函数: {node.id.name}")
                elif node.type == 'ClassDeclaration':
                    analysis.append(f"类名: {node.id.name}")

            for node in ast.body:
                process_node(node)
            return "\n".join(analysis) + "\n原始代码:\n" + content
        except:
            return content

def extract_content(file_path, file_extension):
    if file_extension in ['.txt', '.csv', '.docx', '.pdf', '.xlsx']:
        return extract_document(file_path, file_extension)

    elif file_extension in ['.cpp', '.py', '.c', '.h', '.hpp', '.java', '.js']:
        return extract_code(file_path, file_extension)

    else:
        raise ValueError(f"不支持的文件类型: {file_extension}")

In [4]:
## 创建检索器
import torch
import shutil
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.prompts import PromptTemplate
import chardet

import chromadb
from langchain.schema import Document
from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain.retrievers.multi_query import MultiQueryRetriever
QUERY_PROMPT = PromptTemplate(
    input_variables=["question"],
    template="""你是一个AI助手。请基于用户的原始问题，生成几个不同的查询表达方式，以便从文档数据库中检索相关信息。
    这些查询应该从不同角度来表达相同的意图。每个查询用换行符分隔。

    原始问题: {question}
    生成的查询:""",
)

def create_retriever(folder_path):
    # 处理文件夹中的所有文件
    file_contents = process_folder(folder_path)
    
    # 创建Document对象列表
    documents = []
    for doc in file_contents:
        formatted_content = f"""
        文件名: {doc['filename']}
        文件路径: {doc['path']}
        文件类型: {doc['extension']}
        ---
        {doc['content']}
        """
        # 创建Document对象
        documents.append(
            Document(
                # page_content=formatted_content,
                page_content=doc['content'],
                metadata={
                    'source': doc['path'],
                    'filename': doc['filename'],
                    'extension': doc['extension']
                }
            )
        )

    # 使用文本分割器
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=10000,
        chunk_overlap=200,
        separators=["\n\n", "\n", "。", "！", "？", ".", "!", "?", " "],
        length_function=len,
    )
    
    # 直接分割Document对象
    split_documents = text_splitter.split_documents(documents)
    
    persist_directory = "chroma_db"
    if os.path.exists(persist_directory):
        shutil.rmtree(persist_directory)
        print(f"已删除原有数据库: {persist_directory}")
    # 创建向量存储
    docsearch = Chroma.from_documents(
        documents=split_documents,
        embedding=embeddings,
        persist_directory=persist_directory,  # 添加持久化目录
            client_settings=chromadb.config.Settings(
                anonymized_telemetry=False,
                is_persistent=True
            )
    )
    # 6. 确保持久化
    docsearch.persist()
    def check_document_count(vectordb, required_k=3):
        total_docs = len(vectordb.get())
        if total_docs < required_k:
            print(f"警告: 数据库中只有 {total_docs} 个文档，少于请求的 {required_k} 个")
            print("建议: 添加更多文档或减少请求数量")
        return total_docs

    # 使用前检查
    total_docs = check_document_count(docsearch)
    # 创建检索器
    base_retriever = docsearch.as_retriever(
        search_type="mmr",  # 使用MMR搜索策略
        search_kwargs={
            "k": min(1, total_docs),         # 返回的文档数
            "fetch_k": 3,   # 初始获取的文档数
            "lambda_mult": 0.9  # MMR多样性参数
        }
    )
    retriever = MultiQueryRetriever.from_llm(
        base_retriever, 
        llm,
        prompt=QUERY_PROMPT,
#         verbose=True  # 显示生成的查询
    )
    return retriever

In [6]:
# 获取相关文档
folder_path = "D:/Desktop/NLP/病例样本"
retriever = create_retriever(folder_path)
question = "李建华的病历"
relevant_docs = retriever.get_relevant_documents(question)

# 打印检索结果
for i, doc in enumerate(relevant_docs, 1):
    print(f"\n文档 {i}:")
    print(f"内容: {doc.page_content[:200]}...")  # 显示前200个字符
    print(f"来源: {doc.metadata['source']}")
    print(f"文件名: {doc.metadata['filename']}")

成功处理文件: 病例1.txt
成功处理文件: 病例10.txt
成功处理文件: 病例100.txt
成功处理文件: 病例11.txt
成功处理文件: 病例12.txt
成功处理文件: 病例13.txt
成功处理文件: 病例14.txt
成功处理文件: 病例15.txt
成功处理文件: 病例16txt.txt
成功处理文件: 病例17txt.txt
成功处理文件: 病例18txt.txt
成功处理文件: 病例19txt.txt
成功处理文件: 病例2.txt
成功处理文件: 病例20txt.txt
成功处理文件: 病例21txt.txt
成功处理文件: 病例22txt.txt
成功处理文件: 病例23txt.txt
成功处理文件: 病例24txt.txt
成功处理文件: 病例25txt.txt
成功处理文件: 病例26txt.txt
成功处理文件: 病例27txt.txt
成功处理文件: 病例28txt.txt
成功处理文件: 病例29txt.txt
成功处理文件: 病例3.txt
成功处理文件: 病例30txt.txt
成功处理文件: 病例31txt.txt
成功处理文件: 病例32txt.txt
成功处理文件: 病例33txt.txt
成功处理文件: 病例34txt.txt
成功处理文件: 病例35txt.txt
成功处理文件: 病例36.txt
成功处理文件: 病例37.txt
成功处理文件: 病例38.txt
成功处理文件: 病例39.txt
成功处理文件: 病例4.txt
成功处理文件: 病例40.txt
成功处理文件: 病例41.txt
成功处理文件: 病例42.txt
成功处理文件: 病例43.txt
成功处理文件: 病例44.txt
成功处理文件: 病例45.txt
成功处理文件: 病例46.txt
成功处理文件: 病例47.txt
成功处理文件: 病例48.txt
成功处理文件: 病例49.txt
成功处理文件: 病例5.txt
成功处理文件: 病例50.txt
成功处理文件: 病例51.txt
成功处理文件: 病例52.txt
成功处理文件: 病例53.txt
成功处理文件: 病例54.txt
成功处理文件: 病例55.txt
成功处理文件: 病例56.txt
成功处理文件: 病例57.txt
成功处理文件: 病例58.txt
成功处理文件: 病

PermissionError: [WinError 32] 另一个程序正在使用此文件，进程无法访问。: 'chroma_db\\456e9311-1cbd-4ff9-a2f0-83bf1290c853\\data_level0.bin'