In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "2"

# Dataset

In [2]:
from training_utils import resize_image


def transform_item(item):
    return {
        "text": item["annotations"]["caption"][0],
        "pos_image": resize_image(item["image"], 224),
    }

In [None]:
import datasets
from datasets import DatasetDict

N = 1000
ds = datasets.load_dataset("shunk031/STAIR-Captions", "v1.2.0", split="train")
ds = ds.select(range(N)).map(
    transform_item, num_proc=12, remove_columns=ds.column_names
)

train_dd = DatasetDict(
    {
        "stair captions": ds,
    }
)

# Model

In [4]:
from modeling_clip_qwen2vl import CLIPQwen2VLWrapper

# model_pathには initialize_model で出力したパスを指定する
clip = CLIPQwen2VLWrapper("./clip", enable_text_grad=True)

In [5]:
from peft import get_peft_model, LoraConfig

peft_config = LoraConfig(
    inference_mode=False,
    r=128,
    lora_alpha=16,
    lora_dropout=0.05,
    use_rslora=True,
    target_modules=["attn.qkv", "attn.proj", "fc1", "fc2", "mlp.0", "mlp.2"],
    modules_to_save=["vision_projection"],
)
clip.model = get_peft_model(clip.model, peft_config)

In [6]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer(modules=[clip], device="cuda")

# Train

In [7]:
from PIL import Image
from PIL import PngImagePlugin

# 学習中のエラー防止
Image.MAX_IMAGE_PIXELS = None
PngImagePlugin.MAX_TEXT_CHUNK = 100 * (1024**2)

In [None]:
from training_utils import ImageTextCachedMultipleNegativesRankingLoss
from sentence_transformers import (
    SentenceTransformerTrainingArguments,
    SentenceTransformerTrainer,
)

loss = ImageTextCachedMultipleNegativesRankingLoss(model=model, mini_batch_size=1)

args = SentenceTransformerTrainingArguments(
    report_to="none",
    output_dir="./outputs",
    learning_rate=1e-4,
    num_train_epochs=1,
    lr_scheduler_type="linear",
    warmup_steps=100,
    logging_steps=10,
    save_strategy="steps",
    save_steps=100,
    eval_strategy="no",
    bf16=True,
    per_device_train_batch_size=64,
    run_name="v0.1",
)


trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dd,
    loss=loss,
)
trainer.train()

In [9]:
model.save("./outputs")