In [1]:
from transformers import Mistral3ForConditionalGeneration, FineGrainedFP8Config, AutoProcessor, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
model = "/tmp2/share_data/mistralai--Ministral-3-14B-Instruct-2512/"
processor = AutoProcessor.from_pretrained(model)
tokenizer = AutoTokenizer.from_pretrained(model)
model = Mistral3ForConditionalGeneration.from_pretrained(
    model,
    device_map="auto",
    quantization_config=FineGrainedFP8Config(dequantize=True)
)

OSError: Error no file named model.safetensors found in directory ./models/mistral_input_size=500.

In [4]:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

lora_config = LoraConfig(
    r=16,  # LoRA rank
    lora_alpha=32,  # LoRA scaling factor
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

# Apply LoRA to model
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

Skipping import of cpp extensions due to incompatible torch version 2.9.1+cu128 for torchao version 0.14.1             Please see https://github.com/pytorch/ao/issues/2919 for more info


trainable params: 69,992,448 || all params: 14,015,024,128 || trainable%: 0.4994


### Dataset

In [43]:
from datasets import load_dataset

# Load dataset - expected format with images
raw_train_dataset, raw_val_dataset = load_dataset(
    "json",
    data_files="data/input/example.json",
    split=["train[:90%]", "train[90%:]"],
)


In [44]:

instruction = "請給我 OCR 結果"

def convert_to_conversation(sample):
    conversation = [
        { "role": "user",
          "content" : [
            {"type" : "text",  "text"  : instruction},
            {"type" : "image", "image" : sample["image_path"]} ]
        },
        { "role" : "assistant",
          "content" : [
            {"type" : "text",  "text"  : sample["ocr_text"]} ]
        },
    ]
    return { "messages" : conversation }

train_dataset = [convert_to_conversation(sample) for sample in raw_train_dataset]
val_dataset = [convert_to_conversation(sample) for sample in raw_val_dataset]

In [45]:
train_dataset[0]

{'messages': [{'role': 'user',
   'content': [{'type': 'text', 'text': '請給我 OCR 結果'},
    {'type': 'image',
     'image': 'data/pdf2png/mohw/(檔案下載)衛部心字第1131763377號發布令/(檔案下載)衛部心字第1131763377號發布令_page_0001.png'}]},
  {'role': 'assistant',
   'content': [{'type': 'text',
     'text': '檔號：保存年限：\n衛生福利部 令\n發文日期：中華民國114年1月17日\n發文字號：衛部心字第1131763377號\n附件：「精神疾病嚴重病人強制處置費用標準」條文1份\n\n訂定「精神疾病嚴重病人強制處置費用標準」。\n附「精神疾病嚴重病人強制處置費用標準」部長 邱泰源\n\n第1页 共1页'}]}]}

In [None]:
import torch
'''
image = raw_train_dataset[0]["image_path"]
instruction = "請給我 OCR 結果"

messages = [
    {"role": "user", "content": [
        {"type": "image", "image": image},
        {"type": "text", "text": instruction}
    ]}
]
# template -> string & 加入一些 tag
inputs = processor.apply_chat_template(
    messages, 
    add_generation_prompt=True, 
    tokenize=True, return_dict=True, 
    return_tensors="pt").to(model.device, dtype=torch.bfloat16)

generate_ids = model.generate(**inputs, max_new_tokens=2000)
decoded_output = tokenizer.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True)

decoded_output
'''

'以下是這張圖片的 OCR 結果（根據可見文字）：\n\n---\n衛生福利部 令\n\n發文日期：中華民國114年1月17日\n發文字號：衛部心字第1131763377號\n附件：「精神疾病嚴重病人強制處置費用標準」條文1份\n\n訂定「精神疾病嚴重病人強制處置費用標準」\n附「精神疾病嚴重病人強制處置費用標準」\n\n部長 部長 蔡源\n\n---\n 備註：\n- 圖章部分無法完全識別，但可見為官方印章。\n- 右上角有「保存年限」字樣，但文字不完整。'

### Training

In [None]:
'''
def preprocess_function(examples):
    """預處理數據,將 messages 轉換成 model input"""
    messages = examples["messages"]
    
    # template -> 加入一些 tag string -> token_id
    # 關鍵：不要在這裡做 padding，讓 DataCollator 動態處理
    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=False,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
        truncation=True,
        max_length=4096,  # 只設置 max_length 用於截斷
        # 移除 padding="max_length" - 讓 collator 動態處理
    )
    
    # Labels are the same as input_ids for causal LM
    inputs["labels"] = inputs["input_ids"].clone()
    
    return inputs

# 方法 1: 使用 HuggingFace Dataset + map (推薦)
from datasets import Dataset

train_hf_dataset = Dataset.from_list(train_dataset)
val_hf_dataset = Dataset.from_list(val_dataset)

# 注意：這裡不能用 remove_columns，因為我們需要保留所有欄位
# 且 preprocess_function 現在返回的是 tensor，需要轉換
def preprocess_and_convert(example):
    result = preprocess_function(example)
    # 將 tensor 轉換為 list (Dataset 需要可序列化的格式)
    converted = {}
    for key, value in result.items():
        if hasattr(value, 'squeeze'):
            converted[key] = value.squeeze(0).tolist()
        else:
            converted[key] = value
    return converted

train_hf_dataset = train_hf_dataset.map(preprocess_and_convert, remove_columns=train_hf_dataset.column_names)
val_hf_dataset = val_hf_dataset.map(preprocess_and_convert, remove_columns=val_hf_dataset.column_names)
'''

In [48]:
import numpy as np

from trl import SFTConfig, SFTTrainer

max_seq_length = 4096

from trl import DataCollatorForVisionLanguageModeling

data_collator = DataCollatorForVisionLanguageModeling(
    processor=processor,
)

trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,      # 直接使用原始 list
    eval_dataset=val_dataset,         # 直接使用原始 list
    processing_class=processor,
    data_collator=data_collator,      # 使用 VLM collator
    args=SFTConfig(
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=4,
        num_train_epochs=5,
        eval_strategy="steps",
        save_strategy="steps",
        save_steps=1000,
        metric_for_best_model="eval_loss",
        learning_rate=2e-4,
        logging_steps=1,
        optim="adamw_8bit",
        weight_decay=0.001,
        lr_scheduler_type="cosine",
        seed=3407,
        output_dir="models/vision_1215",
        report_to="none",
        
        # You MUST put the below items for vision finetuning:
        remove_unused_columns=False,
        dataset_text_field="",
        dataset_kwargs={"skip_prepare_dataset": True},
        max_length=max_seq_length,
    ),
)

ImportError: cannot import name 'DataCollatorForVisionLanguageModeling' from 'trl' (/tmp2/howard/venv_manager/vl-sft/lib/python3.11/site-packages/trl/__init__.py)

In [11]:
trainer_stats = trainer.train()

  input_ids = [torch.tensor(example["input_ids"]) for example in examples]
  labels = [torch.tensor(example["labels"]) for example in examples]


Step,Training Loss,Validation Loss,Entropy,Num Tokens,Mean Token Accuracy
1,5.684688,5.551738,6.978872,262144.0,0.623687
2,5.593156,5.487171,6.893362,270336.0,0.622222
3,5.436441,5.153737,6.815319,286720.0,0.623932
4,5.214163,5.058553,6.740965,294912.0,0.62906
5,5.034243,4.839242,6.660617,311296.0,0.634432
6,4.892753,4.798687,6.621404,319488.0,0.63956
7,4.819975,4.715858,6.592567,335872.0,0.650061
8,4.708533,4.658279,6.547239,344064.0,0.653968
9,4.645241,4.612432,6.517872,360448.0,0.65641
10,4.607719,4.553523,6.504981,368640.0,0.657387


### Inference

In [None]:
import torch

image = raw_train_dataset[0]["image_path"]
instruction = "請給我 OCR 結果"

messages = [
    {"role": "user", "content": [
        {"type": "image", "image": image},
        {"type": "text", "text": instruction}
    ]}
]
# template -> 加入一些 tag string -> token_id
inputs = processor.apply_chat_template(
    messages, 
    add_generation_prompt=True, 
    tokenize=True, return_dict=True, 
    return_tensors="pt").to(model.device, dtype=torch.bfloat16)

generate_ids = model.generate(**inputs, max_new_tokens=2000)
decoded_output = tokenizer.decode(generate_ids[0, inputs["input_ids"].shape[1] :], skip_special_tokens=True)

decoded_output

'以下是這張圖片的 OCR 結果（根據可見文字）：\n\n---\n衛生福利部 令\n\n發文日期：中華民國114年1月17日\n發文字號：衛部心字第1131763377號\n附件：「精神疾病嚴重病人強制處置費用標準」條文1份\n\n訂定「精神疾病嚴重病人強制處置費用標準」\n附「精神疾病嚴重病人強制處置費用標準」\n\n部長 部長 蔡源\n\n---\n 備註：\n- 圖章部分無法完全識別，但可見為官方印章。\n- 右上角有「保存年限」字樣，但文字不完整。'

In [41]:
decoded_output

NameError: name 'decoded_output' is not defined