<a href="https://colab.research.google.com/github/tardigrade-dot/colab-script/blob/main/%F0%9F%92%A7_LFM2_VL_SFT_with_TRL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 💧 LFM2-VL - SFT with TRL

This tutorial demonstrates how to fine-tune our LFM2 vision models [`LiquidAI/LFM2-VL-3B`](https://huggingface.co/LiquidAI/LFM2-VL-3B), [`LiquidAI/LFM2-VL-1.6B`](https://huggingface.co/LiquidAI/LFM2-VL-1.6B) and [`LiquidAI/LFM2-VL-450M`](https://huggingface.co/LiquidAI/LFM2-VL-450M) using the TRL library.

Follow along if it's your first time using trl, or take single code snippets for your own workflow

## 🎯 What You'll Find:
- **SFT** (Supervised Fine-Tuning) - Basic instruction following
- **LoRA + SFT** - (Optional) using LoRA (from PEFT) to SFT while on constrained hardware

## 📋 Prerequisites:
- **GPU Runtime**: Select GPU in `Runtime` → `Change runtime type`
- **Hugging Face Account**: For accessing models and datasets

# 📦 Installation & Setup

First, let's install all the required packages:


In [None]:
!pip install -qqq git+https://github.com/huggingface/transformers.git@93671b4444414b01ea034bd64614856644297a66 datasets trl --progress-bar off

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


Let's now verify the packages are installed correctly


In [None]:
import torch
import transformers
import trl
import os
os.environ["WANDB_DISABLED"] = "true"

print(f"📦 PyTorch version: {torch.__version__}")
print(f"🤗 Transformers version: {transformers.__version__}")
print(f"📊 TRL version: {trl.__version__}")

📦 PyTorch version: 2.8.0+cu126
🤗 Transformers version: 5.0.0.dev0
📊 TRL version: 0.24.0


# Loading the model from Transformers 🤗


In [None]:
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor

model_id = "LiquidAI/LFM2-VL-3B" # or LiquidAI/LFM2-VL-1.6B | LiquidAI/LFM2-VL-450M

print("📚 Loading processor...")
processor_source = model_id
processor = AutoProcessor.from_pretrained(
    processor_source,
    max_image_tokens=256,
)

print("🧠 Loading model...")
model = AutoModelForImageTextToText.from_pretrained(
    model_id,
    torch_dtype="bfloat16",
    device_map="auto",
)

print("\n✅ Local model loaded successfully!")
print(f"📖 Vocab size: {len(processor.tokenizer)}")
print(f"🔢 Parameters: {model.num_parameters():,}")
print(f"💾 Model size: ~{model.num_parameters() * 2 / 1e9:.1f} GB (bfloat16)")

📚 Loading processor...


Access to the secret `HF_TOKEN` has not been granted on this notebook.
You will not be requested again.
Please restart the session if you want to be prompted again.


processor_config.json: 0.00B [00:00, ?B/s]

chat_template.jinja:   0%|          | 0.00/434 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/584 [00:00<?, ?B/s]

🧠 Loading model...


config.json: 0.00B [00:00, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.01G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/136 [00:00<?, ?B/s]


✅ Local model loaded successfully!
📖 Vocab size: 64400
🔢 Parameters: 2,998,975,216
💾 Model size: ~6.0 GB (bfloat16)


# 🎯 Supervised Fine-Tuning (SFT + LoRA)

SFT teaches the model to follow instructions by training on input-output pairs (instruction vs response). This is the foundation for creating instruction-following models.

Full SFT might be too compute heavy if you're running on one of the free-tier colab GPUs. Hence we use LoRA (Low-Rank Adaptation) to finetune the model by only training a small number of additional parameters. Perfect for limited compute resources!

## Load an SFT Dataset

We will use [simwit/omni-med-vqa-mini](https://huggingface.co/datasets/simwit/omni-med-vqa-mini), which is small dataset for Medical QA.

In [None]:
from datasets import load_dataset

raw_ds = load_dataset("simwit/omni-med-vqa-mini")
full_dataset = raw_ds["test"]
split = full_dataset.train_test_split(test_size=0.2, seed=42)
train_dataset = split["train"]
eval_dataset = split["test"]

print("✅ SFT Dataset loaded:")
print(f"   📚 Train samples: {len(train_dataset)}")
print(f"   🧪 Eval samples: {len(eval_dataset)}")
print(f"\n📝 Single Sample: [IMAGE] {train_dataset[0]['question']} {train_dataset[0]['gt_answer']}")

README.md:   0%|          | 0.00/766 [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/474M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/2000 [00:00<?, ? examples/s]

✅ SFT Dataset loaded:
   📚 Train samples: 1600
   🧪 Eval samples: 400

📝 Single Sample: [IMAGE] What content appears in this image? Lungs


We now transform the dataset to a format compatible with `trl`



In [None]:
system_message = (
    "You are a medical Vision Language Model specialized in analyzing medical images and providing clinical insights. "
    "Provide concise, clinically relevant answers based on the image and question."
)

def format_medical_sample(sample):
    return [
        {"role": "system", "content": [{"type": "text", "text": system_message}]},
        {
            "role": "user",
            "content": [
                {"type": "image", "image": sample["image"]},
                {"type": "text", "text": sample["question"]},
            ],
        },
        {"role": "assistant", "content": [{"type": "text", "text": sample["gt_answer"]}]},
    ]

train_dataset = [format_medical_sample(s) for s in train_dataset]
eval_dataset = [format_medical_sample(s) for s in eval_dataset]

print("✅ SFT Dataset formatted:")
print(f"   📚 Train samples: {len(train_dataset)}")
print(f"   🧪 Eval samples: {len(eval_dataset)}")

✅ SFT Dataset formatted:
   📚 Train samples: 1600
   🧪 Eval samples: 400


## Collate function
Let's now create a collate function to batch chat text and corresponding RGB images into model-ready tensors using the processor


In [None]:
def create_collate_fn(processor):
    """Create a collate function that prepares batch inputs for the processor."""
    def collate_fn(sample):
        batch = processor.apply_chat_template(sample, tokenize=True, return_dict=True, return_tensors="pt")
        labels = batch["input_ids"].clone()
        labels[labels == processor.tokenizer.pad_token_id] = -100
        batch["labels"] = labels
        return batch
    return collate_fn

collate_fn = create_collate_fn(processor)

## Wrap the model with PEFT (Optional)

We specify target modules that will be finetuned while the rest of the models weights remains frozen. Feel free to modify the `r` (rank) value:
- Higher → better approximation of full-finetuning
- Lower → needs even less compute resources

You can skip this part if you have a premium GPU and want to go for a full finetune.

In [None]:
from peft import LoraConfig, get_peft_model

target_modules = [
    "q_proj", "v_proj", "fc1", "fc2", "linear",
    "gate_proj", "up_proj", "down_proj",
]

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=8,
    bias="none",
    target_modules=target_modules,
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

trainable params: 3,778,304 || all params: 3,002,753,520 || trainable%: 0.1258


## Launch Training

We are now ready to launch an SFT run with `SFTTrainer`, feel free to modify `SFTConfig` to play around with different configurations.


In [None]:
from trl import SFTConfig, SFTTrainer

sft_config = SFTConfig(
    output_dir="lfm2-vl-med",
    num_train_epochs=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    learning_rate=5e-4,
    warmup_ratio=0.1,
    weight_decay=0.01,
    logging_steps=10,
    optim="adamw_torch_8bit",
    gradient_checkpointing=True,
    max_length=512,
    dataset_kwargs={"skip_prepare_dataset": True},
)

print("🏗️  Creating SFT trainer...")
sft_trainer = SFTTrainer(
    model=model,
    args=sft_config,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=collate_fn,
    processing_class=processor.tokenizer,
)

print("\n🚀 Starting SFT training...")
sft_trainer.train()

print("🎉 SFT training completed!")

sft_trainer.save_model()
print(f"💾 Saving to: {sft_config.output_dir}")

The model is already on multiple devices. Skipping the move to device specified in `args`.


🏗️  Creating SFT trainer...

🚀 Starting SFT training...




Step,Training Loss


## (Optional) Save merged model

In case you used LoRA, merge the extra weights learned back into the model to obtain a "normal" model checkpoint.

In [None]:
if hasattr(model, 'peft_config'):
    print("🔄 Merging LoRA weights...")
    model = model.merge_and_unload()
model.save_pretrained("./lfm2-vl-med")
processor.save_pretrained("./lfm2-vl-med")
print("💾 Model saved to: ./lfm2-vl-med")