In [1]:
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration, TrainingArguments
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer
from datasets import load_dataset
import json

In [2]:
MAX_LENGTH = 384
MODEL_ID = "llava-hf/llava-1.5-7b-hf"
REPO_ID = "thashiguchi/llava-finetuning-demo"
WANDB_PROJECT = "LLaVa"
WANDB_NAME = "llava-demo-cord"

In [3]:
dataset = load_dataset("naver-clova-ix/cord-v2")

In [4]:
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = LlavaForConditionalGeneration.from_pretrained(MODEL_ID, torch_dtype=torch.float16, device_map="auto")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=8,
    lora_dropout=0.1,
    bias="none",
    target_modules=["q_proj", "v_proj"],
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)

In [6]:
def preprocess_function(examples):
    images = [example["image"] for example in examples["image"]]
    texts = [f"USER: <image>\nExtract JSON.\nASSISTANT: {json.loads(gt)['gt_parse']}" for gt in examples["ground_truth"]]

    inputs = processor(text=texts, images=images, padding=True, truncation=True, max_length=MAX_LENGTH, return_tensors="pt")

    return inputs

In [7]:
# フォーマット関数を追加
def formatting_func(example):
    try:
        ground_truth = example['ground_truth']
        if isinstance(ground_truth, list):
            ground_truth = ground_truth[0]  # リストの場合、最初の要素を使用
        parsed_gt = json.loads(ground_truth)
        formatted_text = f"USER: <image>\nExtract JSON.\nASSISTANT: {parsed_gt['gt_parse']}"
        return [formatted_text]  # リストとして返す
    except Exception as e:
        print(f"Error processing example: {e}")
        print(f"Example structure: {example}")
        raise  # エラーを再度発生させて処理を停止

def custom_data_collator(data):
    input_ids = [f["input_ids"] for f in data if "input_ids" in f]
    attention_mask = [f["attention_mask"] for f in data if "attention_mask" in f]
    pixel_values = [f["pixel_values"] for f in data if "pixel_values" in f]

    if input_ids:
        input_ids = torch.stack(input_ids)
    if attention_mask:
        attention_mask = torch.stack(attention_mask)
    if pixel_values:
        pixel_values = torch.stack(pixel_values)

    return {
        "input_ids": input_ids if input_ids else None,
        "attention_mask": attention_mask if attention_mask else None,
        "pixel_values": pixel_values if pixel_values else None
    }

def custom_data_collator(data):
    input_ids = []
    attention_mask = []
    pixel_values = []

    for f in data:
        if "input_ids" in f:
            input_ids.append(f["input_ids"])
        if "attention_mask" in f:
            attention_mask.append(f["attention_mask"])
        if "pixel_values" in f:
            pixel_values.append(f["pixel_values"])

    batch = {}

    if input_ids:
        batch["input_ids"] = torch.stack(input_ids)
    if attention_mask:
        batch["attention_mask"] = torch.stack(attention_mask)
    if pixel_values:
        batch["pixel_values"] = torch.stack(pixel_values)

    return batch
