<a href="https://colab.research.google.com/github/testuayun/NWU-nCov-auto-report/blob/master/PaliGemma_vqav2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## PaliGemma Fine-tuning

In this notebook, we will fine-tune [pretrained PaliGemma](https://huggingface.co/google/paligemma2-3b-pt-448) on a small split of [VQAv2](https://huggingface.co/datasets/HuggingFaceM4/VQAv2) dataset. Let's get started by installing necessary libraries.

We will authenticate to access the model using `notebook_login()`.

In [4]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

Let's load the dataset.

In [None]:
from transformers import BitsAndBytesConfig, PaliGemmaForConditionalGeneration
from peft import get_peft_model, LoraConfig

bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)

lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, device_map="auto")#, quantization_config=bnb_config)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
#trainable params: 11,298,816 || all params: 2,934,634,224 || trainable%: 0.38501616002417344


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

trainable params: 11,876,352 || all params: 3,044,118,768 || trainable%: 0.3901


We need to take tokens to same dtype as model so need to store it as a variable.

In [None]:
import torch

image_token = processor.tokenizer.convert_tokens_to_ids("<image>")
def collate_fn(examples):
  texts = ["<image>answer en " + example["question"] for example in examples]
  labels= [example['multiple_choice_answer'] for example in examples]
  images = [example["image"].convert("RGB") for example in examples]
  tokens = processor(text=texts, images=images, suffix=labels,
                    return_tensors="pt", padding="longest")

  tokens = tokens.to(DTYPE).to(device)
  return tokens


We will now initialize the `TrainingArguments`.

In [None]:
from transformers import TrainingArguments
args=TrainingArguments(
            num_train_epochs=2,
            remove_unused_columns=False,
            per_device_train_batch_size=1,
            gradient_accumulation_steps=4,
            warmup_steps=2,
            learning_rate=2e-5,
            weight_decay=1e-6,
            adam_beta2=0.999,
            logging_steps=100,
            optim="adamw_hf", # you can use paged optimizers like paged_adamw_8bit for QLoRA
            save_strategy="steps",
            save_steps=1000,
            save_total_limit=1,
            output_dir="paligemma_vqav2",
            bf16=True,
            report_to=["tensorboard"],
            dataloader_pin_memory=False
        )


We can now start training.

In [None]:
from transformers import Trainer

trainer = Trainer(
        model=model,
        train_dataset=train_ds ,
        data_collator=collate_fn,
        args=args
        )


LoRA with bsz of 2 works on A100 Colab. You can apply gradient accumulation (which is enabled in this notebook) to simulate larger batch sizes.
Currently there's an issue with QLoRA, we are investigating and will solve soon.

In [None]:
trainer.train()

In [None]:
trainer.push_to_hub()

In [None]:
!pip install -q transformers datasets evaluate accelerate Pillow pandas
!sudo apt-get install git-lfs  # 大文件下载支持

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/484.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m484.9/484.9 kB[0m [31m24.6 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/84.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/116.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m12.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [3]:
from PIL import Image
import torch
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
from datasets import load_dataset
from tqdm.auto import tqdm
import evaluate
import numpy as np

ImportError: cannot import name 'PaliGemmaForConditionalGeneration' from 'transformers' (/usr/local/lib/python3.11/dist-packages/transformers/__init__.py)

In [1]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
model_id = "google/paligemma-3b-ft-vqav2-448"

# 优化加载配置（节省显存）
model = PaliGemmaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,  # 使用 BF16 节省显存
    device_map="auto",
    revision="bfloat16"  # 指定量化版本
)
processor = PaliGemmaProcessor.from_pretrained(model_id)

config.json:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/62.6k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/862M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/700 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/40.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.26M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/24.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/607 [00:00<?, ?B/s]

In [None]:
# 自动从 HuggingFace 加载已处理的数据集（无需手动下载）
dataset = load_dataset('merve/vqav2-small', split="validation")


# 创建易于处理的子集（可选，全量验证需删除）
# dataset = dataset.select(range(2000))  # 测试用 500 样本

README.md:   0%|          | 0.00/403 [00:00<?, ?B/s]

validation-00000-of-00007.parquet:   0%|          | 0.00/484M [00:00<?, ?B/s]

validation-00001-of-00007.parquet:   0%|          | 0.00/486M [00:00<?, ?B/s]

validation-00002-of-00007.parquet:   0%|          | 0.00/483M [00:00<?, ?B/s]

validation-00003-of-00007.parquet:   0%|          | 0.00/486M [00:00<?, ?B/s]

validation-00004-of-00007.parquet:   0%|          | 0.00/475M [00:00<?, ?B/s]

validation-00005-of-00007.parquet:   0%|          | 0.00/484M [00:00<?, ?B/s]

validation-00006-of-00007.parquet:   0%|          | 0.00/479M [00:00<?, ?B/s]

Generating validation split:   0%|          | 0/21435 [00:00<?, ? examples/s]

In [None]:
dataset.column_names
dataset[0]

{'multiple_choice_answer': 'carnival ride',
 'question': 'Where are the kids riding?',
 'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x424>}

In [None]:
def clean_answer(answer: str) -> str:
    return answer.strip().lower().replace(".", "").split()[0]

In [None]:
def batch_predict(dataset, batch_size=4):
    all_answers = []

    for i in tqdm(range(0, len(dataset), batch_size)):
        batch = dataset.select(range(i, min(i+batch_size, len(dataset))))

        # 构建 prompts
        prompts = [f"answer en {ex['question']}" for ex in batch]

        # 预处理
        inputs = processor(
            text=prompts,
            images=[ex["image"].convert("RGB") for ex in batch],
            return_tensors="pt",
            padding=True,
            truncation=True
        ).to(model.device)

        # 生成答案
        outputs = model.generate(
            **inputs,
            max_new_tokens=10,
            do_sample=False,
            num_beams=3
        )

        # 解码
        batch_answers = processor.batch_decode(outputs, skip_special_tokens=True)
        all_answers.extend([clean_answer(a) for a in batch_answers])

    return all_answers

In [None]:
predicted_answers = batch_predict(dataset, batch_size=8)

# 准备评估输入（注意简化版数据集的结构）
references = [{"answers": [ex["answer"]]} for ex in dataset]  # 单答案格式
predictions = [{"answer": ans} for ans in predicted_answers]

# 计算指标
vqa_metric = evaluate.load("vqa")
metrics = vqa_metric.compute(
    predictions=predictions,
    references=references
)

print(f"VQA 准确率: {metrics['overall']['accuracy']*100:.2f}%")

  0%|          | 0/2680 [00:00<?, ?it/s]

You are passing both `text` and `images` to `PaliGemmaProcessor`. The processor expects special image tokens in the text, as many tokens as there are images per each text. It is recommended to add `<image>` tokens in the very beginning of your text. For this call, we will infer how many images each text has and add special tokens.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
You are passing both `text` and `images` to `PaliGemmaProcessor`. The processor expects special image tokens in the text, as many tokens as there are images per each text. It is recommended to add `<image>` tokens in the very beginning of your text. For this call, we will infer how many images each text has and add special tokens.
You are passing both `text` and `images` to `PaliGemmaProcessor`. The processor expects special image tokens in the text, as many tokens as there are images per each text. It is recommended to 

KeyError: 'answer'

In [None]:
vqa_metric = evaluate.load("vqa")
metrics = vqa_metric.compute(
    predictions=formatted_predictions,
    references=references
)

print(f"VQA 准确率: {metrics['overall']['accuracy']*100:.2f}%")
print(f"是/否类准确率: {metrics['yes/no']['accuracy']*100:.2f}%")
print(f"数字类准确率: {metrics['number']['accuracy']*100:.2f}%")
print(f"其他类准确率: {metrics['other']['accuracy']*100:.2f}%")

In [None]:
import torch
from PIL import Image
import requests
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration

# 指定使用的模型ID，这里以 google/paligemma-3b-mix-224 为例
model_id = "google/paligemma-3b-mix-224"

# 加载模型与处理器
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)

# 设置设备（可选，如果有GPU则使用GPU）
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

PaliGemmaForConditionalGeneration(
  (vision_tower): SiglipVisionModel(
    (vision_model): SiglipVisionTransformer(
      (embeddings): SiglipVisionEmbeddings(
        (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
        (position_embedding): Embedding(256, 1152)
      )
      (encoder): SiglipEncoder(
        (layers): ModuleList(
          (0-26): 27 x SiglipEncoderLayer(
            (self_attn): SiglipSdpaAttention(
              (k_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (v_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (q_proj): Linear(in_features=1152, out_features=1152, bias=True)
              (out_proj): Linear(in_features=1152, out_features=1152, bias=True)
            )
            (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
            (mlp): SiglipMLP(
              (activation_fn): PytorchGELUTanh()
              (fc1): Linear(in_features

In [None]:
image_file = "https://www.ilankelman.org/stopsigns/australia.jpg"
raw_image = Image.open(requests.get(image_file, stream=True).raw)


In [None]:
prompt = "What does this traffic sign instruct drivers to do? Answer:"
prompt_with_token = "<image> " + prompt
# 使用处理器对文本和图像进行预处理，得到模型输入
inputs = processor(raw_image, prompt_with_token, return_tensors="pt").to(device)
with torch.no_grad():
  output = model.generate(**inputs, max_new_tokens=20)

result = processor.decode(output[0], skip_special_tokens=True)
# 使用分隔符"Answer:"分割并提取后半部分，然后去除多余空格
answer = result.split("Answer:")[-1].strip()
answer

'stop'

In [2]:
!pip install -q transformers datasets evaluate accelerate Pillow pandas regex
!sudo apt-get install git-lfs
from huggingface_hub import notebook_login
notebook_login()
from PIL import Image
import torch
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
from datasets import load_dataset
from tqdm.auto import tqdm
import evaluate
import numpy as np
import re  # 新增正则模块

# 配置参数
model_id = "google/paligemma-3b-ft-vqav2-448"
batch_size = 4  # 根据显存调整（24G显存可设为8）
device = "cuda" if torch.cuda.is_available() else "cpu"

# 初始化模型（增加异常处理）
try:
    model = PaliGemmaForConditionalGeneration.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        revision="bfloat16"
    )
    processor = PaliGemmaProcessor.from_pretrained(model_id)
except Exception as e:
    print(f"初始化失败: {str(e)}")
    raise

# 验证图像尺寸处理
print("图像处理器尺寸:", processor.image_processor.size)  # 应输出 {'height': 448, 'width': 448}

# 加载数据集（注意检查数据集结构）
dataset = load_dataset('merve/vqav2-small', split="validation")

# 检查数据集结构
print("\n数据集样本结构:", dataset[0].keys())  # 应包含 'answers' 字段

# 改进的答案清洗函数
def clean_answer(answer: str) -> str:
    """处理大小写、标点和多余空格"""
    answer = re.sub(r"[^\w\s]", "", answer.strip().lower())  # 移除所有标点
    return answer.split()[0]  # 根据任务需求取第一个单词

# 带异常处理的批量预测
def batch_predict(dataset, batch_size=4):
    all_answers = []

    for i in tqdm(range(0, len(dataset), batch_size)):
        try:
            batch = dataset.select(range(i, min(i+batch_size, len(dataset))))

            # 构建 prompts
            prompts = [f"answer en {ex['question']}" for ex in batch]

            # 预处理（显式转换图像尺寸）
            inputs = processor(
                text=prompts,
                images=[ex["image"].convert("RGB") for ex in batch],
                return_tensors="pt",
                padding=True,
                truncation=True
            ).to(device)

            # 生成答案（优化参数）
            outputs = model.generate(
                **inputs,
                max_new_tokens=20,      # 增加生成长度
                num_beams=5,           # 与论文设置一致
                early_stopping=True,   # 提前终止
                do_sample=False
            )

            # 解码并清洗
            batch_answers = processor.batch_decode(outputs, skip_special_tokens=True)
            all_answers.extend([clean_answer(a) for a in batch_answers])

        except Exception as e:
            print(f"处理批次 {i//batch_size} 时出错: {str(e)}")
            all_answers.extend([""] * batch_size)  # 填充空答案

    return all_answers

# 执行预测
predicted_answers = batch_predict(dataset, batch_size=batch_size)

# 构建评估输入（处理多答案）
references = [{
    "answers": [a["answer"] for a in ex["answers"]]  # 使用所有参考答案
} for ex in dataset]

predictions = [{"answer": ans} for ans in predicted_answers]

# 使用VQAv2专用评估
vqa_metric = evaluate.load("vqa", config_name="vqa_v2")  # 明确指定v2版本
metrics = vqa_metric.compute(
    predictions=predictions,
    references=references
)

# 打印完整指标
print("\n评估结果:")
print(f"- 总体准确率: {metrics['overall']['accuracy']*100:.2f}%")
print(f"- 是/否类准确率: {metrics['yes/no']['accuracy']*100:.2f}%")
print(f"- 数字类准确率: {metrics['number']['accuracy']*100:.2f}%")
print(f"- 其他类准确率: {metrics['other']['accuracy']*100:.2f}%")

# 测试单个样本
test_sample = dataset[0]
test_prompt = f"answer en {test_sample['question']}"
inputs = processor(
    text=test_prompt,
    images=test_sample["image"].convert("RGB"),
    return_tensors="pt"
).to(device)

output = model.generate(**inputs, max_new_tokens=20)
print("\n单样本测试:")
print("问题:", test_sample['question'])
print("预测:", clean_answer(processor.decode(output[0], skip_special_tokens=True)))
print("参考答案:", [a["answer"] for a in test_sample["answers"]])

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/485.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m16.6 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/84.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/116.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m15.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/62.6k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/862M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/700 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/40.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.26M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/24.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/607 [00:00<?, ?B/s]

图像处理器尺寸: {'height': 448, 'width': 448}


README.md:   0%|          | 0.00/403 [00:00<?, ?B/s]

validation-00000-of-00007.parquet:   0%|          | 0.00/484M [00:00<?, ?B/s]

validation-00001-of-00007.parquet:   0%|          | 0.00/486M [00:00<?, ?B/s]

validation-00002-of-00007.parquet:   0%|          | 0.00/483M [00:00<?, ?B/s]

validation-00003-of-00007.parquet:   0%|          | 0.00/486M [00:00<?, ?B/s]

validation-00004-of-00007.parquet:   0%|          | 0.00/475M [00:00<?, ?B/s]

validation-00005-of-00007.parquet:   0%|          | 0.00/484M [00:00<?, ?B/s]

validation-00006-of-00007.parquet:   0%|          | 0.00/479M [00:00<?, ?B/s]

Generating validation split:   0%|          | 0/21435 [00:00<?, ? examples/s]


数据集样本结构: dict_keys(['multiple_choice_answer', 'question', 'image'])


  0%|          | 0/5359 [00:00<?, ?it/s]

You are passing both `text` and `images` to `PaliGemmaProcessor`. The processor expects special image tokens in the text, as many tokens as there are images per each text. It is recommended to add `<image>` tokens in the very beginning of your text. For this call, we will infer how many images each text has and add special tokens.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
You are passing both `text` and `images` to `PaliGemmaProcessor`. The processor expects special image tokens in the text, as many tokens as there are images per each text. It is recommended to add `<image>` tokens in the very beginning of your text. For this call, we will infer how many images each text has and add special tokens.
You are passing both `text` and `images` to `PaliGemmaProcessor`. The processor expects special image tokens in the text, as many tokens as there are images per each text. It is recommended to 

KeyboardInterrupt: 

In [1]:
# 新增依赖安装（需重启运行时）
!pip install -q transformers datasets evaluate accelerate Pillow pandas regex bitsandbytes flash-attn --upgrade
!sudo apt-get install git-lfs

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m89.9/89.9 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.0/6.0 MB[0m [31m43.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.0/10.0 MB[0m [31m62.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.4/485.4 kB[0m [31m34.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.0/84.0 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m342.1/342.1 kB[0m [31m24.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.1/13.1 MB[0m [31m74.0 MB/s[0m eta 

In [2]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
from PIL import Image
import torch
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
from datasets import load_dataset
from tqdm.auto import tqdm
import evaluate
import numpy as np
import re

# 配置参数
model_id = "google/paligemma-3b-ft-vqav2-448"
batch_size = 8  # 显存优化后batch_size可翻倍（原为4）
device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
# 初始化模型（8位量化 + Flash Attention 2）
try:
    model = PaliGemmaForConditionalGeneration.from_pretrained(
        model_id,
        load_in_8bit=True,              # 8位量化，显存减少40%
        use_flash_attention_2=True,     # Flash Attention加速
        device_map="auto",
        revision="bfloat16"
    )
    processor = PaliGemmaProcessor.from_pretrained(model_id)
except Exception as e:
    print(f"初始化失败: {str(e)}")
    raise

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


model.safetensors.index.json:   0%|          | 0.00/62.6k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/862M [00:00<?, ?B/s]

The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/700 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/40.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.26M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/24.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/607 [00:00<?, ?B/s]

In [6]:
# 预处理阶段提前转换图像（减少循环内开销）
dataset = load_dataset('merve/vqav2-small', split="validation")


In [7]:
# 预处理并缓存图像（关键优化：节省30%时间）
def preprocess_image(ex):
    ex["image"] = ex["image"].convert("RGB")  # 提前转换
    return ex
dataset = dataset.map(
    lambda batch: {"image": [img.convert("RGB") for img in batch["image"]]},
    batched=True,  # 关键参数
    batch_size=100  # 根据内存调整批次大小
)

Map:   0%|          | 0/21435 [00:00<?, ? examples/s]

KeyboardInterrupt: 

You can find steps to infer [here](https://colab.research.google.com/drive/100IQcvMvGm9y--oelbLfI__eHCoz5Ser?usp=sharing).