In [None]:
# !python -m pip install --upgrade pip
# !pip install langchain_community
# !pip install langchain_huggingface
# !pip install oss2
# !pip install python-docx
# !pip install pandas
# !pip install pdfplumber
# !pip install chardet
# !pip install langchain
# !pip install transformers
# !pip install torch
# !pip install sentence_transformers
# !pip install chromadb
# !pip install pyngrok
# !pip install flask_cors
# !pip install javalang
# !pip install libclang
# !pip install esprima
# !pip install unsloth
# !pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
# !curl -s https://ngrok-agent.s3.amazonaws.com/ngrok.asc | sudo tee /etc/apt/trusted.gpg.d/ngrok.asc >/dev/null && echo "deb https://ngrok-agent.s3.amazonaws.com buster main" | sudo tee /etc/apt/sources.list.d/ngrok.list && sudo apt update && sudo apt install ngrok

In [2]:
import os
os.environ['ACCESSKEY_ID'] = ''
os.environ['ACCESSKEY_SECRET'] = ''
!ngrok authtoken 

In [3]:
import oss2
import pandas as pd
import pdfplumber
import docx
## 文件处理相关函数
def init_oss():
    # OSS初始化配置
    accessKeyId = os.getenv('ACCESSKEY_ID')
    accessKeySecret = os.getenv('ACCESSKEY_SECRET')
    auth = oss2.Auth(accessKeyId, accessKeySecret)

    endpoint = 'http://oss-cn-beijing.aliyuncs.com'
    bucketName = 'csgroup'
    return oss2.Bucket(auth, endpoint, bucketName)

bucket = init_oss()
def process_folder(folder_path):
    result = []
    supported_extensions = ['.txt', '.csv', '.docx', '.pdf', '.xlsx', '.cpp', '.py', '.c', '.h', '.hpp', '.java', '.js']

    for root, dirs, files in os.walk(folder_path):
        for file in files:
            # 获取文件路径
            file_path = os.path.join(root, file)
            # 获取文件扩展名，不区分大小写
            file_extension = os.path.splitext(file)[1].lower()
            relative_path = os.path.relpath(file_path, folder_path)

            if file_extension in supported_extensions:
                try:
                    # content = extract_content(file_path, file_extension)
                    # result.append(content)
                    content = extract_content(file_path, file_extension)
                    # 创建包含文件信息的字典
                    doc_info = {
                        'filename': file,
                        'path': relative_path,
                        'extension': file_extension,
                        'content': content
                    }
                    result.append(doc_info)
                    print(f"成功处理文件: {relative_path}")
                except Exception as e:
                    print(f"处理文件 {file_path} 时出错: {str(e)}")

    return result

## 普通文件
def extract_document(file_path, file_extension):

    if file_extension == '.txt':
        with open(file_path, 'rb') as f:
            raw_data = f.read()
            detected = chardet.detect(raw_data)
            return raw_data.decode(detected['encoding'])

    elif file_extension == '.csv':
        df = pd.read_csv(file_path)
        return df.to_string()

    elif file_extension == '.docx':
        doc = docx.Document(file_path)
        return "\n".join(paragraph.text for paragraph in doc.paragraphs)

    elif file_extension == '.pdf':
        text = ''
        with pdfplumber.open(file_path) as pdf:
            for page in pdf.pages:
                text += page.extract_text() + '\n'
        return text

    elif file_extension == '.xlsx':
        df = pd.read_excel(file_path)
        return df.to_string()

## 代码文件
def extract_code(file_path, file_extension):
    with open(file_path, 'r', encoding='utf-8') as f:
        content = f.read()
    # Java代码解析
    if file_extension == '.java':
        try:
            tree = javalang.parse.parse(content)
            # 提取类名、方法名、变量名等
            analysis = []
            for path, node in tree.filter(javalang.tree.ClassDeclaration):
                analysis.append(f"类名: {node.name}")
                for method in node.methods:
                    analysis.append(f"方法: {method.name}")
                    if method.documentation:
                        analysis.append(f"文档: {method.documentation}")
            return "\n".join(analysis) + "\n原始代码:\n" + content
        except:
            return content

    # Python代码解析
    elif file_extension == '.py':
        try:
            tree = ast.parse(content)
            analysis = []
            for node in ast.walk(tree):
                if isinstance(node, ast.ClassDef):
                    analysis.append(f"类名: {node.name}")
                elif isinstance(node, ast.FunctionDef):
                    analysis.append(f"函数: {node.name}")
                    if ast.get_docstring(node):
                        analysis.append(f"文档: {ast.get_docstring(node)}")
            return "\n".join(analysis) + "\n原始代码:\n" + content
        except:
            return content

    # C/C++代码解析
    elif file_extension in ['.cpp', '.c', '.h', '.hpp']:
        try:
            index = clang.cindex.Index.create()
            tu = index.parse(file_path)
            analysis = []

            def process_node(node):
                if node.kind == clang.cindex.CursorKind.FUNCTION_DECL:
                    analysis.append(f"函数: {node.spelling}")
                elif node.kind == clang.cindex.CursorKind.CLASS_DECL:
                    analysis.append(f"类名: {node.spelling}")
                for child in node.get_children():
                    process_node(child)

            process_node(tu.cursor)
            return "\n".join(analysis) + "\n原始代码:\n" + content
        except:
            return content

    # JavaScript代码解析
    elif file_extension == '.js':
        try:
            ast = esprima.parseScript(content)
            analysis = []

            def process_node(node):
                if node.type == 'FunctionDeclaration':
                    analysis.append(f"函数: {node.id.name}")
                elif node.type == 'ClassDeclaration':
                    analysis.append(f"类名: {node.id.name}")

            for node in ast.body:
                process_node(node)
            return "\n".join(analysis) + "\n原始代码:\n" + content
        except:
            return content

def extract_content(file_path, file_extension):
    if file_extension in ['.txt', '.csv', '.docx', '.pdf', '.xlsx']:
        return extract_document(file_path, file_extension)

    elif file_extension in ['.cpp', '.py', '.c', '.h', '.hpp', '.java', '.js']:
        return extract_code(file_path, file_extension)

    else:
        raise ValueError(f"不支持的文件类型: {file_extension}")

In [4]:
import shutil

def delete_all_files(folder_path='documents_for_analyse'):
    try:
        # 确保路径存在
        if not os.path.exists(folder_path):
            print(f"路径不存在: {folder_path}")
            return False
            
        # 遍历文件夹中的所有内容
        for filename in os.listdir(folder_path):
            file_path = os.path.join(folder_path, filename)
            try:
                if os.path.isfile(file_path):
                    # 如果是文件，直接删除
                    os.unlink(file_path)
                    print(f"已删除文件: {filename}")
                elif os.path.isdir(file_path):
                    # 如果是文件夹，删除整个文件夹及其内容
                    shutil.rmtree(file_path)
                    print(f"已删除文件夹: {filename}")
            except Exception as e:
                print(f"删除 {filename} 时出错: {str(e)}")
                
        print(f"已清空文件夹: {folder_path}")
        return True
        
    except Exception as e:
        print(f"删除过程中出错: {str(e)}")
        return False

def download_all_files():
    global bucket
    local_dir = 'documents_for_analyse'

    # 5. 列出存储桶中的所有文件并下载
    for obj in oss2.ObjectIterator(bucket):
        objectName = obj.key
        local_file_path = os.path.join(local_dir, objectName)

        # 如果对象路径包含目录结构，确保在本地创建相应的目录
        os.makedirs(os.path.dirname(local_file_path), exist_ok=True)

        # 下载文件
        print(f'Downloading {objectName} to {local_file_path}')
        bucket.get_object_to_file(objectName, local_file_path)

download_all_files()

Downloading colab_url.txt to documents_for_analyse/colab_url.txt
Downloading 病例样本/丁珊.txt to documents_for_analyse/病例样本/丁珊.txt
Downloading 病例样本/丁磊.txt to documents_for_analyse/病例样本/丁磊.txt
Downloading 病例样本/刘东风.txt to documents_for_analyse/病例样本/刘东风.txt
Downloading 病例样本/刘婷.txt to documents_for_analyse/病例样本/刘婷.txt
Downloading 病例样本/刘婷婷.txt to documents_for_analyse/病例样本/刘婷婷.txt
Downloading 病例样本/刘宇.txt to documents_for_analyse/病例样本/刘宇.txt
Downloading 病例样本/刘志.txt to documents_for_analyse/病例样本/刘志.txt
Downloading 病例样本/刘旭东.txt to documents_for_analyse/病例样本/刘旭东.txt
Downloading 病例样本/刘晓云.txt to documents_for_analyse/病例样本/刘晓云.txt
Downloading 病例样本/刘涛.txt to documents_for_analyse/病例样本/刘涛.txt
Downloading 病例样本/刘琳.txt to documents_for_analyse/病例样本/刘琳.txt
Downloading 病例样本/刘红.txt to documents_for_analyse/病例样本/刘红.txt
Downloading 病例样本/刘雨婷.txt to documents_for_analyse/病例样本/刘雨婷.txt
Downloading 病例样本/吴婷.txt to documents_for_analyse/病例样本/吴婷.txt
Downloading 病例样本/吴洁.txt to documents_for_analyse/病例样本/吴洁.txt
Downloadin

In [5]:
# from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain_huggingface import HuggingFaceEmbeddings
import torch
from unsloth import FastLanguageModel
max_seq_length = 10000 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

llm_path = 'shared-nvme/llm_models/Qwen2.5-7B-Instruct'
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = llm_path,
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = FastLanguageModel.for_inference(model)

print("model ok")

embeddings_path = 'shared-nvme/embedding_models/bge-large-zh-v1.5'
# embeddings_path = "/home/ubuntu/embedding_models/bge-large-zh-v1.5"
embeddings = HuggingFaceEmbeddings(
    model_name=embeddings_path,
    model_kwargs={
        'device': device,
        'local_files_only': True  # 指定使用本地模型
    },
    encode_kwargs={
        'normalize_embeddings': True,
        'batch_size': 32
    }
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


  from .autonotebook import tqdm as notebook_tqdm


Unsloth: Your Flash Attention 2 installation seems to be broken?
A possible explanation is you have a new CUDA version which isn't
yet compatible with FA2? Please file a ticket to Unsloth or FA2.
We shall now use Xformers instead, which does not have any performance hits!
We found this negligible impact by benchmarking on 1x A100.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2024.11.9: Fast Qwen2 patching. Transformers = 4.46.3.
   \\   /|    GPU: NVIDIA GeForce RTX 3090. Max memory: 23.684 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.5.1+cu124. CUDA = 8.6. CUDA Toolkit = 12.4.
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.28.post3. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards: 100%|██████████| 4/4 [00:07<00:00,  1.93s/it]


model ok


In [6]:
## 创建检索器
from flask import Flask, request, jsonify
from flask_cors import CORS
import torch
from transformers import pipeline
from langchain_huggingface import HuggingFacePipeline
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
import chardet

from sentence_transformers import SentenceTransformer
import chromadb
import pickle
from datetime import datetime
from langchain.memory import ConversationBufferMemory
from langchain.chains import ConversationalRetrievalChain
# from langchain_community.chains import ConversationalRetrievalQA
from langchain.schema import HumanMessage, AIMessage, SystemMessage  # 添加这行
from langchain.schema import Document
from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain.retrievers.multi_query import MultiQueryRetriever
QUERY_PROMPT = PromptTemplate(
    input_variables=["question"],
    template="""你是一个AI助手。请针对用户的原始问题，生成几个很简单的查询表达方式，以便从文档数据库中检索相关信息。
    这些查询应该从不同角度来表达相同的意图。每个查询用换行符分隔。

    原始问题: {question}
    生成的查询:""",
)

def create_retriever(folder_path):
    # 处理文件夹中的所有文件
    file_contents = process_folder(folder_path)
    
    # 创建Document对象列表
    documents = []
    for doc in file_contents:
        formatted_content = f"""
        文件名: {doc['filename']}
        文件路径: {doc['path']}
        文件类型: {doc['extension']}
        ---
        {doc['content']}
        """
        # 创建Document对象
        documents.append(
            Document(
                page_content=formatted_content,
                metadata={
                    'source': doc['path'],
                    'filename': doc['filename'],
                    'extension': doc['extension']
                }
            )
        )

    # 使用文本分割器
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=10000,
        chunk_overlap=200,
        separators=["\n\n", "\n", "。", "！", "？", ".", "!", "?", " "],
        length_function=len,
    )
    
    # 直接分割Document对象
    split_documents = text_splitter.split_documents(documents)
    
    persist_directory = "chroma_db"
    # 创建向量存储
    docsearch = Chroma.from_documents(
        documents=split_documents,
        embedding=embeddings,
        persist_directory=persist_directory,  # 添加持久化目录
            client_settings=chromadb.config.Settings(
                anonymized_telemetry=False,
                is_persistent=True
            )
    )
    # 6. 确保持久化
    docsearch.persist()
    def check_document_count(vectordb, required_k=3):
        total_docs = len(vectordb.get())
        if total_docs < required_k:
            print(f"警告: 数据库中只有 {total_docs} 个文档，少于请求的 {required_k} 个")
            print("建议: 添加更多文档或减少请求数量")
        return total_docs

    # 使用前检查
    total_docs = check_document_count(docsearch)
    # 创建检索器
    base_retriever = docsearch.as_retriever(
        search_type="mmr",  # 使用MMR搜索策略
        search_kwargs={
            "k": min(2, total_docs),         # 返回的文档数
            "fetch_k": 3,   # 初始获取的文档数
            "lambda_mult": 0.9  # MMR多样性参数
        }
    )
    retriever = MultiQueryRetriever.from_llm(
        base_retriever, 
        llm,
        prompt=QUERY_PROMPT,
#         verbose=True  # 显示生成的查询
    )
    return retriever

In [7]:
from transformers import StoppingCriteria, StoppingCriteriaList, MaxLengthCriteria
from transformers import TextStreamer
def generate_with_retrieval(question, retriever, model, tokenizer):
    # 1. 使用检索器获取相关文档
    retrieved_docs = retriever.get_relevant_documents(question)
    print('检索成功')
    # 2. 构建上下文
    context = "\n\n".join([
        f"文档内容：{doc.page_content}\n来源：{doc.metadata['source']}"
        for doc in retrieved_docs
    ])

    # 3. 构建提示模板
    prompt = f"""基于以下上下文信息回答问题。如果上下文中没有相关信息，请说明。

上下文信息：
{context}

问题：{question}

回答："""
    messages = [
        {"role": "user", "content": prompt},
    ]
    # 4. 生成回答
    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt = True,
        return_tensors = "pt",
    ).to("cuda")

    text_streamer = TextStreamer(tokenizer, skip_prompt = True)
    outputs = model.generate(input_ids,
                             streamer = text_streamer,
                             max_new_tokens = 512,
                             pad_token_id = tokenizer.eos_token_id,
                             eos_token_id=tokenizer.eos_token_id
                            )
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    start_index = response.find('回答：\nassistant')+len('回答：\nassistant')
    response = response[start_index:]
    return response
def generate(question, model, tokenizer):
    prompt = f"""请回答以下问题。

问题：{question}

回答："""
    messages = [
        {"role": "user", "content": prompt},
    ]
    # 4. 生成回答
    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt = True,
        return_tensors = "pt",
    ).to("cuda")

    text_streamer = TextStreamer(tokenizer, skip_prompt = True)
    outputs = model.generate(input_ids,
                             streamer = text_streamer,
                             max_new_tokens = 512,
                             pad_token_id = tokenizer.eos_token_id,
                             eos_token_id=tokenizer.eos_token_id
                            )
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    start_index = response.find('回答：\nassistant')+len('回答：\nassistant')
    response = response[start_index:]
    return response

In [None]:
from threading import Thread
from pyngrok import ngrok
import time
FastLanguageModel.for_inference(model)
download_all_files()

def save_url_to_oss(url):
    global bucket
    try:
        bucket.put_object('colab_url.txt', url)
        print("成功将URL保存到OSS")
    except Exception as e:
        print("保存URL到OSS时发生错误:", str(e))

def maintain_ngrok_connection():
    while True:
        try:
            # 检查现有隧道
            tunnels = ngrok.get_tunnels()
            if not tunnels:
                print("Ngrok隧道断开，正在重新连接...")
                # 关闭所有现有隧道
                ngrok.kill()
                # 创建新隧道
                ngrok_tunnel = ngrok.connect(5000)
                # 更新OSS中的URL
                save_url_to_oss(ngrok_tunnel.public_url)
                print(f"新的URL: {ngrok_tunnel.public_url}")

            time.sleep(60)  # 每30秒检查一次

        except Exception as e:
            print(f"Ngrok维护出错: {str(e)}")
            time.sleep(10)

# 2. 启动维护线程
def start_ngrok_with_maintenance():
    try:
        # 先关闭现有隧道
        ngrok.kill()
        # 创建新隧道
        ngrok_tunnel = ngrok.connect(5000)
        print('Colab服务器URL:', ngrok_tunnel.public_url)
        save_url_to_oss(ngrok_tunnel.public_url)

        # 启动维护线程
        maintenance_thread = Thread(target=maintain_ngrok_connection, daemon=True)
        maintenance_thread.start()

    except Exception as e:
        print(f"启动Ngrok时出错: {str(e)}")

start_ngrok_with_maintenance()
app = Flask(__name__)
CORS(app)

pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
#         max_length=10000,
        max_new_tokens=500,
        temperature=0.7,
        top_p=0.95,
        repetition_penalty=1.15,
        no_repeat_ngram_size=3,    # 避免重复生成n-gram
        pad_token_id=tokenizer.eos_token_id,
        # 重要：禁用截断
        truncation=False,
        # 处理长文本
        return_full_text=True
    )
llm = HuggingFacePipeline(pipeline=pipe)
retriever = create_retriever('documents_for_analyse')
#不同base_retriever
    
@app.route('/qa', methods=['POST'])
def qa():
    try:
        chromadb.api.client.SharedSystemClient.clear_system_cache()
        data = request.get_json()  # 获取POST请求中的JSON 数据
        question = data.get('question')  # 提取'question'字段
        model_type = data.get('model_type', 'without RAG')
        print(model_type)
        print("提取问题成功")
        
        if model_type == 'RAG combined':
            result = generate_with_retrieval(
                question=question,
                retriever=retriever,  # 你之前创建的检索器
                model=model,
                tokenizer=tokenizer
            )
        else:
            result = generate(question=question, model=model, tokenizer=tokenizer)
        
        # print("有用回答："+helpful_answer_content)
        print("返回前")
        return jsonify({
            'status': 'success',
            'question': question,
            'answer': result
        })

    except Exception as e:
        print(str(e))
        return jsonify({
            'status': 'error',
            'message': str(e)
        }), 500

@app.route('/update', methods=['POST'])
def update_knowledge_base():
    global retriever
    try:
        chromadb.api.client.SharedSystemClient.clear_system_cache()
         
        # 2. 删除旧的数据库目录
        persist_directory = "chroma_db"
        if os.path.exists(persist_directory):
            shutil.rmtree(persist_directory)
            print("已删除旧的数据库")
            
        delete_all_files()
        # 下载新文件
        download_all_files()
        print('文件下载成功')
        retriever = create_retriever('documents_for_analyse')
        return jsonify({
            'status': 'success',
            'message': '知识库更新成功'
        })
    except Exception as e:
        print(f"更新知识库时出错: {str(e)}")
        return jsonify({
            'status': 'error',
            'message': str(e)
        }), 500


ngrok_tunnel = ngrok.connect(5000)
print('Colab服务器URL:', ngrok_tunnel.public_url)
save_url_to_oss(ngrok_tunnel.public_url)

# 启动Flask应用
app.run(host='0.0.0.0', port=5000)

Downloading colab_url.txt to documents_for_analyse/colab_url.txt
Downloading 病例样本/丁珊.txt to documents_for_analyse/病例样本/丁珊.txt
Downloading 病例样本/丁磊.txt to documents_for_analyse/病例样本/丁磊.txt
Downloading 病例样本/刘东风.txt to documents_for_analyse/病例样本/刘东风.txt
Downloading 病例样本/刘婷.txt to documents_for_analyse/病例样本/刘婷.txt
Downloading 病例样本/刘婷婷.txt to documents_for_analyse/病例样本/刘婷婷.txt
Downloading 病例样本/刘宇.txt to documents_for_analyse/病例样本/刘宇.txt
Downloading 病例样本/刘志.txt to documents_for_analyse/病例样本/刘志.txt
Downloading 病例样本/刘旭东.txt to documents_for_analyse/病例样本/刘旭东.txt
Downloading 病例样本/刘晓云.txt to documents_for_analyse/病例样本/刘晓云.txt
Downloading 病例样本/刘涛.txt to documents_for_analyse/病例样本/刘涛.txt
Downloading 病例样本/刘琳.txt to documents_for_analyse/病例样本/刘琳.txt
Downloading 病例样本/刘红.txt to documents_for_analyse/病例样本/刘红.txt
Downloading 病例样本/刘雨婷.txt to documents_for_analyse/病例样本/刘雨婷.txt
Downloading 病例样本/吴婷.txt to documents_for_analyse/病例样本/吴婷.txt
Downloading 病例样本/吴洁.txt to documents_for_analyse/病例样本/吴洁.txt
Downloadin

  docsearch.persist()
 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:5000
 * Running on http://10.64.100.92:5000
[33mPress CTRL+C to quit[0m
t=2024-12-03T18:46:04+0800 lvl=warn msg="failed to check for update" obj=updater err="Post \"https://update.equinox.io/check\": context deadline exceeded"
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


without RAG
提取问题成功
您提供的信息不足以确定是谁头痛。如果您能提供更多背景信息，我将能够更好地回答您的问题。<|im_end|>


127.0.0.1 - - [03/Dec/2024 18:46:36] "POST /qa HTTP/1.1" 200 -


返回前
RAG combined
提取问题成功


  retrieved_docs = retriever.get_relevant_documents(question)


检索成功
根据提供的信息，有头痛症状的病人是郑浩和王芳。

- 郑浩的主诉是头痛伴视力模糊2天。
- 王芳的主诉是头痛3天。<|im_end|>


127.0.0.1 - - [03/Dec/2024 18:47:16] "POST /qa HTTP/1.1" 200 -


返回前
