In [1]:
%pip install -q "openai>=1.30.0" faiss-cpu PyMuPDF python-dotenv numpy


Note: you may need to restart the kernel to use updated packages.


In [None]:
import os, numpy as np, faiss, fitz, pickle, textwrap
from openai import OpenAI

# ====== 1) OpenAI API Key ======
# 方式A：直接在此填写（作业提交前记得删掉密钥）
os.environ["OPENAI_API_KEY"] = os.environ.get("OPENAI_API_KEY", "YOUR_OPENAI_API_KEY")

# 方式B：如果你有 .env 文件，取消下面两行注释
# from dotenv import load_dotenv
# load_dotenv()

client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])

# ====== 2) 模型名称（可按老师要求替换）======
MODEL_EMBED = "text-embedding-3-small"  # 或 "text-embedding-3-large" / "text-embedding-ada-002"
MODEL_CHAT  = "gpt-4o-mini"             # 或 "gpt-3.5-turbo"

# ====== 3) 文件路径 ======
# 方案A：你当前会话上传的路径（在本环境中可用）
pdf_path = "/mnt/data/Track_B_Tenancy_Agreement.pdf"  # 如果你本地运行请改成本地路径

# 方案B：本地相对路径（把文件放到 notebook 同目录，并改名）
# pdf_path = "tenancy_agreement.pdf"

assert os.path.exists(pdf_path), f"找不到PDF：{pdf_path}"

# ====== 4) 索引持久化路径 ======
INDEX_PATH  = "tenancy_index.faiss"
CHUNKS_PATH = "tenancy_chunks.pkl"


In [None]:
# -------- 读取 PDF --------
doc = fitz.open(pdf_path)
pages_text = []
for p in doc:
    pages_text.append(p.get_text())
doc.close()
full_text = "\n".join(pages_text)

# -------- 分块（尽量按段落拼到 ~1000 字符）--------
def chunk_text(text, max_len=1000):
    paras = [t.strip() for t in text.split("\n") if t.strip()]
    chunks, cur = [], ""
    for para in paras:
        if cur and len(cur) + 1 + len(para) > max_len:
            chunks.append(cur)
            cur = para
        else:
            cur = para if not cur else (cur + " " + para)
    if cur:
        chunks.append(cur)
    return chunks

chunks = chunk_text(full_text, max_len=1000)
print(f"Total chunks: {len(chunks)}")

# -------- 批量嵌入 --------
def embed_texts(texts, model=MODEL_EMBED, batch_size=32):
    vecs = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        resp = client.embeddings.create(model=model, input=batch)
        vecs.extend([d.embedding for d in resp.data])
    return np.array(vecs, dtype=np.float32)

emb = embed_texts(chunks, model=MODEL_EMBED, batch_size=16)
print("Embedding shape:", emb.shape)

# -------- 归一化 + 构建 FAISS 余弦相似索引 --------
# 归一化后用内积即等价于余弦相似
faiss.normalize_L2(emb)
index = faiss.IndexFlatIP(emb.shape[1])
index.add(emb)
print("FAISS ntotal:", index.ntotal)

# -------- 保存索引与分块 --------
faiss.write_index(index, INDEX_PATH)
with open(CHUNKS_PATH, "wb") as f:
    pickle.dump(chunks, f)
print("Saved:", INDEX_PATH, CHUNKS_PATH)


In [None]:
# 载入索引与分块
index = faiss.read_index(INDEX_PATH)
with open(CHUNKS_PATH, "rb") as f:
    DOC_CHUNKS = pickle.load(f)

def retrieve(query, top_k=5):
    # 查询向量
    q_emb = np.array(
        client.embeddings.create(model=MODEL_EMBED, input=[query]).data[0].embedding,
        dtype=np.float32
    )[None, :]
    faiss.normalize_L2(q_emb)
    # 检索
    D, I = index.search(q_emb, top_k)
    ctxs = [DOC_CHUNKS[i] for i in I[0]]
    return ctxs, D[0]

def answer_query(query, top_k=5, temperature=0.2):
    ctxs, scores = retrieve(query, top_k=top_k)
    context = "\n\n---\n\n".join(ctxs)

    system_msg = (
        "You are a contract-aware assistant for a tenancy agreement. "
        "Answer ONLY using the provided excerpts. If unsure or missing, say you don't know."
    )
    user_msg = f"Contract excerpts:\n{context}\n\nQuestion: {query}\nAnswer:"

    resp = client.chat.completions.create(
        model=MODEL_CHAT,
        temperature=temperature,
        messages=[
            {"role": "system", "content": system_msg},
            {"role": "user",   "content": user_msg},
        ],
    )
    return resp.choices[0].message.content.strip(), ctxs, scores

# 小测试
ans, ctxs, scores = answer_query("What's the diplomatic clause?", top_k=6)
print(ans)


In [None]:
questions = [
    "What's the diplomatic clause?",
    "When things are spoiled/broken, who pays to repair?",
    "What to do before returning the unit?"
]

for i, q in enumerate(questions, 1):
    ans, ctxs, scores = answer_query(q, top_k=6)
    print(f"Q{i}: {q}\nA: {ans}\n")
    print("="*80)
