# Fine-tuning SmolVLM on ChartLlama Dataset

This notebook demonstrates how to fine-tune the SmolVLM model on the ChartLlama dataset using parameter-efficient fine-tuning (LoRA).

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import os
os.chdir('/content/drive/MyDrive/code')

In [3]:
!pip install datasets
!pip install -U bitsandbytes



In [4]:
import torch
import gc
import os
import time
import numpy as np
from PIL import Image
import io
import json
from pathlib import Path
import pandas as pd
from IPython.display import display, Image as IPyImage # Added IPyImage for display

from transformers import (
    Idefics3Processor,
    Idefics3ForConditionalGeneration,
    BitsAndBytesConfig,
    Trainer,
    TrainingArguments,
    EarlyStoppingCallback,
    default_data_collator # Use the default collator
)
from peft import LoraConfig, get_peft_model, PeftModel, prepare_model_for_kbit_training
from torch.utils.data import random_split # Keep random_split if needed for custom splits
import wandb
from accelerate.utils import set_seed
from accelerate import Accelerator
from tqdm.notebook import tqdm
# Corrected import: Explicitly import Dataset along with others
from datasets import load_dataset, DatasetDict, concatenate_datasets, Image as ds_Image, Dataset

import os
from datasets import load_from_disk, DatasetDict, concatenate_datasets
import gc
import traceback
import io
import torch # Needed for empty tensor return

from chartllama_load import load_chartllama_data, create_chat_prompt

In [5]:
# --- CHOOSE TRAINING METHOD ---
USE_LORA = False    # Set to True for LoRA or QLoRA, False for Full Fine-Tuning
USE_QLORA = False   # Set to True for QLoRA (requires USE_LORA=True)

In [6]:
# --- Basic Configuration ---
WANDB_PROJECT = "smolvlm-chartllama" # Project name for Weights & Biases (set to None to disable)
MODEL_ID = "HuggingFaceTB/SmolVLM-Base"
DATA_DIR = Path("./chartllama_data") # Directory containing ChartLlama JSON files and image folder
CHARTQA_CACHE_DIR = "./chartqa_cache" # Cache directory for ChartQA dataset
SEED = 42
SAMPLE_LIMIT = None # Limit number of samples for faster testing (e.g., 100). None for all data.
EVAL_SAMPLE_LIMIT = 100 # Limit number of ChartQA samples for evaluation
SYSTEM_MESSAGE = "You are a helpful assistant that analyzes charts and answers questions about them." # Optional system prompt
PROCESSED_DATA_DIR = "./processed_data"

In [7]:
# --- Output Directories ---
BASE_OUTPUT_DIR = "./smolvlm-chartllama"
if not USE_LORA:
    TRAINING_TYPE = "full-tuned"
    print("Configuring for Full Fine-Tuning.")
elif USE_QLORA:
    TRAINING_TYPE = "qlora-tuned"
    print("Configuring for QLoRA.")
    if not USE_LORA:
        print("Warning: USE_QLORA=True requires USE_LORA=True. Setting USE_LORA=True.")
        USE_LORA = True
else:
    TRAINING_TYPE = "lora-tuned"
    print("Configuring for standard LoRA.")

OUTPUT_DIR = f"{BASE_OUTPUT_DIR}-{TRAINING_TYPE}"
FINAL_ADAPTER_DIR = os.path.join(OUTPUT_DIR, "final_adapter")
FINAL_PROCESSOR_DIR = os.path.join(OUTPUT_DIR, "final_processor") # Consistent processor saving


Configuring for Full Fine-Tuning.


In [8]:
# --- Hardware & Precision ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# BF16 is preferred on Ampere+ GPUs for stability and speed
DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
print(f"Using device: {DEVICE}")
print(f"Using computational dtype: {DTYPE}")

Using device: cuda
Using computational dtype: torch.bfloat16


In [9]:
# --- LoRA Specific Configuration ---
if USE_LORA:
    LORA_R = 16
    LORA_ALPHA = 32
    LORA_DROPOUT = 0.1
    # Common target modules for Llama-like architectures used in Idefics3
    LORA_TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]

In [10]:
# --- Setup ---
os.makedirs(OUTPUT_DIR, exist_ok=True)
set_seed(SEED)
torch.cuda.empty_cache()
gc.collect()
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [11]:
# --- Weights & Biases Setup ---
if WANDB_PROJECT:
    wandb.init(project=WANDB_PROJECT, config={
        "model_id": MODEL_ID,
        "training_type": TRAINING_TYPE,
        "use_lora": USE_LORA,
        "use_qlora": USE_QLORA,
        "dtype": str(DTYPE),
        "seed": SEED,
        "lora_r": LORA_R if USE_LORA else None,
        "lora_alpha": LORA_ALPHA if USE_LORA else None,
        "sample_limit": SAMPLE_LIMIT,
        "output_dir": OUTPUT_DIR,
    }, job_type="fine-tuning")
else:
    os.environ["WANDB_DISABLED"] = "true" # Disable wandb if project name is None

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mryan-seet467[0m ([33mryan-seet467-georgia-institute-of-technology[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [12]:
# https://wandb.ai/authorize
# 0469802d14d997b8dad4d23a7ba212e0a8d8f197

In [13]:
print(f"\n--- Configuration Summary ---")
print(f"Model ID: {MODEL_ID}")
print(f"Training Type: {TRAINING_TYPE}")
print(f"Using LoRA: {USE_LORA}")
print(f"Using QLoRA: {USE_QLORA}")
print(f"Output Directory: {OUTPUT_DIR}")
print(f"Device: {DEVICE}")
print(f"Dtype: {DTYPE}")
print(f"Seed: {SEED}")
if USE_LORA:
    print(f"LoRA R: {LORA_R}, Alpha: {LORA_ALPHA}, Dropout: {LORA_DROPOUT}")
print("---------------------------\n")


--- Configuration Summary ---
Model ID: HuggingFaceTB/SmolVLM-Base
Training Type: full-tuned
Using LoRA: False
Using QLoRA: False
Output Directory: ./smolvlm-chartllama-full-tuned
Device: cuda
Dtype: torch.bfloat16
Seed: 42
---------------------------



## Check Hardware and Clear Memory

In [14]:
# --- Check Hardware ---
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU: {gpu_name}")
    print(f"GPU memory: {gpu_memory_gb:.2f} GB")
    print(f"Using compute dtype: {DTYPE}")
    if USE_LORA and not USE_QLORA and gpu_memory_gb < 20: # Rough estimate for standard LoRA
         print("Warning: Standard LoRA without quantization might require significant VRAM (>20GB). Consider using USE_QLORA=True if you encounter memory issues.")
    elif not USE_LORA and gpu_memory_gb < 40: # Rough estimate for FFT
         print("Warning: Full Fine-Tuning requires substantial VRAM (often >40GB for models of this size). Monitor memory usage closely.")

PyTorch version: 2.6.0+cu124
CUDA available: True
CUDA version: 12.4
GPU: NVIDIA A100-SXM4-40GB
GPU memory: 42.47 GB
Using compute dtype: torch.bfloat16


### Load Processor (Before Data Loading)


In [15]:
# --- Load Processor ---
# Load the processor first, as it's needed for preprocessing
print("Loading processor...")
processor = Idefics3Processor.from_pretrained(MODEL_ID)
if processor.tokenizer.pad_token is None:
    processor.tokenizer.pad_token = processor.tokenizer.eos_token # Set pad token if not defined
print("Processor loaded.")

Loading processor...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Processor loaded.


### Data Loading and Preprocessing (Replaces custom Dataset and Collate)

In [16]:

def preprocess_data_batched(batch, processor, system_message=""):
    """
    Preprocesses a BATCH of data using the Idefics3Processor.
    Handles potential missing images or template errors within the batch.
    """
    batch_size = len(batch["id"]) # Assumes 'id' key exists and determines batch size
    all_prompts_text = []
    all_images = []
    valid_indices_for_batch = [] # Keep track of samples processed correctly in this batch

    for i in range(batch_size):
        sample_id = batch.get("id", ["Unknown"]*batch_size)[i] # Get sample ID safely
        try:
            # --- Image Handling ---
            image = batch["image"][i]
            if image is None:
                # print(f"Debug: Skipping sample {sample_id} - no image.") # Keep commented unless debugging
                continue

            # Basic PIL conversion check (add more robust conversion if needed based on your data)
            if not isinstance(image, Image.Image):
                if isinstance(image, dict) and 'bytes' in image and image['bytes']:
                     image = Image.open(io.BytesIO(image['bytes'])).convert("RGB")
                else:
                     # Add other necessary conversions here if needed
                     # print(f"Debug: Could not convert image type {type(image)} for {sample_id}")
                     raise TypeError(f"Unhandled image type: {type(image)}")

            if image.mode != 'RGB':
                image = image.convert("RGB")

            # --- Text Handling ---
            # Reconstruct single sample dict for create_chat_prompt
            single_sample = {key: values[i] for key, values in batch.items() if isinstance(values, list) and i < len(values)}
            chat_messages = create_chat_prompt(single_sample, system_message)
            prompt_text = processor.apply_chat_template(chat_messages, add_generation_prompt=False)

            # --- Add valid data ---
            all_prompts_text.append(prompt_text)
            all_images.append(image)
            valid_indices_for_batch.append(i) # Mark as valid for this batch

        except Exception as e:
            print(f"Warning: Error processing sample {sample_id} in batch: {e}. Skipping sample.")
            # Optionally uncomment traceback for deep debugging:
            # traceback.print_exc()

    # --- Processor Call ---
    if not all_prompts_text:
        # Return structure matching expected output but with empty tensors/lists
        # Ensure keys match what the Trainer/model expects!
        return {'input_ids': torch.tensor([]), 'attention_mask': torch.tensor([]), 'pixel_values': torch.tensor([]), 'labels': torch.tensor([])}

    try:
        inputs = processor(
            text=all_prompts_text,
            images=all_images,
            return_tensors="pt",
            padding="longest",
            truncation=True,
            max_length=getattr(getattr(processor, 'tokenizer', None), 'model_max_length', 2048)
        )
        # Set labels
        inputs['labels'] = inputs['input_ids'].clone()
        return inputs

    except Exception as proc_err:
        print(f"Error during processor call for batch (containing IDs around {sample_id}): {proc_err}")
        # Return empty structure on processor error
        return {'input_ids': torch.tensor([]), 'attention_mask': torch.tensor([]), 'pixel_values': torch.tensor([]), 'labels': torch.tensor([])}

# --- Corrected Batched Filter Function ---
def batched_filter_fn(batch):
    required_keys = ['input_ids', 'attention_mask', 'pixel_values', 'labels']
    key_for_length = None
    for k in required_keys: # Find a reliable key to determine batch size
        if k in batch and isinstance(batch[k], list) and batch[k]:
            key_for_length = k
            break

    if key_for_length is None: # Handle empty or malformed batch dictionary
         list_keys = [k for k,v in batch.items() if isinstance(v, list)]
         if not list_keys or not batch[list_keys[0]]: return []
         batch_size = len(batch[list_keys[0]])
         # print(f"Debug: Using backup key {list_keys[0]} for filter batch size.")
         return [False] * batch_size # Filter out if primary keys are missing/empty

    batch_size = len(batch[key_for_length])
    keep_mask = [True] * batch_size

    # Check structure once per batch
    keys_present_and_correct_length = all(
        k in batch and isinstance(batch[k], list) and len(batch[k]) == batch_size
        for k in required_keys
    )
    if not keys_present_and_correct_length:
        # print(f"Debug: Batch structure mismatch in filter. Filtering out batch.")
        return [False] * batch_size

    # Iterate samples within the batch
    for i in range(batch_size):
        # Check individual elements for validity
        input_ids_ok = batch['input_ids'][i] is not None and hasattr(batch['input_ids'][i], '__len__') and len(batch['input_ids'][i]) > 0
        attn_mask_ok = batch['attention_mask'][i] is not None
        pixel_values_ok = batch['pixel_values'][i] is not None
        labels_ok = batch['labels'][i] is not None

        if not (input_ids_ok and attn_mask_ok and pixel_values_ok and labels_ok):
            keep_mask[i] = False
    return keep_mask


### Load and Preprocess

In [17]:
# --- Main Loading/Preprocessing/Saving Block ---

processed_data_exists = os.path.isdir(PROCESSED_DATA_DIR)
train_dataset, val_dataset, test_dataset = None, None, None # Initialize

if processed_data_exists:
    print(f"\nLoading preprocessed dataset from: {PROCESSED_DATA_DIR}")
    try:
        processed_dataset_dict = load_from_disk(PROCESSED_DATA_DIR)
        train_dataset = processed_dataset_dict["train"]
        val_dataset = processed_dataset_dict["validation"]
        test_dataset = processed_dataset_dict["test"]
        print("Preprocessed dataset loaded successfully.")
        # Basic validation after load
        if not train_dataset or not val_dataset or not test_dataset:
             raise ValueError("Loaded dataset splits are missing or empty.")
        print(f"Loaded splits - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

    except Exception as e:
        print(f"Error loading preprocessed data from {PROCESSED_DATA_DIR}: {e}")
        print("Will attempt to re-process from raw data.")
        processed_data_exists = False # Force re-processing

if not processed_data_exists:
    print("\nStarting raw data loading and preprocessing...")
    try:
        # Ensure dependencies are available
        if 'processor' not in globals() or processor is None:
             raise NameError("Processor not defined or loaded.")
        if 'SYSTEM_MESSAGE' not in globals():
             raise NameError("SYSTEM_MESSAGE not defined.")
        # Assuming chartllama_load is imported and works
        # from chartllama_load import load_chartllama_data

        # 1. Load Raw Data
        print("Loading raw ChartLlama data...")
        raw_dataset = load_chartllama_data(DATA_DIR, sample_limit=SAMPLE_LIMIT)
        print(f"Raw data loaded: {len(raw_dataset)} samples.")

        # 2. Preprocess using .map()
        print("Preprocessing raw data (batch_size=1)...")
        keep_cols = ["image", "question", "answer", "id"] # Columns needed by preprocess_data_batched
        remove_cols = [col for col in raw_dataset.column_names if col not in keep_cols]
        processed_dataset = raw_dataset.map(
            lambda batch: preprocess_data_batched(batch, processor, SYSTEM_MESSAGE),
            batched=True,
            batch_size=1,             # Keep batch_size=1 for stability during map
            remove_columns=remove_cols,
            desc="Preprocessing data"
        )
        print("Preprocessing map step complete.")

        # 3. Filter using batched filter function
        print("Filtering processed data...")
        original_len = len(processed_dataset)
        processed_dataset = processed_dataset.filter(
            batched_filter_fn,
            batched=True,
            batch_size=10,            # Batching filter is efficient, use 10 or more
            desc="Filtering data"
        )
        filtered_len = len(processed_dataset)
        print(f"Filtering complete. Kept {filtered_len}/{original_len} samples.")

        if filtered_len == 0:
            raise ValueError("Dataset is empty after filtering.")

        # 4. Split into Train/Validation/Test
        print("Splitting dataset...")
        processed_dataset = processed_dataset.shuffle(seed=SEED) # Shuffle before split
        if filtered_len < 10:
             print("Warning: Less than 10 samples. Using all for train/val/test.")
             train_dataset = processed_dataset
             val_dataset = processed_dataset
             test_dataset = processed_dataset
        else:
            train_val_split = processed_dataset.train_test_split(test_size=0.2, seed=SEED)
            train_dataset = train_val_split["train"]
            test_val_combined = train_val_split["test"]
            if len(test_val_combined) >= 2:
                 val_test_split = test_val_combined.train_test_split(test_size=0.5, seed=SEED)
                 val_dataset = val_test_split["train"]
                 test_dataset = val_test_split["test"]
            else:
                 val_dataset = test_val_combined
                 test_dataset = test_val_combined
        print(f"Splits created - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")


        # 5. Save the Processed Splits
        print(f"Saving preprocessed dataset splits to: {PROCESSED_DATA_DIR}")
        os.makedirs(PROCESSED_DATA_DIR, exist_ok=True)
        final_splits = DatasetDict({
            "train": train_dataset,
            "validation": val_dataset,
            "test": test_dataset
        })
        final_splits.save_to_disk(PROCESSED_DATA_DIR)
        print("Preprocessed dataset saved.")

        # Clean up intermediate objects
        del raw_dataset
        del processed_dataset
        del final_splits
        gc.collect()

    except Exception as e:
        print(f"\nAn error occurred during data loading/processing: {e}")
        traceback.print_exc()
        # Ensure datasets are None if processing failed
        train_dataset, val_dataset, test_dataset = None, None, None

# --- Final Check ---
if train_dataset is None or val_dataset is None or test_dataset is None:
    print("\nERROR: Dataset preparation failed. Cannot proceed.")
    # Handle error appropriately, e.g., exit() or raise Exception
else:
    print("\nDataset preparation complete. Proceeding to next steps (Model Loading/Trainer).")
    # ... (rest of your notebook code) ...


Starting raw data loading and preprocessing...
Loading raw ChartLlama data...
Found 7 JSON files in chartllama_data.


Reading JSONs: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 7/7 [00:01<00:00,  6.45it/s]


Initial processing complete.
Loaded 980 valid samples.
Skipped 0 samples due to missing data or images.
Encountered 0 errors during file processing.

Created final dataset with 980 samples.
Dataset features: {'id': Value(dtype='string', id=None), 'image': Image(mode=None, decode=True, id=None), 'question': Value(dtype='string', id=None), 'answer': Value(dtype='string', id=None)}
Raw data loaded: 980 samples.
Preprocessing raw data (batch_size=1)...





Preprocessing data:   0%|          | 0/980 [00:00<?, ? examples/s]

Preprocessing map step complete.
Filtering processed data...


Filtering data:   0%|          | 0/980 [00:00<?, ? examples/s]

Filtering complete. Kept 980/980 samples.
Splitting dataset...
Splits created - Train: 784, Val: 98, Test: 98
Saving preprocessed dataset splits to: ./processed_data


Saving the dataset (0/62 shards):   0%|          | 0/784 [00:00<?, ? examples/s]

Saving the dataset (0/8 shards):   0%|          | 0/98 [00:00<?, ? examples/s]

Saving the dataset (0/8 shards):   0%|          | 0/98 [00:00<?, ? examples/s]

Preprocessed dataset saved.

Dataset preparation complete. Proceeding to next steps (Model Loading/Trainer).


## 7. Model Loading and Preparation


In [18]:
# --- Configure Model Loading ---
model_load_kwargs = {}
quantization_config = None

if USE_QLORA:
    print("Configuring model for QLoRA (8-bit quantization)...")
    # Note: SmolVLM example uses 8-bit, 4-bit might also work but requires careful testing
    quantization_config = BitsAndBytesConfig(
        load_in_8bit=True,
        # Optional: These can sometimes improve 8-bit performance/stability
        # bnb_8bit_use_double_quant=True,
        # bnb_8bit_quant_type="nf8",
        # bnb_8bit_compute_dtype=DTYPE # Compute in higher precision
    )
    model_load_kwargs["quantization_config"] = quantization_config
    model_load_kwargs["device_map"] = "auto" # Recommended for quantization
    print("Using device_map='auto' for QLoRA loading.")
elif not USE_LORA: # Full Fine-Tuning
     print(f"Configuring model for Full Fine-Tuning (dtype: {DTYPE})...")
     model_load_kwargs["torch_dtype"] = DTYPE
     # device_map="auto" can sometimes cause issues with FFT DataParallel/DDP,
     # manually moving to DEVICE later might be safer if issues arise.
     # model_load_kwargs["device_map"] = "auto"
else: # Standard LoRA
    print(f"Configuring model for standard LoRA (dtype: {DTYPE})...")
    model_load_kwargs["torch_dtype"] = DTYPE
    # device_map="auto" can be useful here too, but manual placement also works
    # model_load_kwargs["device_map"] = "auto"


# --- Load Model ---
print(f"\nLoading model '{MODEL_ID}' with config: {model_load_kwargs}")
start_time = time.time()
model = Idefics3ForConditionalGeneration.from_pretrained(
    MODEL_ID,
    **model_load_kwargs,
     # attn_implementation="flash_attention_2" # Optional: Use if available and desired for speed
     trust_remote_code=True # Might be needed for Idefics3 architecture specifics
)

# Manually move model to device if device_map wasn't used
if "device_map" not in model_load_kwargs:
     print(f"Moving model to device: {DEVICE}")
     model.to(DEVICE)

print(f"Model loaded in {time.time() - start_time:.2f} seconds")
print(f"Model device: {model.device}")


# --- Prepare Model for Fine-tuning ---
if USE_LORA:
    print("\nPreparing model for LoRA/QLoRA training...")
    if USE_QLORA:
        print("Applying prepare_model_for_kbit_training for QLoRA...")
        # Gradient checkpointing is implicitly enabled by this function
        model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)

    # Configure LoRA
    print("Applying LoRA configuration...")
    lora_config = LoraConfig(
        r=LORA_R,
        lora_alpha=LORA_ALPHA,
        target_modules=LORA_TARGET_MODULES,
        lora_dropout=LORA_DROPOUT,
        bias="none",
        task_type="CAUSAL_LM",
        # modules_to_save = ["lm_head", "embed_tokens"] # Optional: Train output layer? Needs testing.
    )
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()
else:
      print("\nModel configured for Full Fine-tuning.")
      # Enable gradient checkpointing manually if not using LoRA helpers
      if hasattr(model, "gradient_checkpointing_enable"):
           print("Enabling gradient checkpointing for Full Fine-Tuning.")
           model.gradient_checkpointing_enable()


# Disable cache for training
model.config.use_cache = False

Configuring model for Full Fine-Tuning (dtype: torch.bfloat16)...

Loading model 'HuggingFaceTB/SmolVLM-Base' with config: {'torch_dtype': torch.bfloat16}
Moving model to device: cuda
Model loaded in 3.81 seconds
Model device: cuda:0

Model configured for Full Fine-tuning.
Enabling gradient checkpointing for Full Fine-Tuning.


### Training Arguments

In [19]:
# --- Configure Training Arguments ---
# Adjust hyperparameters based on whether it's FFT, LoRA, or QLoRA and available VRAM

per_device_train_batch_size = 2 # Decrease if OOM
per_device_eval_batch_size = 4  # Decrease if OOM during eval
gradient_accumulation_steps = 8  # Increase to simulate larger batch size if VRAM is limited
learning_rate = 2e-5 if not USE_LORA else 1e-4 # LoRA often benefits from higher LR
num_train_epochs = 3
weight_decay = 0.01
warmup_ratio = 0.1
lr_scheduler_type = "cosine"
logging_steps = 25
save_steps = 500 # Or adjust based on epoch strategy
eval_steps = 500 # Or adjust based on epoch strategy

print(f"\nSetting Training Arguments for {TRAINING_TYPE}...")
print(f"Batch Size (Train): {per_device_train_batch_size}, Grad Accum: {gradient_accumulation_steps}")
print(f"Effective Batch Size: {per_device_train_batch_size * gradient_accumulation_steps * torch.cuda.device_count() if torch.cuda.is_available() else per_device_train_batch_size * gradient_accumulation_steps}")
print(f"Learning Rate: {learning_rate}")

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    per_device_eval_batch_size=per_device_eval_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    learning_rate=learning_rate,

    weight_decay=weight_decay,
    warmup_ratio=warmup_ratio,
    lr_scheduler_type=lr_scheduler_type,

    logging_strategy="steps",
    logging_steps=logging_steps,

    save_strategy="epoch",         # Save checkpoints every epoch
    evaluation_strategy="epoch",   # Evaluate every epoch
    # save_strategy="steps",
    # evaluation_strategy="steps",
    # save_steps=save_steps,
    # eval_steps=eval_steps,

    load_best_model_at_end=True,      # Load the best model found during training
    metric_for_best_model="eval_loss",# Use eval loss to determine the best model
    greater_is_better=False,          # Lower eval loss is better
    save_total_limit=2,               # Keep only the best and the latest checkpoint

    fp16=(DTYPE == torch.float16),    # Use FP16 if selected
    bf16=(DTYPE == torch.bfloat16),   # Use BF16 if selected

    # Gradient checkpointing is enabled via prepare_model_for_kbit_training for QLoRA
    # or manually enabled for FFT/standard LoRA in the model prep cell.
    # Set False here if manually enabled, True otherwise (though manual is often clearer).
    gradient_checkpointing=False if USE_QLORA or not USE_LORA else True,
    gradient_checkpointing_kwargs={"use_reentrant": False}, # Recommended setting

    seed=SEED,
    optim="adamw_torch",              # Recommended optimizer
    report_to="wandb" if WANDB_PROJECT else "none", # Report to W&B if configured
    remove_unused_columns=False,      # Important: processor adds columns needed by model
    dataloader_num_workers=2,         # Adjust based on system
    dataloader_pin_memory=True,

    # Potential args for speeding up / memory saving (use with caution)
    # dataloader_drop_last=True,      # Can speed up if last batch is small
    # group_by_length=True,           # Can make padding more efficient, requires length column
)


Setting Training Arguments for full-tuned...
Batch Size (Train): 2, Grad Accum: 8
Effective Batch Size: 16
Learning Rate: 2e-05




### Initialize trasiner using default collator

In [20]:
# --- Initialize Trainer ---
if train_dataset and val_dataset:
    print("Initializing Trainer...")
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        processor=processor, # Pass processor for potential saving/logging
        data_collator=default_data_collator, # Use the default collator
        # Removed callbacks for simplicity, can add EarlyStoppingCallback back if needed
        # callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
    )
    print("Trainer initialized.")
else:
    print("\nSkipping Trainer initialization due to dataset loading/processing errors.")
    trainer = None

Initializing Trainer...


TypeError: Trainer.__init__() got an unexpected keyword argument 'processor'

In [None]:
# --- Training ---
if trainer:
    print("\nStarting training...")
    start_time = time.time()

    # Clear cache before training
    torch.cuda.empty_cache()
    gc.collect()

    try:
        train_result = trainer.train()

        training_time = (time.time() - start_time) / 60
        print(f"\nTraining completed in {training_time:.2f} minutes")

        # --- Save Results and Final Model/Adapter ---
        print("\nSaving training results, final model/adapter, and processor...")
        trainer.log_metrics("train", train_result.metrics)
        trainer.save_metrics("train", train_result.metrics)
        trainer.save_state()

        # Saving logic depends on the training type
        if USE_LORA:
            print(f"Saving LoRA adapter weights to: {FINAL_ADAPTER_DIR}")
            model.save_pretrained(FINAL_ADAPTER_DIR) # Saves only the adapter
        else: # Full Fine-Tuning
            print(f"Saving full fine-tuned model to: {OUTPUT_DIR}")
            # Trainer saves the best model automatically based on `load_best_model_at_end`
            # If you want to explicitly save the *final* state regardless of best:
            # model.save_pretrained(os.path.join(OUTPUT_DIR, "final_model_state"))
            # Note: Be mindful that Trainer might save the *best* model to OUTPUT_DIR itself.
            # Saving the adapter here might be redundant or overwrite if paths are the same.
            # It's generally safer to rely on Trainer's saving for FFT best model.
            # We'll save the processor separately for consistency.

        # Save the processor in both cases
        print(f"Saving processor to: {FINAL_PROCESSOR_DIR}")
        # Ensure the directory exists before saving
        os.makedirs(FINAL_PROCESSOR_DIR, exist_ok=True)
        if processor:
             processor.save_pretrained(FINAL_PROCESSOR_DIR)
        else:
             print("Warning: Processor object not found, cannot save.")

        print(f"\nModel {'adapter' if USE_LORA else 'weights (best checkpoint)'} and processor saved.")

    except Exception as e:
        print(f"\nAn error occurred during training: {e}")
        # Optionally try to save state even on error
        try:
            trainer.save_state()
            print("Attempted to save trainer state after error.")
        except:
            print("Could not save trainer state after error.")
else:
    print("\nTraining skipped due to errors in data loading, preprocessing, or trainer initialization.")

### Evaluation Setup (Load ChartQA, Define Accuracy)

In [None]:
# --- Evaluation Setup ---

print("\n--- Preparing for Evaluation ---")

# Clear some memory before loading evaluation models/datasets
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

# 1. Download ChartQA Dataset
print("\nDownloading ChartQA dataset...")
chartqa_test_dataset = None
try:
    # Use trust_remote_code=True if required by the dataset script
    chartqa_test_dataset = load_dataset("HuggingFaceM4/ChartQA", split="test", cache_dir=CHARTQA_CACHE_DIR, trust_remote_code=True)
    # Limit evaluation samples if needed
    if EVAL_SAMPLE_LIMIT and EVAL_SAMPLE_LIMIT < len(chartqa_test_dataset):
        print(f"Limiting evaluation to {EVAL_SAMPLE_LIMIT} samples.")
        chartqa_test_dataset = chartqa_test_dataset.select(range(EVAL_SAMPLE_LIMIT))
    print(f"Loaded ChartQA test split with {len(chartqa_test_dataset)} examples.")
    print("\nFirst ChartQA test sample structure:")
    print(chartqa_test_dataset[0])
except Exception as e:
    print(f"Error downloading/loading ChartQA dataset: {e}")


# 2. Define Relaxed Accuracy Function (Modified for list label)
def calculate_relaxed_accuracy(prediction, ground_truths):
    """
    Checks if the prediction string is contained within any of the ground truth strings (case-insensitive).
    Handles cases where ground_truths is a list (like in ChartQA).
    """
    if not prediction or not ground_truths:
        return False

    # Ensure ground_truths is treated as a list, even if it's a single string initially
    if isinstance(ground_truths, str):
        ground_truths = [ground_truths]
    elif not isinstance(ground_truths, list):
         # Try converting to string if it's neither list nor string (e.g., number)
         try:
              ground_truths = [str(ground_truths)]
         except:
              print(f"Warning: Could not convert ground_truth '{ground_truths}' to list/string.")
              return False # Cannot compare if format is unknown

    prediction_lower = prediction.lower().strip()
    if not prediction_lower: # Handle empty prediction string
        return False

    for gt in ground_truths:
        gt_str = str(gt) if gt is not None else ""
        gt_lower = gt_str.lower().strip()
        if not gt_lower:
            continue
        # Check if prediction is a substring of ground truth
        if prediction_lower in gt_lower:
            return True
        # Optional: Check if ground truth is a substring of prediction (more relaxed)
        # if gt_lower in prediction_lower:
        #     return True

    return False

EVAL_ACCURACY_FUNC = calculate_relaxed_accuracy
print(f"\nUsing accuracy function: {EVAL_ACCURACY_FUNC.__name__}")


In [None]:
# 3. Helper function for inference (Updated prompt format)
@torch.no_grad()
def generate_eval_answer(model, processor, image, question, device, max_new_tokens=128):
    """Generates an answer for evaluation using the chat template."""
    if image is None or question is None:
        print("Warning: Skipping inference due to None image or question.")
        return "Error: Missing image or question"

    # Ensure image is PIL
    if not isinstance(image, Image.Image):
         try:
              # Handle datasets.Image format
              if isinstance(image, dict) and 'bytes' in image and image['bytes']:
                   image = Image.open(io.BytesIO(image['bytes'])).convert("RGB")
              elif isinstance(image, dict) and 'path' in image and image['path']:
                   image = Image.open(image['path']).convert("RGB")
              elif isinstance(image, bytes):
                   image = Image.open(io.BytesIO(image)).convert("RGB")
              else:
                   # Attempt conversion as a last resort
                   image = Image.fromarray(np.array(image)).convert("RGB")
         except Exception as e:
              print(f"Warning: Could not process/convert image of type {type(image)}. Error: {e}. Skipping inference.")
              return "Error: Image processing failed"

    # Ensure RGB
    if image.mode != 'RGB':
        image = image.convert("RGB")

    # Format prompt using the chat template structure for user turn
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": question}
            ]
        }
    ]
    # Apply template, adding the generation prompt for the assistant's turn
    prompt = processor.apply_chat_template(messages, add_generation_prompt=True)

    try:
        # Process inputs
        inputs = processor(text=[prompt], images=[image], return_tensors="pt", padding=True).to(device)

        # Generate
        # Set pad_token_id to eos_token_id for open-ended generation
        generation_args = {
            "max_new_tokens": max_new_tokens,
            "pad_token_id": processor.tokenizer.eos_token_id,
            "eos_token_id": processor.tokenizer.eos_token_id,
            "do_sample": False, # Greedy decoding for deterministic eval
            "num_beams": 1,
        }

        generated_ids = model.generate(**inputs, **generation_args)

        # Decode only generated tokens
        input_len = inputs['input_ids'].shape[1]
        generated_texts = processor.batch_decode(generated_ids[:, input_len:], skip_special_tokens=True)

        # Basic cleanup
        generated_text = generated_texts[0].strip()
        return generated_text

    except Exception as e:
        print(f"Error during model inference for question '{question[:50]}...': {e}")
        # import traceback
        # traceback.print_exc() # Uncomment for detailed traceback
        return "Error: Inference failed"

print("\nEvaluation setup complete.")