#### Training to align the adapter using VQAv2 dataset

In [1]:
from datasets import load_dataset, DatasetDict
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
import evaluate
from transformers import Blip2Config, Blip2VisionConfig, Blip2ForConditionalGeneration, Blip2QFormerConfig
from transformers import AutoModel, AutoTokenizer, SwinModel, SwinConfig, AutoModelForCausalLM, AutoTokenizer
from transformers import Blip2Processor, AutoImageProcessor, BlipImageProcessor
from transformers import AddedToken
import os
from torch.utils.data import DataLoader
from tqdm import tqdm
import pickle
import torch
import torch.nn as nn
from collections import Counter
import wandb

In [2]:
# Configure HuggingFace cache directories (change the base path if needed)
HF_CACHE_BASE = os.environ.get("HF_CACHE_BASE", r"D:/cache/huggingface")
HF_CACHE_BASE = os.path.abspath(HF_CACHE_BASE)
HF_DATASETS_CACHE = os.path.join(HF_CACHE_BASE, "datasets")
HF_MODELS_CACHE = os.path.join(HF_CACHE_BASE, "models")

for path in (HF_CACHE_BASE, HF_DATASETS_CACHE, HF_MODELS_CACHE):
    os.makedirs(path, exist_ok=True)

os.environ["HF_HOME"] = HF_CACHE_BASE
os.environ["HF_DATASETS_CACHE"] = HF_DATASETS_CACHE
os.environ["TRANSFORMERS_CACHE"] = HF_MODELS_CACHE
os.environ["HF_HUB_CACHE"] = HF_MODELS_CACHE


#### Dataset preparation

In [3]:
class VQADataset(torch.utils.data.Dataset):
    """VQA (v2) dataset."""

    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def _select_answer(self, answers, fallback=""):
        candidates = []
        if isinstance(answers, dict):
            if isinstance(answers.get("text"), list):
                candidates = answers["text"]
            elif isinstance(answers.get("answer"), list):
                candidates = answers["answer"]
            elif isinstance(answers.get("answers"), list):
                candidates = answers["answers"]
        elif isinstance(answers, list):
            if answers and isinstance(answers[0], dict):
                candidates = [a.get("text") or a.get("answer") for a in answers if a.get("text") or a.get("answer")]
            else:
                candidates = answers
        elif isinstance(answers, str):
            candidates = [answers]

        candidates = [c for c in candidates if isinstance(c, str) and c]
        if candidates:
            # make sure the answer contains at most one string
            return Counter(candidates).most_common(1)[0][0]
        return fallback

    def __getitem__(self, idx):
        item = self.dataset[idx]
        question = item["question"]
        answer = self._select_answer(item.get("answers"), item.get("multiple_choice_answer", ""))
        image = item["image"]

        #encoding和decoding的max_length必须一样,这里都是64.因为language_model的logits的seq_len就是max_length,同时计算loss会logits = logits[:, -labels.size(1) :, :]，这里max_length不一样截取后loss计算就会错
        encoding = self.processor(
            images=image,
            text=question,
            padding="max_length",
            max_length=64,
            truncation=True,
            return_tensors="pt",
        )
        labels = self.processor.tokenizer(
            answer,
            max_length=64,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )

        #only keep the first eos in labels, other eos would be replaced by ignored_index of loss function, so that the loss won't count them
        label_ids = labels["input_ids"].squeeze(0)
        pad_token_id = self.processor.tokenizer.pad_token_id
        eos_token_id = self.processor.tokenizer.eos_token_id
        if eos_token_id is None:
            eos_token_id = pad_token_id

        if pad_token_id is not None and pad_token_id != eos_token_id:
            label_ids[label_ids == pad_token_id] = -100

        if eos_token_id is not None:
            eos_positions = (label_ids == eos_token_id).nonzero(as_tuple=False)
            if eos_positions.numel() > 0:
                label_ids[eos_positions.flatten()] = -100
                first_eos_idx = eos_positions[0].item()
                label_ids[first_eos_idx] = eos_token_id

        encoding = {k: v.squeeze(0) for k, v in encoding.items()}
        encoding["labels"] = label_ids
        return encoding

In [4]:
from datasets.download.download_config import DownloadConfig

##################################      Creating Dataset and Dataloader     ##################################
raw_dataset = load_dataset(
    "HuggingFaceM4/VQAv2",
    cache_dir=HF_DATASETS_CACHE,
    trust_remote_code=True,
     download_config= DownloadConfig(storage_options={"timeout": 3600})
)

dataset = DatasetDict({
    "train": raw_dataset["train"],
    "test": raw_dataset["test"]
})


Repo card metadata block was not found. Setting CardData to empty.


In [5]:
print(dataset["train"][0])

{'question_type': 'what is this', 'multiple_choice_answer': 'net', 'answers': [{'answer': 'net', 'answer_confidence': 'maybe', 'answer_id': 1}, {'answer': 'net', 'answer_confidence': 'yes', 'answer_id': 2}, {'answer': 'net', 'answer_confidence': 'yes', 'answer_id': 3}, {'answer': 'netting', 'answer_confidence': 'yes', 'answer_id': 4}, {'answer': 'net', 'answer_confidence': 'yes', 'answer_id': 5}, {'answer': 'net', 'answer_confidence': 'yes', 'answer_id': 6}, {'answer': 'mesh', 'answer_confidence': 'maybe', 'answer_id': 7}, {'answer': 'net', 'answer_confidence': 'yes', 'answer_id': 8}, {'answer': 'net', 'answer_confidence': 'yes', 'answer_id': 9}, {'answer': 'net', 'answer_confidence': 'yes', 'answer_id': 10}], 'image_id': 458752, 'answer_type': 'other', 'question_id': 458752000, 'question': 'What is this photo taken looking through?', 'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=640x480 at 0x2D58FB06620>}


#### Model creation

In [6]:
##################################  Model Creation  ##################################
llm_id = "meta-llama/Llama-3.2-3B"
base_model_id = "Salesforce/blip2-opt-2.7b"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#dtype = torch.float16 if device.type == "cuda" else torch.float32
"""
Q-Former: bfloat16
Vision encoder: bfloat16
LLM (LLaMA / Flan-T5): bfloat16
Loss: fp32
发现 q-former若使用float16,其outputs直接nan，不知为啥
"""
dtype = torch.float32
vision_dtype  = torch.bfloat16  if device.type == "cuda" else torch.float32
qformer_dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
llm_dtype     = torch.bfloat16 if device.type == "cuda" else torch.float32

processor = Blip2Processor.from_pretrained(base_model_id, cache_dir=HF_MODELS_CACHE)
model = Blip2ForConditionalGeneration.from_pretrained(base_model_id, torch_dtype=dtype, cache_dir=HF_MODELS_CACHE)

llama_tokenizer = AutoTokenizer.from_pretrained(llm_id, use_fast=False, cache_dir=HF_MODELS_CACHE)
if llama_tokenizer.pad_token is None:
    llama_tokenizer.pad_token = llama_tokenizer.eos_token

image_token = AddedToken("<image>", normalized=False, special=True)
llama_tokenizer.add_tokens([image_token], special_tokens=True)
processor.tokenizer = llama_tokenizer
processor.tokenizer.pad_token = processor.tokenizer.eos_token
processor.num_query_tokens = model.config.num_query_tokens

llama_model = AutoModelForCausalLM.from_pretrained(llm_id, torch_dtype=llm_dtype)
llama_model = llama_model.to(device, dtype=llm_dtype)

hidden_in = model.language_projection.in_features
hidden_out = llama_model.config.hidden_size
if model.language_projection.out_features != hidden_out:
    projection = nn.Linear(hidden_in, hidden_out, bias=False)
    nn.init.xavier_uniform_(projection.weight)
    model.language_projection = projection

# 这里手动设置各子模块精度
model.vision_model = model.vision_model.to(device=device, dtype=vision_dtype)   # ViT: fp16
model.qformer = model.qformer.to(device=device, dtype=qformer_dtype)            # Q-Former: bf16
model.language_projection = model.language_projection.to(device=device, dtype=qformer_dtype)

model.language_model = llama_model
model.config.text_config = llama_model.config
model.config.pad_token_id = llama_tokenizer.pad_token_id
model.config.bos_token_id = llama_tokenizer.bos_token_id
model.config.eos_token_id = llama_tokenizer.eos_token_id
model.language_model.resize_token_embeddings(len(processor.tokenizer))
model.config.vocab_size = len(processor.tokenizer)
model.config.image_token_index = processor.tokenizer.convert_tokens_to_ids("<image>")
model = model.to(device)

for param in model.vision_model.parameters():
    param.requires_grad = False

for param in model.language_model.parameters():
    param.requires_grad = False

for param in model.qformer.parameters():
    param.requires_grad = True

for param in model.language_projection.parameters():
    param.requires_grad = True

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


In [7]:
print("Model mixed-precision layout:")

def describe_module(module, name):
    param_dtypes = {p.dtype for p in module.parameters()}
    param_devices = {p.device for p in module.parameters()}
    buffers = list(module.buffers())
    buffer_dtypes = {b.dtype for b in buffers} if buffers else {"<none>"}
    buffer_devices = {b.device for b in buffers} if buffers else {"<none>"}
    print(f"- {name}: params {param_dtypes} on {param_devices}; buffers {buffer_dtypes} on {buffer_devices}")

describe_module(model.vision_model, "Vision Encoder")
describe_module(model.qformer, "Q-Former")
describe_module(model.language_projection, "Language Projection")
describe_module(model.language_model, "LLM")


Model mixed-precision layout:
- Vision Encoder: params {torch.bfloat16} on {device(type='cuda', index=0)}; buffers {'<none>'} on {'<none>'}
- Q-Former: params {torch.bfloat16} on {device(type='cuda', index=0)}; buffers {'<none>'} on {'<none>'}
- Language Projection: params {torch.bfloat16} on {device(type='cuda', index=0)}; buffers {'<none>'} on {'<none>'}
- LLM: params {torch.bfloat16} on {device(type='cuda', index=0)}; buffers {torch.bfloat16} on {device(type='cuda', index=0)}


In [14]:
# 检查哪些参数 require_grad=True
trainable = [n for n,p in model.named_parameters() if p.requires_grad]
print("trainable param count:", len(trainable))
for n in trainable:
    print(n)

trainable param count: 256
query_tokens
qformer.layernorm.weight
qformer.layernorm.bias
qformer.encoder.layer.0.attention.attention.query.weight
qformer.encoder.layer.0.attention.attention.query.bias
qformer.encoder.layer.0.attention.attention.key.weight
qformer.encoder.layer.0.attention.attention.key.bias
qformer.encoder.layer.0.attention.attention.value.weight
qformer.encoder.layer.0.attention.attention.value.bias
qformer.encoder.layer.0.attention.output.dense.weight
qformer.encoder.layer.0.attention.output.dense.bias
qformer.encoder.layer.0.attention.output.LayerNorm.weight
qformer.encoder.layer.0.attention.output.LayerNorm.bias
qformer.encoder.layer.0.crossattention.attention.query.weight
qformer.encoder.layer.0.crossattention.attention.query.bias
qformer.encoder.layer.0.crossattention.attention.key.weight
qformer.encoder.layer.0.crossattention.attention.key.bias
qformer.encoder.layer.0.crossattention.attention.value.weight
qformer.encoder.layer.0.crossattention.attention.value.bia

In [15]:
rouge = evaluate.load("rouge")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
     # 预测：转到 CPU
    logits = logits.detach().cpu().numpy()

    # 标签：把 -100 换回 eos_token_id 才能 decode
    labels = labels.clone()
    labels[labels == -100] = processor.tokenizer.eos_token_id
    labels = labels.detach().cpu().numpy()

    decoded_preds = processor.tokenizer.batch_decode(logits, skip_special_tokens=True)
    decoded_labels = processor.tokenizer.batch_decode(labels, skip_special_tokens=True)
    # Compute ROUGE scores
    result = rouge.compute(predictions=decoded_preds, references=decoded_labels)

    # We can return ROUGE-1, ROUGE-2, and ROUGE-L as needed
    return {
        "rouge1": result["rouge1"],
        "rouge2": result["rouge2"],
        "rougeL": result["rougeL"],
    }



train_dataset = VQADataset(dataset=dataset["train"],
                          processor=processor)
valid_dataset = VQADataset(dataset=dataset["test"],
                          processor=processor)

batch_size = 32
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False, pin_memory=True) #can't do validation, answer is None in this split

##################################      Training Arguements     ##################################
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9, last_epoch=-1)

num_epochs = 10
patience = 10
min_eval_acc = float("-inf")
early_stopping_hook = 0
tracking_information = []

wandb_run = wandb.init(
    project="blip2-vqa",
    config={
        "learning_rate": optimizer.param_groups[0]["lr"],
        "batch_size": batch_size,
        "max_length": 64,
        "num_epochs": num_epochs,
        "base_model": base_model_id,
        "llm": llm_id,
    },
)
wandb.watch(model, log="gradients", log_freq=200, log_graph=False)
global_step = 0

##################################      Model Training and Evaluation      ##################################
for epoch in range(num_epochs):
    epoch_loss = 0
    model.train()
    train_acc = 0
    train_tqdm = tqdm(range(len(train_dataloader)), desc=f'Epoch {epoch+1} - Training loss: 0.000 - Train Acc: 0.000', position=0)
    for idx, batch in zip(train_tqdm, train_dataloader):
        input_ids = batch.pop('input_ids').to(device)
        pixel_values = batch.pop('pixel_values').to(device)
        attention_masked = batch.pop('attention_mask').to(device)
        labels = batch.pop('labels').to(device)

        outputs = model(
            input_ids=input_ids,
            pixel_values=pixel_values,
            attention_mask=attention_masked,
            labels=labels,
        )

        """
        模型各子模块用手动 dtype（ViT fp16、Q-Former & projection bf16、LLM bf16）后，Forward 得到的 logits 会跟最后一层（LLaMA）同 dtype，也就是 torch.bfloat16。
        Hugging Face 的 Blip2ForConditionalGeneration 内部在计算 loss时，会把 logits 自动提到 torch.float32，再和 labels（int64）做 softmax + log prob。所以你看到的 outputs.loss 实际上是 fp32（因此 .item() 也是正常的 float）。
        反向传播时，梯度的 dtype 跟对应参数一致：
        Q-Former、language projection：bf16 梯度
        Vision encoder（虽然冻结）/LLM：fp16 或 bf16，但它们 requires_grad=False 就不会计算梯度
        Optimizer 更新的就是当前仍在训练的 bf16 参数，因此可以继续搭配 gradient clipping。
        """
        loss = outputs.loss #这里会忽视labels中token_id=-100的
        epoch_loss += loss.item()
        logits = outputs.logits.argmax(dim=-1)

        train_acc += compute_metrics((logits, labels))["rougeL"]  # You can use rouge1, rouge2, or rougeL

        optimizer.zero_grad()
        loss.backward()
        #grad_norm是裁剪前的l2范数，是所有require_grad=True的参数的l2范数
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
        optimizer.step()
        current_lr = optimizer.param_groups[0]["lr"]
        wandb.log(
            {
                "train/loss": loss.item(),
                "train/grad_norm": grad_norm.item() if hasattr(grad_norm, "item") else grad_norm,
                "train/lr": current_lr,
                "train/epoch": epoch + 1,
            },
            step=global_step,
        )
        global_step += 1
        train_tqdm.set_description(f'Epoch {epoch+1} - Training loss: {epoch_loss/(idx+1):.4f} - Train Acc: {train_acc/(idx+1):.4f}')
        # Clear cache to avoid OOM
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

    #per epoch evaluation
    #model.eval()
    #eval_loss = 0
    #eval_acc = 0
    #val_tqdm = tqdm(range(len(valid_dataloader)), desc=f'Epoch {epoch+1} - Eval loss: 0.000 - Eval Acc: 0.000')
    #for idx, batch in zip(val_tqdm, valid_dataloader):
    #    input_ids = batch.pop('input_ids').to(device)
    #    pixel_values = batch.pop('pixel_values').to(device)
    #    attention_masked = batch.pop('attention_mask').to(device)
    #    labels = batch.pop('labels').to(device)
#
    #    with torch.no_grad():
    #        outputs = model(
    #            input_ids=input_ids,
    #            pixel_values=pixel_values,
    #            attention_mask=attention_masked,
    #            labels=labels,
    #        )
#
    #    loss = outputs.loss
    #    eval_loss += loss.item()
#
    #    logits = outputs.logits.argmax(dim=-1)
    #    eval_acc += compute_metrics((logits, labels))["rougeL"]  # You can use rouge1, rouge2, or rougeL
#
    #    val_tqdm.set_description(f'Epoch {epoch+1} - Eval loss: {eval_loss/(idx+1):.4f} - Eval Acc: {eval_acc/(idx+1):.4f}')
    #    # Clear cache to avoid OOM
    #    torch.cuda.empty_cache()
    #    torch.cuda.synchronize()

    #tracking_information.append((epoch_loss/len(train_dataloader), eval_loss/len(valid_dataloader), optimizer.param_groups[0]["lr"]))
   # print("Epoch: {} - Training loss: {} - Eval Loss: {} - LR: {}".format(epoch+1, epoch_loss/len(train_dataloader), eval_loss/len(valid_dataloader), optimizer.param_groups[0]["lr"]))
    avg_loss = epoch_loss / len(train_dataloader)
    avg_acc = train_acc / len(train_dataloader)
    print("Epoch: {} - Training loss: {}  - LR: {}".format(epoch+1, avg_loss, optimizer.param_groups[0]["lr"]))
    wandb.log(
        {
            "train/epoch_loss": avg_loss,
            "train/epoch_accuracy": avg_acc,
            "epoch": epoch + 1,
        },
        step=global_step,
    )
    scheduler.step()
    eval_acc = avg_acc
    if eval_acc > min_eval_acc:
        model.save_pretrained("./models/blip2", from_pt=True)
        processor.save_pretrained("./models/blip2_processor")
        print("Saved model and processor")
        min_eval_acc = eval_acc
        early_stopping_hook = 0
    else:
        early_stopping_hook += 1
        if early_stopping_hook > patience:
            break

pickle.dump(tracking_information, open("tracking_information.pkl", "wb"))
wandb_run.finish()
print("The finetuning process has done!")

0,1
train/epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/grad_norm,█▁▁▁▁▁▁▁▁▁▁▁▁▆▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/loss,█▆▁▆▆▃▄▄▃▆▃▄▂▄▄▃▁▂▄▃▃▂▃▃▃▇▃▄▅▃▁▂▄▁▃▃▃▃▂▂
train/lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train/epoch,1.0
train/grad_norm,1.39106
train/loss,4.0
train/lr,0.0005


Epoch 1 - Training loss: 3.8918 - Train Acc: 0.0063:   8%|▊         | 1057/13868 [11:58<2:25:08,  1.47it/s]


KeyboardInterrupt: 