In [1]:
from transformers import Mistral3ForConditionalGeneration, FineGrainedFP8Config, AutoProcessor, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "/tmp2/share_data/mistralai--Ministral-3-14B-Instruct-2512/"
processor = AutoProcessor.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = Mistral3ForConditionalGeneration.from_pretrained(
    model_name,
    device_map="auto",
    local_files_only=True,
    trust_remote_code=True,
    quantization_config=FineGrainedFP8Config(dequantize=True)
)

Loading weights: 100%|█| 585/585 [00:03<00:00, 160.71it/s, Materializing param=model.vision_tower.transformer.layers.23.ffn_nor
The module name  (originally ) is not a valid Python identifier. Please rename the original module to avoid import issues.


In [4]:
from peft import PeftConfig, PeftModel
peft_model_id = "/tmp2/howard/vl-sft-ocr/models/mistral_input_size_500/checkpoint-318"
lora_model = PeftModel.from_pretrained(model, peft_model_id)



### Dataset

In [4]:
from datasets import load_dataset

# Load dataset - expected format with images
raw_train_dataset, raw_val_dataset = load_dataset(
    "json",
    data_files="data/input/ocr_test_data=100.json",
    split=["train[:90%]", "train[90%:]"],
)


### Inference

In [8]:
import torch

topk = 1

for i, sample in enumerate(raw_train_dataset):
    if i >= topk:
        break
    image = sample["image_path"]
    instruction = "請給我 OCR 結果"

    messages = [
        {"role": "user", "content": [
            {"type": "image", "image": image},
            {"type": "text", "text": instruction}
        ]}
    ]
    # template -> 加入一些 tag string -> token_id
    inputs = processor.apply_chat_template(
        messages, 
        add_generation_prompt=True, 
        tokenize=True, return_dict=True, 
        return_tensors="pt").to(model.device, dtype=torch.bfloat16)

    generate_ids = model.generate(**inputs, max_new_tokens=2000)
    decoded_output = tokenizer.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True)
    print(decoded_output)

以下是這張圖片的 OCR 結果：

---

**衛生福利部 令**

登文日期：中華民國112年5月10日
登文字號：衛部醫字第1121663253號
附件：「公共場所應設置急救設備管理辦法」修正條文1份

修正「公共場所應設置急救設備管理辦法」。

附修正「公共場所應設置急救設備管理辦法」。

部長 薛蕙芳

---


### Saving for VLLM

In [None]:
import torch

output_dir = "/tmp2/howard/vl-sft-ocr/models/merged--mistral-sft_input_size=500"
merged_lora_model = lora_model.merge_and_unload()

# 清除量化配置以避免序列化問題
if hasattr(merged_lora_model.config, 'quantization_config'):
    merged_lora_model.config.quantization_config = None

# 將模型轉換為 float16 或 bfloat16 以確保兼容性
merged_lora_model = merged_lora_model.to(dtype=torch.bfloat16)

# 保存模型、tokenizer 和 processor
merged_lora_model.save_pretrained(output_dir, safe_serialization=True)
tokenizer.save_pretrained(output_dir)
processor.save_pretrained(output_dir)

print(f"Model saved to {output_dir}")

TypeError: Object of type dtype is not JSON serializable