In [1]:
from PIL import Image
import torch
import json
import numpy as np
import faiss
from transformers import CLIPProcessor, CLIPModel

# 加载 CLIP 模型
clip_model = CLIPModel.from_pretrained("/root/autodl-tmp/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("/root/autodl-tmp/clip-vit-base-patch32")

# 提取图片嵌入
def get_image_embedding(image_path):
    image = Image.open(image_path).convert("RGB")
    inputs = clip_processor(images=image, return_tensors="pt", padding=True)
    with torch.no_grad():
        image_embedding = clip_model.get_image_features(**inputs)
    return image_embedding.numpy()

# 提取文本嵌入
def get_text_embedding(text):
    inputs = clip_processor(text=text, return_tensors="pt", padding=True)
    with torch.no_grad():
        text_embedding = clip_model.get_text_features(**inputs)
    return text_embedding.numpy()

# 加载图片和问答数据
def load_image_text_data(json_file):
    with open(json_file, "r", encoding="utf-8") as f:
        data = json.load(f)
    image_texts = []
    for item in data:
        if item["type"] == "image_text":
            image_texts.append({
                "image_path": item["image"],
                "text": f"Q: {item['question']} A: {item['answer']}"
            })
    return image_texts

# 构建多模态 FAISS 索引
def build_multimodal_faiss_index(image_texts):
    embeddings = []
    for item in image_texts:
        image_embedding = get_image_embedding(item["image_path"])
        text_embedding = get_text_embedding(item["text"])
        combined_embedding = np.concatenate([image_embedding, text_embedding], axis=1)
        embeddings.append(combined_embedding)
    embeddings = np.vstack(embeddings)
    dimension = embeddings.shape[1]
    index = faiss.IndexFlatL2(dimension)
    index.add(embeddings.astype("float32"))
    return index, image_texts

# 保存多模态检索系统
def save_multimodal_retrieval_system(index, image_texts, index_file, texts_file):
    faiss.write_index(index, index_file)
    with open(texts_file, "w", encoding="utf-8") as f:
        json.dump(image_texts, f, ensure_ascii=False, indent=4)

# 主函数：构建并保存多模态检索系统
def build_and_save_multimodal_retrieval_system(json_file, index_file, texts_file):
    image_texts = load_image_text_data(json_file)
    index, image_texts = build_multimodal_faiss_index(image_texts)
    save_multimodal_retrieval_system(index, image_texts, index_file, texts_file)
    print("多模态检索系统构建并保存完成！")

# 示例
image_text_data_file = "/path/to/your/image_text_data.json"
image_index_file = "/path/to/save/image_index.faiss"
image_texts_file = "/path/to/save/image_texts.json"
build_and_save_multimodal_retrieval_system(image_text_data_file, image_index_file, image_texts_file)

FileNotFoundError: Failed to download file (open_clip_pytorch_model.bin) for laion/CLIP-ViT-H-14-laion2B-s32B-b79K. Last error: An error happened while trying to locate the file on the Hub and we cannot find the requested files in the local cache. Please check your connection and try again or make sure your Internet connection is on.