In [3]:
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import json
import os

class vector_database():
    def __init__(self, index_path='vector_db_index.bin', data_path='vector_db_metadata.json'):
        self.index_path = index_path
        self.data_path = data_path
        self.model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
        self.data = []  # list of dict: {id, text, prompt}
        self.id2idx = {}  # id -> index in self.data
        self.next_id = 1
        self.index = None
        self.load()

    def add_text(self, text, prompt):
        vec = self.model.encode([text])[0].astype('float32')
        id = self.next_id
        self.next_id += 1
        self.data.append({'id': id, 'text': text, 'prompt': prompt})
        self.id2idx[id] = len(self.data) - 1
        if self.index is None:
            self.index = faiss.IndexFlatL2(len(vec))
            self.index.add(np.array([vec]))
        else:
            self.index.add(np.array([vec]))
        self.save()
        return id

    def add_file(self, text_path, prompt_path):
        with open(text_path, 'r', encoding='utf-8') as f:
            text = f.read()
        with open(prompt_path, 'r', encoding='utf-8') as f:
            prompt = f.read()
        return self.add_text(text, prompt)

    def search(self, query_text, k=3):
        if self.index is None or len(self.data) == 0:
            return []
        vec = self.model.encode([query_text])[0].astype('float32')
        D, I = self.index.search(np.array([vec]), k)
        results = []
        for idx, dist in zip(I[0], D[0]):
            if idx < len(self.data):
                item = self.data[idx].copy()
                item['score'] = float(dist)
                results.append(item)
        return results

    def delete(self, id):
        if id not in self.id2idx:
            return False
        idx = self.id2idx[id]
        # Remove from data
        self.data.pop(idx)
        # Rebuild id2idx
        self.id2idx = {item['id']: i for i, item in enumerate(self.data)}
        # Rebuild index
        if len(self.data) > 0:
            vecs = self.model.encode([item['text'] for item in self.data]).astype('float32')
            self.index = faiss.IndexFlatL2(vecs.shape[1])
            self.index.add(vecs)
        else:
            self.index = None
        self.save()
        return True

    def save(self):
        # Save data
        with open(self.data_path, 'w', encoding='utf-8') as f:
            json.dump({'data': self.data, 'next_id': self.next_id}, f, ensure_ascii=False, indent=2)
        # Save index
        if self.index is not None:
            faiss.write_index(self.index, self.index_path)

    def load(self):
        # Load data
        if os.path.exists(self.data_path):
            with open(self.data_path, 'r', encoding='utf-8') as f:
                obj = json.load(f)
                if isinstance(obj, dict):
                    self.data = obj.get('data', [])
                    self.next_id = obj.get('next_id', 1)
                elif isinstance(obj, list):
                    self.data = obj
                    self.next_id = 1 if not self.data else max(item.get('id', 0) for item in self.data) + 1
                else:
                    self.data = []
                    self.next_id = 1
                self.id2idx = {item['id']: i for i, item in enumerate(self.data)}
        else:
            self.data = []
            self.next_id = 1
            self.id2idx = {}
        # Load index
        if os.path.exists(self.index_path) and len(self.data) > 0:
            self.index = faiss.read_index(self.index_path)
        elif len(self.data) > 0:
            vecs = self.model.encode([item['text'] for item in self.data]).astype('float32')
            self.index = faiss.IndexFlatL2(vecs.shape[1])
            self.index.add(vecs)
        else:
            self.index = None


In [None]:
import gradio as gr

db = vector_database()

# 假设 db 已经初始化
def search_func(query, k):
    results = db.search(query, k=int(k))
    if not results:
        return "未找到相关章节。"
    output = ["推荐的章节："]
    for i, item in enumerate(results, 1):
        output.append(f"{i}. {item['prompt']}，内容是{item['text']}")
    return "\n".join(output)

def insert_func(text, prompt):
    new_id = db.add_text(text, prompt)
    return f"插入成功，ID: {new_id}"

def delete_func(id_to_delete):
    try:
        id_int = int(id_to_delete)
    except Exception:
        return "请输入有效的数字ID"
    ok = db.delete(id_int)
    return "删除成功" if ok else "未找到该ID"

with gr.Blocks() as demo:
    gr.Markdown("# 向量数据库交互界面")
    with gr.Tab("查询"):
        query = gr.Textbox(label="查询内容")
        k = gr.Number(label="返回条数", value=3)
        search_btn = gr.Button("查询")
        search_output = gr.Textbox(label="查询结果", lines=10)
        search_btn.click(search_func, [query, k], search_output)
    with gr.Tab("插入"):
        text = gr.Textbox(label="文本内容")
        prompt = gr.Textbox(label="Prompt", value="")
        insert_btn = gr.Button("插入")
        insert_output = gr.Textbox(label="插入结果")
        insert_btn.click(insert_func, [text, prompt], insert_output)
    with gr.Tab("删除"):
        del_id = gr.Textbox(label="要删除的ID")
        del_btn = gr.Button("删除")
        del_output = gr.Textbox(label="删除结果")
        del_btn.click(delete_func, del_id, del_output)

demo.launch(server_port=9080)

* Running on local URL:  http://127.0.0.1:9080
* To create a public link, set `share=True` in `launch()`.
* To create a public link, set `share=True` in `launch()`.




: 

: 