In [1]:
from datasets import load_dataset
from PIL import Image
import os
from functools import partial

# load json file
json_data_path = "/home/linux/yyj/colpali/finetune/img_cap_pairs_t.json"
ds = load_dataset("json", data_files=json_data_path, split="train")

# load images
image_dir = "/home/linux/yyj/colpali/finetune/pdf2images"
def load_image(example, image_dir):
    full_path = os.path.join(image_dir, example["image_path"])
    example["image"] = Image.open(full_path)
    return example

ds = ds.map(partial(load_image, image_dir=image_dir))

# split dataset
ds = ds.train_test_split(test_size=0.2, seed=42)
train_ds = ds["train"]
test_ds = ds["test"]


  from .autonotebook import tqdm as notebook_tqdm
Generating train split: 1103 examples [00:00, 40758.71 examples/s]
Map: 100%|██████████| 1103/1103 [00:00<00:00, 9142.45 examples/s]


In [2]:
import torch
from transformers import Trainer, TrainingArguments, BitsAndBytesConfig, \
    ColPaliForRetrieval, ColPaliProcessor, EarlyStoppingCallback
from colpali_engine.loss import ColbertPairwiseCELoss
from colpali_engine.trainer.contrastive_trainer import ContrastiveTrainer
from peft import LoraConfig, get_peft_model

torch.manual_seed(42)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

model_dir = "/home/linux/yyj/colpali/finetune/colpali-v1.2-hf"
model = ColPaliForRetrieval.from_pretrained(
    model_dir,
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config,
    device_map="cuda:0",
).eval()


lora_config = LoraConfig(
        r=8,
        lora_alpha=8,
        lora_dropout=0.1,
        target_modules=['down_proj','o_proj','k_proj','q_proj','gate_proj','up_proj','v_proj'],
        init_lora_weights="gaussian"
    )
lora_config.inference_mode = False
model = get_peft_model(model, lora_config)
processor = ColPaliProcessor.from_pretrained(model_dir)

def collate_fn(examples):
    texts = []
    images = []

    for example in examples:

        texts.append(example["caption"])
        images.append(example["image"].convert("RGB"))

    batch_images = processor(images=images, return_tensors="pt").to(model.device)
    batch_queries = processor(text=texts, max_length=512, padding="max_length", truncation=True, return_tensors="pt").to(model.device)
    return (batch_queries, batch_images)


class ContrastiveTrainer(Trainer):
    def __init__(self, loss_func, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_func = loss_func

    def compute_loss(self, model, inputs, num_items_in_batch=4, return_outputs=False):
        query_inputs, doc_inputs = inputs
        query_outputs = model(**query_inputs)
        doc_outputs = model(**doc_inputs)
        loss = self.loss_func(query_outputs.embeddings, doc_outputs.embeddings)
        return (loss, (query_outputs, doc_outputs)) if return_outputs else loss

    def prediction_step(self, model, inputs, prediction_loss_only=True, ignore_keys=None):
        query_inputs, doc_inputs = inputs # unpack from data collator
        with torch.no_grad():
            query_outputs = model(**query_inputs)
            doc_outputs = model(**doc_inputs)

            loss = self.loss_func(query_outputs.embeddings, doc_outputs.embeddings)
            
            return loss, None, None if prediction_loss_only else loss

training_args = TrainingArguments(
    output_dir="./colpali_city_0529",
    num_train_epochs=5,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    gradient_checkpointing=False,
    logging_steps=10,
    eval_strategy="steps",    
    eval_steps=10,
    warmup_steps=20,
    learning_rate=5e-5,
    save_total_limit=1,
    report_to="tensorboard",
    dataloader_pin_memory=False,
    load_best_model_at_end=True,      
    metric_for_best_model="loss",      
    greater_is_better=False   
)


trainer = ContrastiveTrainer(
    model=model,
    train_dataset=train_ds,
    eval_dataset=test_ds,
    args=training_args,
    loss_func=ColbertPairwiseCELoss(),
    data_collator=collate_fn,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)

trainer.args.remove_unused_columns = False


pynvml not found. GPU stats will not be printed.


Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00,  2.91s/it]
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
No label_names provided for model class `PeftModel`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [3]:
trainer.train()

Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss,Validation Loss
10,15.8281,1.261523
20,1.5475,0.391962
30,0.5994,0.2543
40,0.0834,0.243903
50,0.5017,0.160004
60,0.3195,0.140658
70,0.134,0.103096
80,0.3721,0.128103
90,0.1198,0.155246
100,0.1732,0.150843


TrainOutput(global_step=100, training_loss=1.9678859949111938, metrics={'train_runtime': 1554.1382, 'train_samples_per_second': 2.838, 'train_steps_per_second': 0.116, 'total_flos': 0.0, 'train_loss': 1.9678859949111938, 'epoch': 2.761904761904762})

In [4]:
trainer.save_model("/home/linux/yyj/colpali/finetune/wiky_city_zh_0528")  # 包括模型和 tokenizer 等信息
processor.save_pretrained("/home/linux/yyj/colpali/finetune/wiky_city_zh_0528")  # 保存预处理器，如 tokenizer + image processor


[]