In [None]:
import os
import time
import base64
import requests
import pandas as pd
import spacy

# ====== 配置区域 ======
API_KEY = "YOUR Gemini API Key"  # 请替换为你的 Gemini API Key
original_csv = r"C:\Users\Public\projects\601\final_data_sum.csv"
output_csv = r"C:\Users\Public\projects\601\final_data_sum_4.csv"
error_log = r"C:\Users\Public\projects\601\gemini_recovery_error_log.txt"
checkpoint_file = r"C:\Users\Public\projects\601\gemini_checkpoint.txt"
sleep_sec = 4.0
retry_sleep_sec = 7.0
save_interval = 100  # 每处理多少行保存一次
exclude_words = {
    "stamp", "china", "postage", "prc", "中华人民共和国", "中国邮政", "邮票",
    "design", "element", "image", "border", "column", "text", "ink",
    "style", "section", "marking", "character", "*", "layout", "print",
    "color", "background", "value", "denomination", "information",
    "seal", "cloud", "line", "mark", "center", "overprint", "calligraphy",
    "place", "illustration", "motif", "lilac", "bottom", "feature", "depiction",
    "cancellation", "symbol", "font", "type", "rectangle","side", "top","position",
    "description","impression","feel",":*","manner","format","date","location"
}
# ======================

nlp = spacy.load("en_core_web_sm")

# Gemini 请求函数
def query_gemini_description(image_path, retries=3):
    for attempt in range(retries):
        try:
            with open(image_path, "rb") as img_file:
                b64_img = base64.b64encode(img_file.read()).decode("utf-8")

            url = f"https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key={API_KEY}"
            headers = {"Content-Type": "application/json"}
            payload = {
                "contents": [
                    {
                        "parts": [
                            {"text": "Describe the visual elements and symbols on the stamp."},
                            {"inline_data": {"mime_type": "image/jpeg", "data": b64_img}}
                        ]
                    }
                ]
            }

            response = requests.post(url, headers=headers, json=payload, timeout=60)
            result = response.json()
            return result.get("candidates", [{}])[0].get("content", {}).get("parts", [{}])[0].get("text", "")
        except Exception as e:
            print(f"⚠️ 第 {attempt+1} 次失败: {e}")
            time.sleep(retry_sleep_sec)
    return "ERROR: 多次连接失败"

# 提取关键词
def extract_top_n_nouns(text, exclude=exclude_words, top_k=3):
    doc = nlp(str(text).lower())
    nouns = [token.lemma_ for token in doc if token.pos_ == "NOUN" and token.lemma_ not in exclude]
    top = pd.Series(nouns).value_counts().head(top_k)
    return list(top.index) + [""] * (top_k - len(top))

# 读取原始和输出文件
df_base = pd.read_csv(original_csv, encoding="utf-8")
df_out = pd.read_csv(output_csv, encoding="utf-8")

# 补充缺失列
for col in ["description_gemini", "tag_1", "tag_2", "tag_3", "tags_all"]:
    if col not in df_out.columns:
        df_out[col] = ""

# 读取中断点
start_row = 0
if os.path.exists(checkpoint_file):
    with open(checkpoint_file, "r") as f:
        start_row = int(f.read().strip())

# 开始处理所有行
total_rows = len(df_base)
for idx in range(start_row, total_rows):
    image_path = df_base.at[idx, "image_path"]
    print(f"\n⏳ 正在处理第 {idx+1}/{total_rows} 行...")

    try:
        if not isinstance(image_path, str) or not os.path.isfile(image_path):
            raise ValueError("图像路径无效")

        desc = query_gemini_description(image_path)
        df_out.at[idx, "description_gemini"] = desc

        if desc.startswith("ERROR"):
            raise ValueError(desc)

        tags = extract_top_n_nouns(desc)
        df_out.at[idx, "tag_1"], df_out.at[idx, "tag_2"], df_out.at[idx, "tag_3"] = tags
        df_out.at[idx, "tags_all"] = ", ".join([t for t in tags if t])

    except Exception as e:
        print(f"❌ 第 {idx+1} 行失败: {e}")
        with open(error_log, "a", encoding="utf-8") as log:
            log.write(f"{idx}, {image_path}, {str(e)}\n")

    # 每 save_interval 行保存一次
    if (idx + 1) % save_interval == 0 or (idx + 1) == total_rows:
        df_out.to_csv(output_csv, index=False, encoding="utf-8-sig")
        with open(checkpoint_file, "w") as f:
            f.write(str(idx + 1))
        print(f"💾 已保存进度至第 {idx+1} 行")

    time.sleep(sleep_sec)

print("✅ 全部完成，所有数据已保存。")
