In [1]:
import torch
from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 1. Configure 4-bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=False,
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [None]:
# 2. Load the Llama-3.2-11B-Vision model with quantization
model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct"
model = AutoModelForVision2Seq.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config
)
model.config.use_cache = False

In [4]:
# 3. Load the processor
processor = AutoProcessor.from_pretrained(model_name)

In [5]:
# 4. Load and preprocess the amazon-product-descriptions-vlm dataset
prompt = """Create a Short Product description based on the provided ##PRODUCT NAME## and ##CATEGORY## and image.
Only return description. The description should be SEO optimized and for a better mobile search experience.

##PRODUCT NAME##: {product_name}
##CATEGORY##: {category}"""

def format_data(sample):
    return {
        "messages": [
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": prompt.format(product_name=sample["Product Name"], category=sample["Category"]),
                    },
                    {
                        "type": "image",
                        "image": sample["image"],
                    }
                ],
            },
            {
                "role": "assistant",
                "content": [{"type": "text", "text": sample["description"]}],
            },
        ],
    }

# Load dataset and format it
dataset_id = "philschmid/amazon-product-descriptions-vlm"
dataset = load_dataset(dataset_id, split="train")
dataset = [format_data(sample) for sample in dataset]

In [6]:
# 5. Configure LoRA
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=8,
    bias="none",
    target_modules=["q_proj", "v_proj"],
    task_type="CAUSAL_LM"
)

In [7]:
# 6. Define SFT Training Arguments
sft_args = SFTConfig(
    output_dir="./llama-3.2-11b-vision-finetuned",
    num_train_epochs=2,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    max_seq_length=512,
    logging_steps=100,
    learning_rate=2e-5,
    fp16=False,
    bf16=True,
    max_grad_norm=1.0,
    warmup_ratio=0.1,
    weight_decay=0.01,
    gradient_checkpointing=True,
    optim="paged_adamw_32bit",
    save_strategy="epoch",
    lr_scheduler_type="cosine",
    report_to="tensorboard",
    remove_unused_columns=False,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    dataset_kwargs={"skip_prepare_dataset": True},
    use_liger=False
)

In [8]:
# 7. Define data collator
def collate_fn(examples):
    texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
    image_inputs = [example["messages"][0]['content'][1]['image'] for example in examples]

    batch = processor(
        text=texts,
        images=image_inputs,
        return_tensors="pt",
        add_special_tokens=False,
        padding=True
    )

    # Mask padding and image tokens in labels
    labels = batch["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == processor.image_token_id] = -100
    batch["labels"] = labels

    return batch

In [10]:
# 8. Initialize the Trainer
trainer = SFTTrainer(
    model=model,
    args=sft_args,
    train_dataset=dataset,
    data_collator=collate_fn,
    peft_config=peft_config,
    tokenizer=processor.tokenizer,
)


In [None]:
# 10. Fine-tune the model
if torch.cuda.is_available():
    print("CUDA available. Starting training...")
    trainer.train()
else:
    print("CUDA not available. Training on CPU. This may take longer.")

# 11. Save the fine-tuned model
trainer.save_model("./l2-llama-3.2-11b-vision-finetuned")
print("Model saved successfully.")