# ShareGPT4V - LoRA Fine-tuning

1. Set up the required libraries
2. Prepare the model and dataset
3. Configure LoRA
4. Fine-tune the model
5. Save the LoRA weights
6. Test the fine-tuned model

## 1. Setup and Requirements

In [None]:
import os
import torch
# import transformers
from transformers import AutoTokenizer, TrainingArguments
from peft import LoraConfig, get_peft_model, PeftModel

from share4v.model import Share4VLlamaForCausalLM
from share4v.constants import DEFAULT_IMAGE_TOKEN
from share4v.mm_utils import tokenizer_image_token
from share4v.model.builder import load_pretrained_model

from share4v.train.train import get_peft_state_maybe_zero_3, get_peft_state_non_lora_maybe_zero_3

# Set up basic configurations
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. Load Base Model

In [None]:
model_path = "Lin-Chen/ShareGPT4V-7B"
model_name = "share4v-7b"

tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path, None, model_name, False, False
)
model.requires_grad_(False)

## 3. Prepare Dataset

In [None]:
from share4v.train.train import (
    LazySupervisedDataset, 
    DataArguments, 
    DataCollatorForSupervisedDataset
)
data_args = DataArguments(
    data_path="./data/example_training.json",
    lazy_preprocess=False, # Typically TRUE
    is_multimodal=True,
    image_folder="/home/justas/ShareGPT4V/data",
    image_aspect_ratio="square"
)
data_args.image_processor = image_processor

In [None]:
dataset = LazySupervisedDataset(
    data_path=data_args.data_path,
    tokenizer=tokenizer,
    data_args=data_args
)
print(f"Dataset created with {len(dataset)} examples")

example

In [None]:
example = dataset[0]
print("\nExample data:")
print("-" * 50)
print(f"Keys in example: {list(example.keys())}")
if "input_ids" in example:
    print(f"Input IDs shape: {example['input_ids'].shape}")
if "labels" in example:
    print(f"Labels shape: {example['labels'].shape}")
if "image" in example:
    print(f"Image tensor shape: {example['image'].shape}")
print("-" * 50)

In [None]:
# example of 1 dataset entry with image and the conversation

In [None]:
# Use the original DataCollatorForSupervisedDataset from train.py
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)

## 4. Configure LoRA

In [None]:
# Select the linear layers for LoRA
# Copyed from train.py

def find_all_linear_names(model):
    """
    Find all linear layer names in the model that are suitable for LoRA.
    Excludes multimodal components (vision tower, mm_projector, etc.)
    """
    cls = torch.nn.Linear
    lora_module_names = set()
    multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler']
    
    for name, module in model.named_modules():
        if any(mm_keyword in name for mm_keyword in multimodal_keywords):
            continue
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if 'lm_head' in lora_module_names:  # needed for 16-bit
        lora_module_names.remove('lm_head')
    
    return list(lora_module_names)

In [None]:
def prepare_model_for_lora(
    model, 
    lora_r=64,            # LoRA rank - lower means fewer parameters, higher means more capacity
    lora_alpha=16,        # LoRA alpha - scaling factor (usually 2x to 4x of rank)
    lora_dropout=0.05,    # LoRA dropout - regularization to prevent overfitting
    bias="none",          # Whether to train bias parameters ("none", "all", or "lora_only")
    target_modules=None,  # Which modules to apply LoRA to. If None, will find all linear layers
    task_type="CAUSAL_LM" # Task type for LoRA configuration
):
    if target_modules is None:
        target_modules = find_all_linear_names(model)

    lora_config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha, 
        target_modules=target_modules,
        lora_dropout=lora_dropout, 
        bias=bias,
        task_type=task_type,
    )
    peft_model = get_peft_model(model, lora_config)
    
    return peft_model

def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}%"
    )

In [None]:
# Store results for comparison
config_results = []

model_config1 = prepare_model_for_lora(model, lora_r=64, lora_alpha=16)
trainable1, total1 = print_trainable_parameters(model_config1)
del model_config1
config_results.append({"name": "Default (r=64, alpha=16)", "trainable": trainable1, "total": total1})
print("\n")

model_config2 = prepare_model_for_lora(model, lora_r=16, lora_alpha=32)
trainable2, total2 = print_trainable_parameters(model_config2)
del model_config2
config_results.append({"name": "Low rank (r=16, alpha=32)", "trainable": trainable2, "total": total2})
print("\n")\

target_modules = ["q_proj", "v_proj"]  # Only attention query and value projections
model_config3 = prepare_model_for_lora(model, lora_r=64, lora_alpha=16, target_modules=target_modules)
trainable3, total3 = print_trainable_parameters(model_config3)
del model_config3
config_results.append({"name": "Attention only (r=64, alpha=16)", "trainable": trainable3, "total": total3})

target_modules = ["q_proj", "v_proj"]  # Only attention query and value projections
model_config4 = prepare_model_for_lora(model, lora_r=16, lora_alpha=16, target_modules=target_modules)
trainable4, total4 = print_trainable_parameters(model_config4)
del model_config4
config_results.append({"name": "Attention only (r=16, alpha=16)", "trainable": trainable4, "total": total4})
print("\n")


# Print comparison table
print("## Configuration Comparison ##")
print("----------------------------------------")
print(f"{'Configuration':<40} {'Trainable':<12} {'% of Model':<12}")
print("----------------------------------------")
for config in config_results:
    print(f"{config['name']:<40} {config['trainable']:<12,d} {(config['trainable']/config['total']*100):.4f}%")

## 5. Training Setup

In [None]:
# select model from config 4
target_modules = ["q_proj", "v_proj"]
model = prepare_model_for_lora(model, lora_r=16, lora_alpha=16, target_modules=target_modules)

In [None]:
from share4v.train.share4v_trainer import Share4VTrainer

# Define training arguments
training_args = TrainingArguments(
    output_dir="./lora_share4v_output",  # Output directory
    num_train_epochs=3,                  # Number of training epochs
    per_device_train_batch_size=4,       # Batch size per device
    gradient_accumulation_steps=4,       # Number of update steps to accumulate gradients for
    learning_rate=2e-5,                  # Learning rate
    weight_decay=0.01,                   # Weight decay
    save_steps=5,                      # Save every 500 steps
    save_total_limit=3,                  # Keep only the 3 most recent checkpoints
    report_to=None,                      # Disable reporting to wandb etc.
    remove_unused_columns=False,         # Keep all columns
    log_level="info",                    # Logging level
    logging_steps=10,                    # Log every 10 steps
    fp16=True,                           # Use mixed precision
    lora_enable=True,                    # Enable LoRA training
    group_by_modality_length=False       # Don't group by modality length
)

trainer = Share4VTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=dataset,
    data_collator=data_collator,
)

## 6. Train

Now we'll start the training process and save the LoRA weights after training.

In [None]:
print("Starting training...")
trainer.train()
print("Training completed!")

In [None]:
# Save LoRA weights
output_dir = "share4v_lora_weights"
os.makedirs(output_dir, exist_ok=True)

# Extract and save LoRA state dict
lora_state_dict = get_peft_state_maybe_zero_3(
    model.named_parameters(), bias="none"
)

# Extract non-LoRA trainable weights (like special token embeddings)
non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
    model.named_parameters()
)

# Save the model configuration and weights
model.config.save_pretrained(output_dir)
model.save_pretrained(output_dir, state_dict=lora_state_dict)
torch.save(non_lora_state_dict, os.path.join(output_dir, 'non_lora_trainables.bin'))

print(f"LoRA weights saved to {output_dir}")