In [None]:
from typing import List

def split_into_chunks(doc_file: str) -> List[str]:
    with open(doc_file, 'r') as file:
        content = file.read()
    return [chunk for chunk in content.split("\n\n")]

chunks = split_into_chunks("doc.md")

for i, chunk in enumerate(chunks):
    print(f"[{i}] {chunk}\n")

In [None]:
from sentence_transformers import SentenceTransformer

embedding_model = SentenceTransformer("shibing624/text2vec-base-chinese")
def embed_chunk(chunk: str) -> List[float]:
    embedding = embedding_model.encode(chunk)
    return embedding.tolist()

test_embedding = embed_chunk("测试内容")
print(len(test_embedding))
print(test_embedding)

In [None]:
embeddings = [embed_chunk(chunk) for chunk in chunks]
print(len(embeddings))
print(embeddings[0])

In [None]:
import chromadb

chromadb_client = chromadb.EphemeralClient()
chromadb_collection = chromadb_client.get_or_create_collection(name="default")

def save_embeddings(chunks: List[str], embeddings: List[List[float]]) -> None:
    ids = [str(i) for i in range(len(chunks))]
    chromadb_collection.add(
        documents=chunks,
        embeddings=embeddings,
        ids=ids
    )

save_embeddings(chunks, embeddings)

In [None]:
def retrieve(query: str, top_k: int) -> List[str]:
    query_embedding = embed_chunk(query)
    results = chromadb_collection.query(
        query_embeddings=[query_embedding],
        n_results=top_k
    )
    return results['documents'][0]

query = "哆啦A梦使用的3个秘密道具分别是什么？"
retrieved_chunks = retrieve(query, 5)

for i, chunk in enumerate(retrieved_chunks):
    print(f"[{i}] {chunk}\n")

In [None]:
from sentence_transformers import CrossEncoder

def rerank(query: str, retrieved_chunks: List[str], top_k: int) -> List[str]:
    cross_encoder = CrossEncoder('cross-encoder/mmarco-mMiniLMv2-L12-H384-v1')
    pairs = [(query, chunk) for chunk in retrieved_chunks]
    scores = cross_encoder.predict(pairs)

    chunk_with_score_list = [(chunk, score) for chunk, score in zip(retrieved_chunks, scores)]
    chunk_with_score_list.sort(key=lambda pair: pair[1], reverse=True)

    return [chunk for chunk, _ in chunk_with_score_list][:top_k]

reranked_chunks = rerank(query, retrieved_chunks, 3)

for i, chunk in enumerate(reranked_chunks):
    print(f"[{i}] {chunk}\n")

In [None]:
from dotenv import load_dotenv
import os

# Load environment variables from .env file
load_dotenv()

prompt = f"""你是一位知识助手，请根据用户的问题和下列片段生成准确的回答。

用户问题：{query}

相关片段：
{"\n\n".join(chunks)}

请基于上述内容作答，不要编造信息。"""

print(f"{prompt}\n\n---\n")

# Google Gemini API
from google import genai
GEMINI_API_KEY = os.getenv("GEMINI_API_KEY")
google_client = genai.Client(api_key=GEMINI_API_KEY)
def generate_gemini_response(query: str, chunks: List[str]) -> str:
    response = google_client.models.generate_content(
        model="gemini-2.5-flash",
        contents=prompt
    )

    return response.text

gemini_answer = generate_gemini_response(query, reranked_chunks)
print("Google Gemini Answer:")
print(gemini_answer)

# DeepSeek API
from openai import OpenAI

DEEPSEEK_API_KEY = os.getenv("DEEPSEEK_API_KEY")
deepseek_client = OpenAI(
    api_key=DEEPSEEK_API_KEY,
    base_url="https://api.deepseek.com"
)
def generate_deepseek_response(query: str, chunks: List[str]) -> str:
    response = deepseek_client.chat.completions.create(
        model="deepseek-chat",  # or "deepseek-reasoner"
        messages=[
            {"role": "system", "content": prompt},
            {"role": "user", "content": query}
        ],
        stream=False
    )

    return response.choices[0].message.content

print("\n")

deepseek_answer = generate_deepseek_response(query, reranked_chunks)
print("DeepSeek Answer:")
print(deepseek_answer)
