1. set up the environment

In [None]:
# https://huggingface.co/docs/transformers/en/model_doc/llava


from transformers import AutoProcessor, AutoTokenizer, LlavaOnevisionForConditionalGeneration
import torch
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

model_id = "llava-hf/llava-onevision-qwen2-7b-ov-hf"
processor = AutoProcessor.from_pretrained(model_id)

2. Create and prepare the dataset

Once you have determined that fine-tuning is the right solution we need to create a dataset to fine-tune our model. We have to prepare the dataset in a format that the model can understand.

TRL supports popular instruction and conversation dataset formats. This means we only need to convert our dataset to one of the supported formats and trl will take care of the rest.
```
"messages": [
  {"role": "system", "content": [{"type":"text", "text": "You are a helpful...."}]},
  {"role": "user", "content": [{
    "type": "text", "text":  "How many dogs are in the image?", 
    "type": "image", "text": <PIL.Image> 
    }]},
  {"role": "assistant", "content": [{"type":"text", "text": "There are 3 dogs in the image."}]}
],
```
In our example we are going to load our dataset using the Datasets library and apply our frompt and convert it into the the conversational format.

Lets start!

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import io

def format_data(sample):
    return {
        # "images": [Image.open(io.BytesIO(img)) for img in sample["image"]],
        "messages": [
        {
            "role": "system",
            "content": [{"type": "text", "text": "You are a professional academic paper review assistant."}],
        },
        {   
            "role": "user",
            "content": [
                        *[{'type': 'image', 'image': Image.open(io.BytesIO(img))} for img in sample["image"]],
                        # {'type': 'image'},
                        {"type": "text", "text": "Please help me on reviewing this paper by given those images"}
                        ],
        },
        {
            "role": "assistant",
            "content": [{"type": "text", "text": sample["summaries"][0]}],
        },
    ]
    }


from datasets import load_dataset

dataset = load_dataset("DetionDX/neurips_openreview_v1", split="train")

print(dataset)

def display_example(example):
    print(f"ID: {example['id']}")
    print(f"Page numbers: {example['page_number']}")
    print(f"Number of images: {len(example['image'])}")
    print(f"Number of summaries: {len(example['summaries'])}")
    print(f"First summary: {example['summaries'][0]}")
    print("First image:")
    
    # Convert bytes to PIL Image
    image_bytes = example['image'][0]
    image = Image.open(io.BytesIO(image_bytes))
    
    # Display using matplotlib
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    plt.axis('off')  # Hide axes
    plt.show()

# Display an example
display_example(dataset[0])

dataset = [format_data(sample) for sample in dataset]


3. Fine-tune VLM using trl and the SFTTrainer

In [None]:
# https://huggingface.co/docs/transformers/en/model_doc/llava

from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration, BitsAndBytesConfig
import torch

model_id = "llava-hf/llava-onevision-qwen2-7b-ov-hf"
from transformers import BitsAndBytesConfig

from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model



USE_LORA = True
USE_QLORA = False

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16
)

model = LlavaOnevisionForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config if USE_QLORA else None,
)

def find_all_linear_names(model):
    cls = torch.nn.Linear
    lora_module_names = set()
    multimodal_keywords = ['multi_modal_projector', 'vision_model']
    for name, module in model.named_modules():
        if any(mm_keyword in name for mm_keyword in multimodal_keywords):
            continue
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if 'lm_head' in lora_module_names: # needed for 16-bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)

if USE_LORA:

    lora_config = LoraConfig(
        r=8,
        lora_alpha=8,
        lora_dropout=0.05,
        target_modules=find_all_linear_names(model),
        init_lora_weights="gaussian",
    )
    if USE_QLORA:
        model = prepare_model_for_kbit_training(model)
        
    model = get_peft_model(model, lora_config)



    model.print_trainable_parameters()


processor = AutoProcessor.from_pretrained(model_id)

In [4]:
# Create a data collator to encode text and image pairs
from transformers import Qwen2VLProcessor
from qwen_vl_utils import process_vision_info

def collate_fn(examples):
    # Get the texts and images, and apply the chat template
    texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
    # print(texts)
    # print(examples[0]["messages"])
    image_inputs = [process_vision_info(example["messages"])[0] for example in examples]

    
    # Tokenize the texts and process the images
    batch = processor(text=texts, images=image_inputs, return_tensors="pt", padding=True)

    # The labels are the input_ids, and we mask the padding tokens in the loss computation
    labels = batch["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100  #
    # Ignore the image token index in the loss computation (model specific)
    if isinstance(processor, Qwen2VLProcessor):
        image_tokens = [151652,151653,151655]
    else: 
        image_tokens = [processor.tokenizer.convert_tokens_to_ids(processor.image_token)]
    for image_token_id in image_tokens:
        labels[labels == image_token_id] = -100
    batch["labels"] = labels

    return batch

In [None]:
collate_fn([dataset[0]])

In [None]:
from trl import (
    ModelConfig,
    SFTConfig,
    SFTTrainer,
    TrlParser,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
)

# from transformers import TrainingArguments

args = SFTConfig(
    output_dir="llava-onevision-qwen2-7b-ov-neurips-openreview-v1", # directory to save and repository id
    num_train_epochs=10,                     # number of training epochs
    per_device_train_batch_size=1,          # batch size per device during training
    gradient_accumulation_steps=8,          # number of steps before performing a backward/update pass
    gradient_checkpointing=True,            # use gradient checkpointing to save memory
    optim="adamw_hf",              # use fused adamw optimizer
    logging_steps=5,                       # log every 10 steps
    save_strategy="epoch",                  # save checkpoint every epoch
    learning_rate=2e-4,                     # learning rate, based on QLoRA paper
    # fp16=False,
    bf16=True,                              # use bfloat16 precision
    tf32=True,                              # use tf32 precision
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.1,                      # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",           # use constant learning rate scheduler
    push_to_hub=False,                       # push model to hub
    report_to="tensorboard",                # report metrics to tensorboard
    gradient_checkpointing_kwargs = {"use_reentrant": False}, # use reentrant checkpointing
    dataset_text_field="", # need a dummy field for collator
    dataset_kwargs = {"skip_prepare_dataset": True} # important for collator
)
args.remove_unused_columns=False


trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset,
    data_collator=collate_fn,
    dataset_text_field="", # needs dummy value
    peft_config=lora_config,
    tokenizer=processor.tokenizer,
)



In [None]:
# start training, the model will be automatically saved to the hub and the output directory
trainer.train()

# save model 
trainer.save_model(args.output_dir)

In [8]:
# free the memory again
del model
del trainer
torch.cuda.empty_cache()