In [None]:
# 导入必要的库
import torch
from transformers import AutoConfig, AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor
from janus.utils.io import load_pil_images
from PIL import Image
import numpy as np
from IPython.display import display, HTML
from ipywidgets import FileUpload, Text, Button, HBox, VBox, Output
import os
from io import BytesIO

# 加载模型和处理器
model_path = "/root/autodl-tmp/Janus-1/20250127"  
config = AutoConfig.from_pretrained(model_path)
language_config = config.language_config
language_config._attn_implementation = 'eager'
vl_gpt = AutoModelForCausalLM.from_pretrained(model_path,
                                             language_config=language_config,
                                             trust_remote_code=True)
if torch.cuda.is_available():
    vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
else:
    vl_gpt = vl_gpt.to(torch.float16)

vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'

@torch.inference_mode()
def multimodal_understanding(image, question, top_p=0.95, temperature=0.1):
    # 清空 CUDA 缓存
    torch.cuda.empty_cache()
    
    
    # 将字节流转换为 PIL 图像
    try:
        pil_image = Image.open(BytesIO(image)).convert("RGB")
    except Exception as e:
        print(f"Error loading image: {e}")
        return "无法加载图片，请检查图片格式是否正确。"
    
    # 构建对话模板
    conversation = [
        {
            "role": "<|User|>",
            "content": f"<image_placeholder>\n{question}",
            "images": ["image_placeholder"],  # 这里只是一个占位符，实际图片通过 `images` 参数传递
        },
        {"role": "<|Assistant|>", "content": ""},
    ]
    
    # 使用处理器准备输入
    try:
        prepare_inputs = vl_chat_processor(
            conversations=conversation, images=[pil_image], force_batchify=True
        ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
    except Exception as e:
        print(f"Error processing inputs: {e}")
        return "输入处理失败，请检查模型和处理器配置。"
    
    # 获取嵌入表示
    inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
    
    # 生成回答
    try:
        outputs = vl_gpt.language_model.generate(
            inputs_embeds=inputs_embeds,
            attention_mask=prepare_inputs.attention_mask,
            pad_token_id=tokenizer.eos_token_id,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            max_new_tokens=512,
            do_sample=False if temperature == 0 else True,
            use_cache=True,
            temperature=temperature,
            top_p=top_p,
        )
    except Exception as e:
        print(f"Error generating response: {e}")
        return "生成回答失败，请检查模型配置。"
    
    # 解码生成的回答
    answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
    return answer

# Notebook 界面
upload = FileUpload(accept="image/*", multiple=False)  # 图片上传
question_input = Text(description="Question")  # 文字提示输入框
output = Output()  # 输出区域

def on_button_clicked(b):
    with output:
        output.clear_output()
        if upload.value:
            # 打印 upload.value 的内容，以便调试
            print("upload.value:", upload.value)
            # 获取上传的图片字节流
            try:
                # 尝试从 upload.value 中提取字节流
                image_bytes = upload.value[0]['content']
            except (KeyError, TypeError) as e:
                # 如果失败，打印错误信息
                print("Error extracting image bytes:", e)
                return
            question = question_input.value  # 获取输入的问题
            answer = multimodal_understanding(image_bytes, question)
            display(HTML(f"<h3>Response:</h3><p>{answer}</p>"))

button = Button(description="Chat")
button.on_click(on_button_clicked)

# 显示界面
display(VBox([upload, question_input, button, output]))