# Fine-tuning SmolVLM on ChartLlama Dataset

This notebook demonstrates how to fine-tune the SmolVLM model on the ChartLlama dataset using parameter-efficient fine-tuning (LoRA).

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os
os.chdir('/content/drive/MyDrive/code')

In [81]:
import torch
import gc
import os
import time
import numpy as np
from PIL import Image
# Use the specific model and processor
from transformers import Idefics3Processor, Idefics3ForConditionalGeneration, BitsAndBytesConfig
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback
from peft import LoraConfig, get_peft_model
from torch.utils.data import random_split, Dataset
from chartllama_load import ChartDataset
import wandb
from accelerate.utils import set_seed
from accelerate import Accelerator
import io # Ensure io is imported
from tqdm.notebook import tqdm # Use notebook-friendly progress bar
from datasets import load_dataset

In [4]:
# 1. Setup Configuration
WANDB_PROJECT = "smolvlm-chartllama"
MODEL_ID = "HuggingFaceTB/SmolVLM-256M-Base"
OUTPUT_DIR = "./smolvlm-chartllama-lora-tuned"
DATA_DIR = "./chartllama_data"
SAMPLE_LIMIT = None # Use all data
IMAGE_SIZE = 384
SEED = 42
# Decide on precision (bfloat16 preferred on A100/H100, float16 otherwise)
DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
USE_LORA = False # Set to False to do full finetuning (requires much more memory)
USE_QUANTIZATION = False # <<< Set to False for standard LoRA, True for QLoRA

In [82]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # Ensure consistent dtype
print(f"Using device: {DEVICE}")

Using device: cuda


In [None]:
# https://wandb.ai/authorize
# 0469802d14d997b8dad4d23a7ba212e0a8d8f197

In [5]:
# # Configure wandb
if WANDB_PROJECT:
    wandb.init(project=WANDB_PROJECT)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mryan-seet467[0m ([33mryan-seet467-georgia-institute-of-technology[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


## Check Hardware and Clear Memory

In [6]:
# --- Check Hardware and Clear Memory (Keep as is) ---
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU: {gpu_name}")
    print(f"GPU memory: {gpu_memory_gb:.2f} GB")
    print(f"Using dtype: {DTYPE}")
    if not USE_QUANTIZATION and gpu_memory_gb < 20: # Rough estimate
         print("Warning: Standard LoRA without quantization might require significant VRAM (>20-24GB). Consider using USE_QUANTIZATION=True if you encounter memory issues.")


PyTorch version: 2.6.0+cu124
CUDA available: True
CUDA version: 12.4
GPU: NVIDIA A100-SXM4-40GB
GPU memory: 42.47 GB
Using dtype: torch.bfloat16


In [7]:
torch.cuda.empty_cache()
gc.collect()
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.makedirs(OUTPUT_DIR, exist_ok=True)
set_seed(SEED) # Set seed early

## Load Model and Processor

In [8]:
# --- Load Model and Processor ---
print("Loading processor...")
# Ensure processor matches the model
processor = Idefics3Processor.from_pretrained(MODEL_ID)

Loading processor...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/486 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/11.1k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/801k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/466k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/3.53M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/1.05k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/1.07k [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/68.0 [00:00<?, ?B/s]

chat_template.json:   0%|          | 0.00/424 [00:00<?, ?B/s]

In [9]:
# --- Configure Model Loading ---
model_load_kwargs = {} # Start with an empty dict


In [10]:
if USE_QUANTIZATION:
    print("Configuring model for QLoRA (8-bit quantization)...")
    bnb_config = BitsAndBytesConfig(
        load_in_8bit=True,
        bnb_8bit_use_double_quant=True,
        bnb_8bit_quant_type="nf8",
        bnb_8bit_compute_dtype=DTYPE
    )
    model_load_kwargs["quantization_config"] = bnb_config
    from peft import prepare_model_for_kbit_training # Needed for QLoRA
else:
    print(f"Configuring model for standard LoRA (dtype: {DTYPE})...")
    model_load_kwargs["torch_dtype"] = DTYPE


Configuring model for standard LoRA (dtype: torch.bfloat16)...


In [11]:
# --- Load Model ---
print(f"Loading model '{MODEL_ID}'...")
start_time = time.time()
model = Idefics3ForConditionalGeneration.from_pretrained(
    MODEL_ID,
    **model_load_kwargs
)
print(f"Model loaded in {time.time() - start_time:.2f} seconds")

Loading model 'HuggingFaceTB/SmolVLM-256M-Base'...


config.json:   0%|          | 0.00/7.30k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/513M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

Model loaded in 24.97 seconds


In [None]:
# !pip install -U bitsandbytes

## Prepare Model for Fine-tuning

In [12]:
# --- Prepare Model for Fine-tuning ---
if USE_LORA:
    print("Preparing model for LoRA training...")
    if USE_QUANTIZATION:
        print("Applying prepare_model_for_kbit_training for QLoRA...")
        model = prepare_model_for_kbit_training(model)

    # Configure LoRA (same as before)
    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        lora_dropout=0.1,
        bias="none",
        task_type="CAUSAL_LM",
    )
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()
else:
     print("Configured for Full Fine-tuning.")

model.config.use_cache = False


Configured for Full Fine-tuning.


## Load Dataset

In [13]:
# check whether files present in current directory
import os
from pathlib import Path

data_dir_path = Path(DATA_DIR)
print(f"Checking data directory: {data_dir_path.resolve()}") # Print absolute path
print(f"Does directory exist? {data_dir_path.exists()}")
if data_dir_path.exists():
    print("Files in directory:")
    for item in data_dir_path.iterdir():
        print(f"  - {item.name}")

Checking data directory: /content/drive/MyDrive/code/chartllama_data
Does directory exist? True
Files in directory:
  - candlestick_chart_100examples_simplified_qa.json
  - funnel_chart_100examples_simplified_qa.json
  - gantt_chart_100examples_simplified_qa.json
  - polar_chart_100examples_simplified_qa.json
  - heatmap_chart_100examples_simplified_qa.json
  - box_chart_100examples_simplified_qa.json
  - .gitattributes
  - ours.zip
  - scatter_chart_100examples_simplified_qa.json
  - extracted_images
  - .cache


In [14]:
print("Loading dataset using custom ChartDataset...")
# Ensure the prompt_format matches what the model expects for QA
dataset = ChartDataset(
    data_dir=DATA_DIR,
    processor=processor,
    image_size=IMAGE_SIZE,
    max_answer_length=128,
    prompt_format="USER: {question}\nASSISTANT:", # Adjust if needed for Idefics3
    sample_limit=SAMPLE_LIMIT,
    cache_images=True
)

Loading dataset using custom ChartDataset...


Loading data: 100%|██████████| 7/7 [05:17<00:00, 45.33s/it]


In [15]:
# --- Split Dataset ---
total_size = len(dataset)
if total_size == 0:
    raise ValueError("Dataset loaded 0 examples. Check data directory and loading logic.")

train_size = int(0.8 * total_size)
val_test_size = total_size - train_size
val_size = int(0.5 * val_test_size) # Split remaining 20% into 10% val, 10% test
test_size = val_test_size - val_size

train_dataset, remaining = random_split(
    dataset, [train_size, val_test_size],
    generator=torch.Generator().manual_seed(SEED)
)

val_dataset, test_dataset = random_split(
    remaining, [val_size, test_size],
    generator=torch.Generator().manual_seed(SEED)
)

print(f"Total samples: {total_size}")
print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")


Total samples: 980
Training samples: 784
Validation samples: 98
Test samples: 98


In [40]:
from IPython.display import display
import random

# --- Choose an item index ---
item_index = random.randint(0, len(dataset.data_list) - 1)

# --- Retrieve the raw data and the PIL image ---
if 0 <= item_index < len(dataset.data_list):
    raw_item = dataset.data_list[item_index]
    pil_image = raw_item.get('image')
    if pil_image:
        # --- Display information about the item ---
        print(f"--- Displaying Item Index: {item_index} ---")
        print(f"Original ID: {raw_item.get('original_id', 'N/A')}")
        print(f"Chart Type: {raw_item.get('chart_type', 'N/A')}")
        print(f"Question: {raw_item.get('question', 'N/A')[:200]}...")
        print(f"Answer: {raw_item.get('answer', 'N/A')}")
        print(f"Image Path (from JSON): {raw_item.get('image_path', 'N/A')}")
        print("--------------------------------------")

        # --- Display the image ---
        display(pil_image)

    else:
        print(f"Item {item_index} does not have a loaded image object (This should not happen).")
else:
    print(f"Invalid index: {item_index}. Dataset size is {len(dataset.data_list)}")

AttributeError: 'ChartDataset' object has no attribute 'data_list'

In [None]:
# --- Define Custom Collator (Manual Label Construction) ---
def collate_fn(examples):
    # 1. Filter out potential errors from __getitem__
    examples = [e for e in examples if e is not None]
    if not examples:
        raise ValueError("Collate function received an empty list of examples.")

    # 2. Extract Data Components
    prompts = [ex["prompt"] for ex in examples]
    answers = [ex["answer"] for ex in examples] 
    images = [ex["image"] for ex in examples]   
    ids = [ex.get("id", "N/A") for ex in examples] 

    # --- Step 1: Process Prompts and Images ---
    try:
        prompt_image_inputs = processor(
            text=prompts,
            images=images,
            return_tensors="pt",
            padding="longest",   
                                 
            truncation=True,    
            max_length=getattr(getattr(processor, 'tokenizer', None), 'model_max_length', None) 
        )
    except Exception as e:
        print(f"Error processing prompts/images for batch IDs {ids[:5]}: {e}")
        raise

    # --- Step 2: Tokenize Answers Separately ---
    answer_encodings = processor.tokenizer(
        answers,                 
        add_special_tokens=False, 
        padding=False,           
        truncation=True,          
        max_length=128            
    )
    # --- Step 3: Combine and Find Max Length ---
    max_len = 0                 
    combined_inputs_list = []   

    for i in range(len(examples)):
        prompt_ids = prompt_image_inputs['input_ids'][i]
        prompt_len = prompt_image_inputs['attention_mask'][i].sum().item()
        prompt_ids_unpadded = prompt_ids[:prompt_len]

        answer_ids = torch.tensor(answer_encodings['input_ids'][i], dtype=torch.long)

        combined = torch.cat([prompt_ids_unpadded, answer_ids], dim=0)
        combined_inputs_list.append(combined) 

        if combined.size(0) > max_len:
            max_len = combined.size(0)

    model_max_len = getattr(getattr(processor, 'tokenizer', None), 'model_max_length', None)
    if model_max_len is not None:
        max_len = min(max_len, model_max_len) 

    # --- Step 4: Pad combined sequences and create final tensors ---
    pad_token_id = processor.tokenizer.pad_token_id if hasattr(processor, 'tokenizer') and processor.tokenizer.pad_token_id is not None else 0
    final_input_ids = []        
    final_attention_mask = []   
    final_labels = []           

    for i in range(len(examples)):
        prompt_len = prompt_image_inputs['attention_mask'][i].sum().item()
        answer_ids = torch.tensor(answer_encodings['input_ids'][i], dtype=torch.long)
        answer_len = answer_ids.size(0)

        combined_ids = combined_inputs_list[i]
        current_len = combined_ids.size(0)

        # --- Truncation Logic (Applied AFTER combination) ---
        if current_len > max_len:
            combined_ids = combined_ids[:max_len] #
            if prompt_len >= max_len:
                 prompt_len_final = max_len
                 answer_len_final = 0 
            else:
                 prompt_len_final = prompt_len
                 answer_len_final = max_len - prompt_len
            current_len = max_len 
        else:
             prompt_len_final = prompt_len
             answer_len_final = answer_len 
        pad_len = max_len - current_len

        # --- Pad input_ids ---
        final_input_ids.append(torch.cat([
            combined_ids, 
            torch.full((pad_len,), pad_token_id, dtype=torch.long)
        ]))

        # --- Pad attention mask ---
        final_attention_mask.append(torch.cat([
            torch.ones(current_len, dtype=torch.long), 
            torch.zeros(pad_len, dtype=torch.long)    
        ]))

        # --- Create and pad labels ---
        labels_core = torch.cat([
            torch.full((prompt_len_final,), -100, dtype=torch.long), 
            answer_ids[:answer_len_final]                             
        ])
        labels_pad_len = max_len - labels_core.size(0)
        final_labels.append(torch.cat([
            labels_core,
            torch.full((labels_pad_len,), -100, dtype=torch.long)
        ]))

    # --- Assemble the final batch dictionary ---
    batch = {
        'input_ids': torch.stack(final_input_ids),
        'attention_mask': torch.stack(final_attention_mask),
        'labels': torch.stack(final_labels),
        'pixel_values': prompt_image_inputs['pixel_values'],
    }
    if 'pixel_attention_mask' in prompt_image_inputs:
         batch['pixel_attention_mask'] = prompt_image_inputs['pixel_attention_mask']

    # --- Final Sanity Check (Optional Debugging) ---
    if batch['input_ids'].shape != batch['labels'].shape:
         print(f">>> FATAL ERROR in Collator: Final input_ids shape {batch['input_ids'].shape} != labels shape {batch['labels'].shape}")
         raise ValueError("Final shape mismatch in collator")

    return batch

## Configure Training

In [None]:

# Adjust output directory name
FULL_FT_OUTPUT_DIR = "./smolvlm-chartllama-256-full-tuned"
os.makedirs(FULL_FT_OUTPUT_DIR, exist_ok=True) 

training_args = TrainingArguments(
    output_dir=FULL_FT_OUTPUT_DIR, 
    num_train_epochs=3,             

    # --- CRITICAL MEMORY/COMPUTE CHANGES ---
    per_device_train_batch_size=2, 
    per_device_eval_batch_size=4,   
    gradient_accumulation_steps=8, 
    # --- CRITICAL LEARNING RATE CHANGE ---
    learning_rate=2e-5,             

    weight_decay=0.01,              
    warmup_ratio=0.1,               
    lr_scheduler_type="cosine",     

    logging_strategy="steps",
    logging_steps=25,              
    save_strategy="epoch",
    evaluation_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    save_total_limit=2,

    fp16=(DTYPE == torch.float16),  
    bf16=(DTYPE == torch.bfloat16),


    gradient_checkpointing=True,  # NEEDED to save memory

    seed=SEED,
    optim="adamw_torch",
    remove_unused_columns=False,
)



In [None]:
# --- Initialize Trainer ---
print("Initializing Trainer...")
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collate_fn
)

Initializing Trainer...


## Training

In [66]:
# --- Training ---
print("Starting training...")
start_time = time.time()

train_result = trainer.train()

training_time = (time.time() - start_time) / 60
print(f"Training completed in {training_time:.2f} minutes")

Starting training...


Epoch,Training Loss,Validation Loss
1,4.9097,2.656473
2,2.1421,2.595716
3,1.992,2.605402


Training completed in 37.12 minutes


In [67]:
# --- Save Results and Model ---
print("Saving training results and final model/adapter...")
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()


Saving training results and final model/adapter...
***** train metrics *****
  epoch                    =        3.0
  total_flos               =  3449970GF
  train_loss               =     2.6639
  train_runtime            = 0:37:06.95
  train_samples_per_second =      1.056
  train_steps_per_second   =      0.066


In [68]:
# Save the FFT adapter weights
model.save_pretrained(os.path.join(FULL_FT_OUTPUT_DIR, "final_adapter"))

In [69]:

# Save the processor
processor.save_pretrained(os.path.join(FULL_FT_OUTPUT_DIR, "final_processor")) # Save with the adapter/model

print(f"Model adapter and processor saved to {FULL_FT_OUTPUT_DIR}")

Model adapter and processor saved to ./smolvlm-chartllama-256-full-tuned


## Evaluation with Relaxed Accuracy

In [83]:
# Clear some memory before loading potentially large models/datasets
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

In [None]:
# Download ChartQA Dataset
print("Downloading ChartQA dataset...")
try:
    chartqa_test_dataset = load_dataset("HuggingFaceM4/ChartQA", split="test", trust_remote_code=True)
    print(f"Loaded ChartQA test split with {len(chartqa_test_dataset)} examples.")
    print("\nFirst ChartQA test sample structure:")
    print(chartqa_test_dataset[0])
except Exception as e:
    print(f"Error downloading/loading ChartQA dataset: {e}")
    chartqa_test_dataset = None

Downloading ChartQA dataset...


README.md:   0%|          | 0.00/852 [00:00<?, ?B/s]

(…)-00000-of-00003-49492f364babfa44.parquet:   0%|          | 0.00/219M [00:00<?, ?B/s]

(…)-00001-of-00003-7302bae5e425bbc7.parquet:   0%|          | 0.00/311M [00:00<?, ?B/s]

(…)-00002-of-00003-194c9400785577a2.parquet:   0%|          | 0.00/315M [00:00<?, ?B/s]

(…)-00000-of-00001-0f11003c77497969.parquet:   0%|          | 0.00/50.2M [00:00<?, ?B/s]

(…)-00000-of-00001-e2cd0b7a0f9eb20d.parquet:   0%|          | 0.00/68.9M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/28299 [00:00<?, ? examples/s]

Generating val split:   0%|          | 0/1920 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/2500 [00:00<?, ? examples/s]

Loaded ChartQA test split with 2500 examples.

First ChartQA test sample structure:
{'image': <PIL.PngImagePlugin.PngImageFile image mode=RGBA size=850x600 at 0x7C3F01AE76D0>, 'query': 'How many food item is shown in the bar graph?', 'label': ['14'], 'human_or_machine': 0}
