## Setup

In [1]:
!pip install  -U -q transformers==4.46.3 trl==0.12.1 datasets bitsandbytes peft accelerate

In [2]:
!pip install -q flash-attn --no-build-isolation

In [3]:
!pip install -q tensorboard

In [4]:
import torch 
import time

## Load data

In [5]:
from datasets import Dataset
from pathlib import Path
import json

root = Path("./chart_llama_dataset")
image_root = root

examples = []

for json_file in root.glob("*_simplified_qa.json"):
    chart_type = json_file.stem.split("_chart")[0] + "_chart"
    with open(json_file, "r", encoding="utf-8") as f:
        qa_list = json.load(f)

    for qa in qa_list:
        try:
            question = qa["conversations"][0]["value"].replace("<image>", "").strip()
            answer = qa["conversations"][1]["value"].strip()
            image_path = image_root / qa["image"]  # 相対パスを解決
            examples.append({
                "question": question,
                "answer": answer,
                "image_path": str(image_path),
                "chart_type": chart_type
            })
        except Exception as e:
            print(f"スキップ: {qa.get('id', '')} due to error: {e}")

# Datasetに変換
dataset = Dataset.from_list(examples)
print(dataset[0])

{'question': 'What is the fast food consumption for the age group 65+?', 'answer': '55', 'image_path': 'chart_llama_dataset/ours/scatter_chart/png/scatter_chart_100examples_52.png', 'chart_type': 'scatter_chart'}


In [6]:
from datasets import Image

dataset = dataset.cast_column("image_path", Image())
print(dataset[0])

{'question': 'What is the fast food consumption for the age group 65+?', 'answer': '55', 'image_path': <PIL.PngImagePlugin.PngImageFile image mode=RGBA size=640x480 at 0x7FCD55DD6BD0>, 'chart_type': 'scatter_chart'}


In [7]:
dataset_split = dataset.train_test_split(test_size=0.2, seed=42)

train_dataset = dataset_split["train"]
test_dataset = dataset_split["test"]

In [8]:
system_message = "You are a helpful assistant that answers questions about charts."

from PIL import Image as PILImage

def format_data(example):
    img_info = example["image_path"]

    if isinstance(img_info, dict) and "path" in img_info:
        pil_img = PILImage.open(img_info["path"])
    elif isinstance(img_info, PILImage.Image):
        pil_img = img_info
    else:
        raise ValueError("Unknown image format:", img_info)

    if pil_img.mode != "RGB":
        pil_img = pil_img.convert("RGB")

    return {
        "messages": [
            {
                "role": "system",
                "content": [{"type": "text", "text": system_message}],
            },
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": pil_img},
                    {"type": "text", "text": example["question"]},
                ],
            },
            {
                "role": "assistant",
                "content": [{"type": "text", "text": example["answer"]}],
            },
        ]
    }


In [9]:
# Apply to dataset
train_formatted_dataset = [format_data(sample) for sample in train_dataset]
test_formatted_dataset = [format_data(sample) for sample in test_dataset]

In [10]:
train_formatted_dataset[0]

{'messages': [{'role': 'system',
   'content': [{'type': 'text',
     'text': 'You are a helpful assistant that answers questions about charts.'}]},
  {'role': 'user',
   'content': [{'type': 'image',
     'image': <PIL.Image.Image image mode=RGB size=640x480>},
    {'type': 'text',
     'text': 'What is the water pollution level for Europe in 2020?'}]},
  {'role': 'assistant', 'content': [{'type': 'text', 'text': '40'}]}]}

## Model

In [11]:
from transformers import Idefics3ForConditionalGeneration, AutoProcessor

model_id = "HuggingFaceTB/SmolVLM-256M-Base"

2025-04-10 10:36:26.795882: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-04-10 10:36:26.808094: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-04-10 10:36:26.811927: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-10 10:36:26.823542: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [12]:
from transformers import BitsAndBytesConfig

# BitsAndBytesConfig int-4 config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
)

# Load model and tokenizer
model = Idefics3ForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config,
    _attn_implementation="flash_attention_2",
)
processor = AutoProcessor.from_pretrained(model_id)

Some kwargs in processor config are unused and will not have any effect: image_seq_len. 


In [13]:
from peft import LoraConfig, get_peft_model

# Configure LoRA
peft_config = LoraConfig(
    r=8,
    lora_alpha=8,
    lora_dropout=0.1,
    target_modules=["down_proj", "o_proj", "k_proj", "q_proj", "gate_proj", "up_proj", "v_proj"],
    use_dora=True,
    init_lora_weights="gaussian",
)

# Apply PEFT model adaptation
peft_model = get_peft_model(model, peft_config)

# Print trainable parameters
peft_model.print_trainable_parameters()

trainable params: 3,067,776 || all params: 259,552,704 || trainable%: 1.1819


In [14]:
from trl import SFTConfig

# Configure training arguments using SFTConfig
training_args = SFTConfig(
    output_dir="smolvlm-base-chartllama",
    max_seq_length=1024,
    num_train_epochs=5,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    warmup_steps=50,
    learning_rate=1e-4,
    weight_decay=0.01,
    logging_steps=25,
    save_strategy="steps",
    save_steps=25,
    save_total_limit=1,
    optim="adamw_torch_fused",
    bf16=True,
    #push_to_hub=True,
    #report_to="tensorboard",
    remove_unused_columns=False,
    gradient_checkpointing=True,
    dataset_text_field="",
    dataset_kwargs={"skip_prepare_dataset": True},
)

In [15]:
from PIL import Image as PILImage

MAX_SIZE = 384

def resize_if_needed(img, max_size=MAX_SIZE):
    w, h = img.size
    if max(w, h) > max_size:
        scale = max_size / max(w, h)
        new_w, new_h = int(w * scale), int(h * scale)
        img = img.resize((new_w, new_h), PILImage.LANCZOS)
    return img

def collate_fn(examples):
    texts = [processor.apply_chat_template(ex["messages"], tokenize=False) for ex in examples]

    images = []
    for ex in examples:
        for msg in ex["messages"]:
            if msg["role"] == "user":
                for item in msg["content"]:
                    if item["type"] == "image":
                        img = item["image"]
                        if isinstance(img, PILImage.Image):  # 念のため確認
                            img = resize_if_needed(img)
                            if img.mode != "RGB":
                                img = img.convert("RGB")
                            images.append([img])  # SmolVLM expects list of lists
                        break


    batch = processor(text=texts, images=images, return_tensors="pt", padding=True, do_resize=False)

    labels = batch["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100
    image_token_id = processor.tokenizer.additional_special_tokens_ids[
        processor.tokenizer.additional_special_tokens.index("<image>")
    ]
    labels[labels == image_token_id] = -100

    batch["labels"] = labels
    return batch




In [16]:
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_formatted_dataset,
    eval_dataset=test_formatted_dataset,
    data_collator=collate_fn,
    peft_config=peft_config,
    tokenizer=processor.tokenizer,
)



In [17]:
print(processor.image_processor.size) 
print(processor.image_processor.max_image_size) 

{'longest_edge': 2048}
{'longest_edge': 512}


In [None]:
start_time = time.time()
trainer.train()
end_time = time.time()
print("Time taken:", (end_time - start_time)/60)

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
  return fn(*args, **kwargs)
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.
The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


Step,Training Loss
25,5.677
50,3.8142
75,1.8872
100,1.3
125,1.1284
150,1.0597
175,0.9996
200,0.9486
225,0.9343


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
  return fn(*args, **kwargs)
  with torch.enab