In [1]:
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
import whisper
import faiss
import json
import numpy as np
from sentence_transformers import SentenceTransformer
from transformers import CLIPModel, CLIPProcessor

# 设备设置
device = "cuda" if torch.cuda.is_available() else "cpu"

# 加载 GLM-4V 模型
tokenizer = AutoTokenizer.from_pretrained("/root/autodl-tmp/glm-4v-9b", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    "/root/autodl-tmp/glm-4v-9b",
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True
).to(device).eval()

# 加载 Whisper 语音识别模型
whisper_model = whisper.load_model("base")

# 加载文本嵌入模型
text_embedding_model = SentenceTransformer("/root/autodl-tmp/all-MiniLM-L6-v2")

# 加载 CLIP 模型和处理器
clip_model = CLIPModel.from_pretrained("/root/autodl-tmp/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("/root/autodl-tmp/clip-vit-base-patch32")

# 加载文本检索系统
def load_text_retrieval_system(index_file, texts_file):
    """
    加载文本检索系统（FAISS 索引和文本数据）。
    """
    index = faiss.read_index(index_file)
    with open(texts_file, "r", encoding="utf-8") as f:
        texts = json.load(f)
    return index, texts

# 加载多模态检索系统
def load_multimodal_retrieval_system(index_file, texts_file):
    """
    加载多模态检索系统（FAISS 索引和图片文本数据）。
    """
    index = faiss.read_index(index_file)
    with open(texts_file, "r", encoding="utf-8") as f:
        image_texts = json.load(f)
    return index, image_texts

# 检索文本（支持多个数据库）
def retrieve_texts(query, indices, texts_list, k=3):
    """
    检索与查询相关的文本，支持多个数据库。
    """
    all_retrieved_texts = []
    for index, texts in zip(indices, texts_list):
        query_embedding = text_embedding_model.encode([query], convert_to_tensor=False)
        distances, indices = index.search(query_embedding.astype("float32"), k)
        retrieved_texts = [texts[i] for i in indices[0]]
        all_retrieved_texts.extend(retrieved_texts)
    return all_retrieved_texts

# 分块处理长文本
def chunk_text(text, max_length=77):
    """
    将长文本分割成多个短文本块，每个块的长度不超过 max_length。
    """
    words = text.split()  # 按空格分割文本为单词列表
    chunks = []
    current_chunk = []

    for word in words:
        # 如果当前块加上新单词后长度不超过 max_length，则添加到当前块
        if len(" ".join(current_chunk + [word])) <= max_length:
            current_chunk.append(word)
        else:
            # 否则，将当前块保存，并开始一个新块
            chunks.append(" ".join(current_chunk))
            current_chunk = [word]

    # 添加最后一个块
    if current_chunk:
        chunks.append(" ".join(current_chunk))

    return chunks

# 检索多模态数据（支持分块处理）
def retrieve_multimodal_data(query, index, image_texts, k=3):
    """
    检索与查询相关的多模态数据（图片和文本）。
    """
    # 分块处理长文本
    chunks = chunk_text(query)

    # 对每个文本块获取嵌入
    chunk_embeddings = []
    for chunk in chunks:
        inputs = clip_processor(text=chunk, return_tensors="pt", padding=True, truncation=True).to(device)
        with torch.no_grad():
            chunk_embedding = clip_model.get_text_features(**inputs).cpu().numpy()
        chunk_embeddings.append(chunk_embedding)

    # 聚合所有文本块的嵌入（取平均值）
    if chunk_embeddings:
        query_text_embedding = np.mean(chunk_embeddings, axis=0)
    else:
        query_text_embedding = np.zeros((1, 512))  # 如果没有文本块，返回零向量

    # 构建查询嵌入
    query_embedding = np.concatenate([np.zeros((1, 512)), query_text_embedding], axis=1)  # 图片部分用零填充
    distances, indices = index.search(query_embedding.astype("float32"), k)
    retrieved_data = [image_texts[i] for i in indices[0]]
    return retrieved_data

# 语音转文本
def transcribe_audio(audio_path):
    """
    将语音转换为文本。
    """
    if not audio_path:
        return ""
    try:
        transcription = whisper_model.transcribe(audio_path)
        return transcription["text"]
    except Exception as e:
        return f"语音识别失败: {str(e)}"

# 生成描述（集成 RAG）
def generate_description(image, query, text_indices, text_texts_list, image_index, image_texts):
    """
    生成描述，集成 RAG 功能。
    """
    if not query.strip():
        return "错误：请输入文本或语音输入问题。"

    # 检索相关文档（支持多个数据库）
    retrieved_texts = retrieve_texts(query, text_indices, text_texts_list)
    retrieved_images = retrieve_multimodal_data(query, image_index, image_texts)

    # 将检索结果与用户输入结合
    context = "检索到的文本信息：\n" + "\n".join(retrieved_texts) + "\n\n检索到的图片信息：\n" + "\n".join(
        [item["text"] for item in retrieved_images]
    )
    final_query = f"{context}\n\n用户问题：{query}"

    # 处理图片 + 文本输入
    if image is not None:
        image = image.convert('RGB')
        inputs = tokenizer.apply_chat_template(
            [{"role": "user", "image": image, "content": final_query}],
            add_generation_prompt=True,
            tokenize=True,
            return_tensors="pt",
            return_dict=True
        ).to(device)
    else:
        inputs = tokenizer.apply_chat_template(
            [{"role": "user", "content": final_query}],
            add_generation_prompt=True,
            tokenize=True,
            return_tensors="pt",
            return_dict=True
        ).to(device)

    # 生成回答
    gen_kwargs = {
        "max_new_tokens": 1000,  # 明确指定生成的最大 token 数量
        "do_sample": True,
        "top_k": 1
    }
    with torch.no_grad():
        outputs = model.generate(**inputs, **gen_kwargs)
        outputs = outputs[:, inputs['input_ids'].shape[1]:]
        description = tokenizer.decode(outputs[0], skip_special_tokens=True)

    return description

# 更新查询（语音转文本）
def update_query_from_audio(audio):
    """
    更新查询内容（语音转文本）。
    """
    return transcribe_audio(audio)

# Gradio 界面
def gradio_interface(image, transcribed_text, query):
    """
    Gradio 界面逻辑。
    """
    final_query = query.strip() or transcribed_text.strip()
    if not final_query:
        return "Error: Please enter text or voice input problem."

    description = generate_description(image, final_query, text_indices, text_texts_list, image_index, image_texts)
    return description

# 主函数
def main():
    global text_indices, text_texts_list, image_index, image_texts

    # 加载检索系统
    text_index_file_1 = "/root/autodl-tmp/text_index.faiss"  # 原有文本 FAISS 索引文件路径
    text_texts_file_1 = "/root/autodl-tmp/text_texts.json"   # 原有文本数据文件路径
    text_index_file_2 = "/root/autodl-tmp/NHS_text_index.faiss"  # 新增文本 FAISS 索引文件路径
    text_texts_file_2 = "/root/autodl-tmp/NHS_text_texts.json"   # 新增文本数据文件路径
    image_index_file = "/root/autodl-tmp/image_index.faiss"    # 图片 FAISS 索引文件路径
    image_texts_file = "/root/autodl-tmp/image_texts.json"     # 图片文本数据文件路径

    # 加载多个文本数据库
    text_index_1, text_texts_1 = load_text_retrieval_system(text_index_file_1, text_texts_file_1)
    text_index_2, text_texts_2 = load_text_retrieval_system(text_index_file_2, text_texts_file_2)
    text_indices = [text_index_1, text_index_2]
    text_texts_list = [text_texts_1, text_texts_2]

    # 加载多模态数据库
    image_index, image_texts = load_multimodal_retrieval_system(image_index_file, image_texts_file)

    # Gradio 界面
    with gr.Blocks() as interface:
        gr.Markdown("## GLM-4V Voice + Picture + Text Multimodal Description Generation (Integrated with RAG)")
        gr.Markdown("Upload a picture, enter a question, or use voice description to let AI generate the corresponding description.")

        with gr.Row():
            image_input = gr.Image(label="Upload a picture (optional)", type="pil")

        with gr.Row():
            audio_input = gr.Audio(type="filepath", label="Voice input (optional)")
            transcribed_text = gr.Textbox(label="Speech-to-text results (editable)", interactive=True)

        with gr.Row():
            query_input = gr.Textbox(label="Input question (can be edited manually)", interactive=True)
            submit_button = gr.Button("submit")

        output_text = gr.Textbox(label="Generated Description")

        # 当用户上传音频时，更新 语音转文本 输入框
        audio_input.change(update_query_from_audio, inputs=[audio_input], outputs=[transcribed_text])

        # 提交时，综合 语音转文本 + 手动输入，生成描述
        submit_button.click(
            gradio_interface,
            inputs=[image_input, transcribed_text, query_input],
            outputs=[output_text]
        )

    interface.launch(share=True)

if __name__ == "__main__":
    main()

Loading checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s]

  checkpoint = torch.load(fp, map_location=device)


Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://3e901690f3d64d5b42.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)
