### 自定义Trainer类，将idcor的计算也考虑到损失函数中

In [1]:
import torch
image_features = torch.randn(4, 3, 2048)
image_features, image_features.shape

(tensor([[[-0.5019,  0.9098, -0.3257,  ...,  1.2817,  0.0495, -1.5661],
          [-0.2381,  0.6739,  0.4183,  ..., -1.1270, -0.1086, -0.1199],
          [ 0.1480,  1.4203, -0.2630,  ...,  0.1286,  0.7086,  0.0828]],
 
         [[ 0.9037, -0.6607,  0.5651,  ..., -2.1822,  0.1593,  0.7980],
          [-1.1816, -0.9557,  1.1454,  ...,  0.7013,  0.0301,  1.1178],
          [ 1.6244,  0.2782, -0.4068,  ...,  0.7782,  0.1591, -0.9041]],
 
         [[ 2.9414,  0.9978, -0.6884,  ..., -0.7702,  1.2302, -0.4109],
          [-1.3048,  0.6791, -1.3369,  ...,  0.5799, -0.1845, -0.2855],
          [ 0.3539, -1.7124, -2.0652,  ...,  1.4186,  0.2590, -0.7188]],
 
         [[ 1.3619, -0.9040, -0.6858,  ..., -0.4627,  0.1842,  0.3898],
          [ 1.3845, -0.1301, -0.2596,  ..., -0.2234,  0.0809,  0.0944],
          [-0.2858, -2.5773, -1.1297,  ..., -0.4675, -2.7245,  0.7337]]]),
 torch.Size([4, 3, 2048]))

In [2]:
image_features_average = torch.mean(image_features, dim=1)
image_features_average.shape

torch.Size([4, 2048])

In [6]:
image_features.shape

torch.Size([4, 3, 2048])

In [None]:
single_average_feature = image_features_average[0].unsqueeze(0)
single_average_feature, single_average_feature.shape

(tensor([[-0.1973,  1.0013, -0.0568,  ...,  0.0944,  0.2165, -0.5344]]),
 torch.Size([1, 2048]))

In [None]:
import copy
import logging
import os
from dataclasses import dataclass, field
from functools import partial
from typing import Dict, List, Optional, Sequence

import torch
import transformers

from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import (
    LlavaForConditionalGeneration,
    LlavaProcessor,
    Trainer,
    TrainingArguments,
)

# 导入读取数据集和处理数据集为向量的工具类
from show_llava.data import LlavaDataset, TrainLLavaModelCollator
from show_llava.util import print_trainable_parameters

logger = logging.getLogger(__name__)

In [None]:
# 指定数据集路径工具类
@dataclass
class DataArguments:
    data_path: str = field(
        default=None, metadata={"help": "Path to the training data."}
    )
    # source_length: int = field(default=128)
    # target_length: int = field(default=512)

# 指定要训练的模型路径及训练参数工具类
@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="./show_model/model001")
    train_type: Optional[str] = field(
        default="none",
        metadata={
            "help": """
            1. use_lora:使用lora训练,
            2. none:全量参数训练;
            3. freeze_vision:只冻结vision_tower进行训练
            4. freeze_vision_and_language:冻结vision_tower和language_model进行训练
            """
        },
    )

In [None]:
# 函数加载模型工具类
def load_model_processor(modelargs: ModelArguments):
    # 读取模型
    model = LlavaForConditionalGeneration.from_pretrained(
        modelargs.model_name_or_path,
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
    )
    # 读取处理器
    processor = LlavaProcessor.from_pretrained(modelargs.model_name_or_path)

    if modelargs.train_type == "use_lora":  # 指定使用lora训练，配置lora的相关的参数
        logging.warning("Loading model to Lora")

        from peft import LoraConfig, get_peft_model

        """
            TODO: 可以不用lora, 因为参数比较少, 再训练几次
            引入lora的配置可能引入了一个新的变量, 会导致实验的不严谨性
        """
        
        LORA_R = 32
        # LORA_ALPHA = 16
        LORA_DROPOUT = 0.05
        TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj"]

        config = LoraConfig(
            r=LORA_R,
            # lora_alpha=LORA_ALPHA,
            target_modules=TARGET_MODULES,
            lora_dropout=LORA_DROPOUT,
            bias="none",
            task_type="CAUSAL_LM",
            modules_to_save=["multi_modal_projector"],  # 显式指定训练从视觉层投影到文本层的MLP
        )
        model = get_peft_model(model, config)
        # model.print_trainable_parameters()

    elif modelargs.train_type == "none":
        logging.warning("使用全量参数进行训练")

        pass
    elif modelargs.train_type == "freeze_vision":
        logging.warning("冻结vision_tower网络层，剩下的网络权重进行训练")

        for param in model.vision_tower.parameters():
            param.requires_grad = False
    elif modelargs.train_type == "freeze_vision_and_language":
        logging.warning("llava stage1 冻结vision_tower和language_model网络层, 剩下的网络权重进行训练")
        
        # 冻结 vision_tower 网络层
        for param in model.vision_tower.parameters():
            param.requires_grad = False

        # 冻结 language_model 网络层
        for param in model.language_model.parameters():
            param.requires_grad = False

        # 显示指定 multi_modal_projector 层参与梯度更新
        for param in model.multi_modal_projector.parameters():
            param.requires_grad = True

    print_trainable_parameters(model)   # 打印此时可训练的参数占全部参数的百分比
    return model, processor


def load_dataset_collator(processor, dataargs: DataArguments):

    llava_dataset = LlavaDataset(
        dataargs.data_path  # "data/liuhaotian/LLaVA-CC3M-Pretrain-595K"
    )
    data_collator = TrainLLavaModelCollator(processor, -100)

    return llava_dataset, data_collator

In [None]:
def train():
    model_args: ModelArguments = ModelArguments(
        model_name_or_path="./qwen2.5_3B_Instruct_clipvL14_model/model001",
        train_type="freeze_vision_and_language"
    )

    data_args: DataArguments = DataArguments(
        data_path="/home/lsy/shared_data/liuhaotian/LLaVA-CC3M-Pretrain-595K"
    )
    
    model, processor = load_model_processor(model_args)
    train_dataset, data_collator = load_dataset_collator(processor, data_args)
    
    eval_dataset = train_dataset.train_test_split(test_size=0.1)


    # trainer = Trainer(
    #     model=model,
    #     args=training_args,
    #     train_dataset=train_dataset,
    #     eval_dataset=None,
    #     data_collator=data_collator,
    # )

    # trainer.train()
    # trainer.save_state()
    # trainer.save_model(output_dir=training_args.output_dir)
    

In [None]:
logging.basicConfig(
        format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
        level=logging.INFO,
        datefmt="%Y-%m-%d %H:%M:%S",
)
train()