In [11]:
from typing import List

def split_into_chunks(file_path: str) -> List[str]:
    with open(file_path) as file:
        content = file.read()
    return content.split("\n")

chunks = split_into_chunks("data.doc")

In [None]:
from sentence_transformers import SentenceTransformer
m = SentenceTransformer("shibing624/text2vec-base-chinese")

def emabed_chunk(chunk: str) -> List[float]:
    enbadding = m.encode(chunk)
    return enbadding.tolist()
embeddings = [emabed_chunk(chunk) for chunk in chunks]


In [None]:
import chromadb

chromadb_client = chromadb.EphemeralClient()
chromadb_collection = chromadb_client.get_or_create_collection("default")

def save_embeddings(embeddings: List[List[float]], chunks: List[str]) -> None:
    for i, embedding in enumerate(embeddings):
        chromadb_collection.add(
            ids=[str(i)],
            embeddings=[embedding],
            metadatas=[{"text": chunks[i]}]
        )
save_embeddings(embeddings, chunks)

In [None]:
def search(query: str, top_k: int = 5) -> List[dict]:
    query_embedding = m.encode(query).tolist()
    results = chromadb_collection.query(
        query_embeddings=[query_embedding],
        n_results=top_k
    )
    return results

In [None]:
from sentence_transformers import CrossEncoder
cross_encoder = CrossEncoder("cross-encoder/mmarco-mMiniLMv2-L12-H384-v1")

def re_rank(query: str, results: List[dict]) -> List[dict]:
    scores = cross_encoder.predict([(query, result['text']) for result in results])
    for i, score in enumerate(scores):
        results[i]['score'] = score
    return sorted(results, key=lambda x: x['score'], reverse=True)


In [None]:
# 使用 genai 把查询出来的向量和问题一起发送到 genai, 给出一个正确的答案
from dotenv import load_dotenv
from google import genai


load_dotenv()
genai_client = genai.Client()


def generate_answer(question: str, chunks: List[str]) -> str:
    # 这里需要 嵌入 query 和 re_ranked_results, 并且写好中文提示词,
    prompt = f"根据以下上下文回答问题, 只回答问题相关数据。\n\n上下文: {chunks}\n\n问题: {question}"
    response = genai_client.models.generate_content(
        model="gemini-2.5-flash",
        contents=prompt,
    )
    return response.text


In [None]:
query = "易拉罐里有什么"
embeddingResults = search(query)
re_ranked_results = re_rank(query, embeddingResults["metadatas"][0])
answer = generate_answer(query, re_ranked_results)
print(answer)