In [7]:
import argparse
import datetime
import gc
import glob
import json
import logging
import math
import os
import random
import re
import shutil
import time
from collections import Counter
from dataclasses import asdict, dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

from safetensors.torch import load_file
import datasets
import imagehash
import mlflow
import numpy as np
import requests
import torch
import transformers
import yaml
from accelerate import Accelerator, DataLoaderConfiguration
from accelerate.logging import MultiProcessAdapter
from accelerate.utils import (
    DistributedDataParallelKwargs,
    FullyShardedDataParallelPlugin,
    GradientAccumulationPlugin,
    gather,
    gather_object,
)
from datasets import DatasetDict, concatenate_datasets, load_from_disk
from nltk.tokenize import wordpunct_tokenize
from peft import (
    LoraConfig,
    PeftModel,
    TaskType,
    get_peft_model,
    get_peft_model_state_dict,
)
from peft.tuners import lora
from peft.utils import AuxiliaryTrainingWrapper
from PIL import Image
from scipy.ndimage import zoom
from scorers.scores import compute_scores
from torch import nn
from torch.distributed.fsdp import FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset, SequentialSampler
from tqdm import tqdm
from transformers import (
    AutoConfig,
    AutoImageProcessor,
    AutoTokenizer,
    Dinov2Config,
    Dinov2Model,
    LlamaConfig,
    PretrainedConfig,
    PreTrainedModel,
    VisionEncoderDecoderModel,
    get_cosine_schedule_with_warmup,
)
from transformers.generation import GenerationConfig
from transformers.generation import utils as tf_generation_utils
from transformers.modeling_outputs import BaseModelOutput, ModelOutput
from transformers.models.dinov2.modeling_dinov2 import Dinov2Embeddings
from transformers.models.llama.modeling_llama import LlamaRMSNorm, LlamaRotaryEmbedding

In [2]:
@dataclass
class Vision2LanguageOutputWithPast(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    past_key_values: Optional[List[torch.FloatTensor]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    image_hidden_states: Optional[torch.FloatTensor] = None


class VisionLanguageProjector(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.linear_1 = nn.Linear(config.encoder_hidden_size, config.decoder_hidden_size, bias=True)
        self.act = nn.SiLU()
        self.linear_2 = nn.Linear(config.decoder_hidden_size, config.decoder_hidden_size, bias=True)

    def forward(self, image_features):
        hidden_states = self.linear_1(image_features)
        hidden_states = self.act(hidden_states)
        hidden_states = self.linear_2(hidden_states)
        return hidden_states


class Vision2LanguageModel(VisionEncoderDecoderModel):
    def __init__(self, config=None, encoder=None, decoder=None):

        super().__init__(config=config, encoder=encoder, decoder=decoder)
        self.config.encoder_hidden_size = self.encoder.config.hidden_size
        self.config.decoder_hidden_size = self.decoder.config.hidden_size

        # replace enc_to_dec_proj with VisionLanguageProjector
        self.v2l_projector = VisionLanguageProjector(self.config)
        if hasattr(self, "enc_to_dec_proj"):
            del self.enc_to_dec_proj  # 移除投影层

    def _inject_image_features(self, input_ids, decoder_input_ids, image_features):
        # image_indices_map 是一个嵌套list，每个样本对应一个list，list中的元素是图像在 last_hidden_state 中的索引
        # e.g. [[0], [1], [2, 3], ...]

        # replace img features with the <|image_token|> placeholder token in the input text
        special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
        special_image_mask = special_image_mask.expand_as(decoder_input_ids).to(decoder_input_ids.device)

        # 保证所有 image_features 都能够被复制到 decoder_input_ids 中
        assert special_image_mask.sum() == image_features.numel(), f"special_image_mask.sum()={special_image_mask.sum()}, image_features.numel()={image_features.numel()}, should be equal to guarantee that all image features are copied to decoder_input_ids"

        image_features = image_features.to(decoder_input_ids.device, decoder_input_ids.dtype)
        decoder_input_ids = decoder_input_ids.masked_scatter(special_image_mask, image_features)

        return decoder_input_ids

    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        decoder_assistant_masks: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = True,
        cache_position: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0,
        output_loss: Optional[bool] = False,
        **kwargs,
    ) -> Union[Tuple, Vision2LanguageOutputWithPast]:
        """Additional args:
        `decoder_inputs_embeds`: should represent the text embeddings with image features injected.
        `encoder_outputs`: in inference statge, we encode `pixel_values` and get `encoder_outputs` outside this forward method. This is because the `pixel_values` and `decoder_input_ids` have different batch sizes, which cause error in generate().

        If `output_loss` is True, by default we use `decoder_input_ids` as `labels`.
        And the `decoder_assistant_masks` should be provided to compute the loss.
        `decoder_assistant_masks` is provided by `tokenizer.apply_chat_template`.
        `decoder_assistant_masks` is a tensor with the same shape as decoder_input_ids, and the value is 0 or 1. 0: system/user tokens, 1: assistant tokens, which is the tokens that need to be generated.
        """
        LOGGER.debug("rank[%s], kwargs %s", ACCELERATOR.process_index, kwargs)
        LOGGER.debug("rank[%s], pixel_values: %s", ACCELERATOR.process_index, pixel_values.shape if pixel_values is not None else None)
        LOGGER.debug("rank[%s], decoder_input_ids: %s", ACCELERATOR.process_index, decoder_input_ids)
        LOGGER.debug("rank[%s], decoder_attention_mask: %s", ACCELERATOR.process_index, decoder_attention_mask.shape)
        LOGGER.debug("rank[%s], encoder_outputs.last_hidden_state: %s", ACCELERATOR.process_index, encoder_outputs.last_hidden_state.shape if encoder_outputs is not None else None)
        LOGGER.debug("rank[%s], past_key_values: %s", ACCELERATOR.process_index, past_key_values)
        LOGGER.debug("rank[%s], decoder_inputs_embeds: %s", ACCELERATOR.process_index, decoder_inputs_embeds)
        LOGGER.debug("rank[%s], position_ids: %s", ACCELERATOR.process_index, position_ids)
        LOGGER.debug("rank[%s], logits_to_keep: %s", ACCELERATOR.process_index, logits_to_keep)

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states

        # train时，有pixel_values，没有encoder_outputs
        # inference时，没有pixel_values，有encoder_outputs；encoder_outputs只有第一轮才需要，后续需要忽略
        if (pixel_values is not None) and (encoder_outputs is not None):
            raise ValueError("You must not specify both pixel_values and encoder_outputs.")

        # 我们目前没有使用过 decoder_inputs_embeds
        if (decoder_input_ids is None) ^ (decoder_inputs_embeds is not None):
            raise ValueError("You must specify exactly one of decoder_input_ids or decoder_inputs_embeds")

        if (pixel_values is not None or encoder_outputs is not None) and decoder_inputs_embeds is not None:
            raise ValueError("You cannot specify both `pixel_values`/`encoder_outputs` and `decoder_inputs_embeds` at the same time, and must specify either one")

        if decoder_inputs_embeds is None:
            # get text embeddings
            decoder_inputs_embeds = self.decoder.get_input_embeddings()(decoder_input_ids)

        # 如果有encoder_outputs，就不需要再次 encode pixel_values
        if (pixel_values is not None) and (encoder_outputs is None):
            # get img features
            encoder_outputs = self.encoder(pixel_values=pixel_values, return_dict=True)

        # train forward 以及 inference first round，需要进行这一步
        # train forward 会提供 pixel_values
        # inference all rounds 会提供 encoder_outputs，而pixel_values=None；在first round时，past_key_values=None，后续为past_key_values=DynamicCache()
        if encoder_outputs is not None and past_key_values is None:
            image_features = encoder_outputs.last_hidden_state  # torch.Size([4, 1370, enc_dim])
            # project image features
            LOGGER.debug("rank[%s], v2lmodel forward image_features shape: %s", ACCELERATOR.process_index, image_features.shape)
            image_features = self.v2l_projector(image_features)
            # inject image features into text embeddings
            decoder_inputs_embeds = self._inject_image_features(decoder_input_ids, decoder_inputs_embeds, image_features)

        # Text generation. decoder_inputs_embeds is used in replace of decoder_input_ids on decoder in all cases.
        # In train statge, decoder_input_ids is encoded into decoder_inputs_embeds and then merged with image features.
        # In inference stage, encoder_outputs is passed from generate() in replace of pixel_values.
        decoder_outputs = self.decoder(
            attention_mask=decoder_attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=True,
            cache_position=cache_position,
            logits_to_keep=logits_to_keep,
            **kwargs,
        )

        logits = decoder_outputs.logits

        # text loss
        loss = None
        if output_loss:
            labels = labels if labels is not None else decoder_input_ids

            # Shift so that tokens < n predict n
            if decoder_assistant_masks is not None:
                shift_label_mask = decoder_assistant_masks[:, 1:]  # torch.Size([bsz, seq_len - 1])
            elif decoder_attention_mask is not None:
                shift_label_mask = decoder_attention_mask[:, 1:]
            else:
                raise ValueError("decoder_assistant_masks or decoder_attention_mask should be provided")

            shift_logits = logits[:, :-1, :]  # torch.Size([bsz, seq_len - 1, vocab_size])
            shift_labels = labels[:, 1:]  # torch.Size([bsz, seq_len - 1])
            active_shift_logits = shift_logits[shift_label_mask != 0].contiguous()  # torch.Size([num_acitve_labels, vocab_size])
            active_shift_labels = shift_labels[shift_label_mask != 0].contiguous()  # torch.Size([num_acitve_labels])

            ce_loss_fct = nn.CrossEntropyLoss()
            loss = ce_loss_fct(active_shift_logits, active_shift_labels)

        return Vision2LanguageOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=decoder_outputs.past_key_values,
            hidden_states=decoder_outputs.hidden_states,
            attentions=decoder_outputs.attentions,
            image_hidden_states=image_features if pixel_values is not None else None,
        )

    @torch.no_grad()
    def generate(
        self,
        inputs,
        generation_config=None,
        logits_processor=None,
        stopping_criteria=None,
        prefix_allowed_tokens_fn=None,
        synced_gpus=None,
        assistant_model=None,
        streamer=None,
        negative_prompt_ids=None,
        negative_prompt_attention_mask=None,
        **kwargs,  # If the model is an encoder-decoder model, encoder specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with decoder_.
    ):
        # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
        self._validate_model_class()
        tokenizer = kwargs.pop("tokenizer", None)  # Pull this out first, we only use it for stopping criteria
        assistant_tokenizer = kwargs.pop("assistant_tokenizer", None)  # only used for assisted generation
        LOGGER.debug("rank[%s], step1", ACCELERATOR.process_index)
        LOGGER.debug("rank[%s], tokenizer %s", ACCELERATOR.process_index, tokenizer)
        LOGGER.debug("rank[%s], assistant_tokenizer: %s", ACCELERATOR.process_index, assistant_tokenizer)

        generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
        self._validate_model_kwargs(model_kwargs.copy())
        self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer)
        LOGGER.debug("rank[%s], generation_config: %s", ACCELERATOR.process_index, generation_config)
        LOGGER.debug("rank[%s], model_kwargs step1: %s", ACCELERATOR.process_index, model_kwargs)
        LOGGER.debug("rank[%s], decoder_input_ids: %s", ACCELERATOR.process_index, model_kwargs["decoder_input_ids"].shape)
        LOGGER.debug("rank[%s], decoder_attention_mask: %s", ACCELERATOR.process_index, model_kwargs["decoder_attention_mask"].shape)

        # 2. Set generation parameters if not already defined
        if synced_gpus is None:
            synced_gpus = (tf_generation_utils.is_deepspeed_zero3_enabled() or tf_generation_utils.is_fsdp_managed_module(self)) and tf_generation_utils.dist.get_world_size() > 1
        LOGGER.debug("rank[%s], step2", ACCELERATOR.process_index)
        LOGGER.debug("rank[%s], synced_gpus: %s (should be True)", ACCELERATOR.process_index, synced_gpus, main_process_only=False)

        logits_processor = logits_processor if logits_processor is not None else tf_generation_utils.LogitsProcessorList()
        stopping_criteria = stopping_criteria if stopping_criteria is not None else tf_generation_utils.StoppingCriteriaList()
        LOGGER.debug("rank[%s], logits_processor: %s", ACCELERATOR.process_index, logits_processor)
        LOGGER.debug("rank[%s], stopping_criteria: %s", ACCELERATOR.process_index, stopping_criteria)

        accepts_attention_mask = "attention_mask" in set(tf_generation_utils.inspect.signature(self.forward).parameters.keys())
        requires_attention_mask = "encoder_outputs" not in model_kwargs
        kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
        LOGGER.debug("rank[%s], accepts_attention_mask: %s", ACCELERATOR.process_index, accepts_attention_mask)
        LOGGER.debug("rank[%s], requires_attention_mask: %s", ACCELERATOR.process_index, requires_attention_mask)
        LOGGER.debug("rank[%s], kwargs_has_attention_mask: %s", ACCELERATOR.process_index, kwargs_has_attention_mask)

        # 3. Define model inputs
        inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(inputs, generation_config.bos_token_id, model_kwargs)
        # batch_size = inputs_tensor.shape[0]
        # encoder和decoder的bsz可能不一样，我们以decoder的bsz为准
        batch_size = model_kwargs["decoder_input_ids"].shape[0]
        LOGGER.debug("rank[%s], step3", ACCELERATOR.process_index)
        LOGGER.debug("rank[%s], inputs_tensor: %s", ACCELERATOR.process_index, inputs_tensor.shape)
        LOGGER.debug("rank[%s], model_input_name: %s", ACCELERATOR.process_index, model_input_name)
        LOGGER.debug("rank[%s], model_kwargs step3: %s", ACCELERATOR.process_index, model_kwargs)
        LOGGER.debug("rank[%s], batch_size: %s", ACCELERATOR.process_index, batch_size)

        device = inputs_tensor.device
        self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)

        # decoder-only models must use left-padding for batched generation.
        LOGGER.debug("rank[%s], self.config.is_encoder_decoder %s", ACCELERATOR.process_index, self.config.is_encoder_decoder)
        if not self.config.is_encoder_decoder and not tf_generation_utils.is_torchdynamo_compiling():
            # If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
            # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.
            LOGGER.warning("Should not see this warning!!! A decoder-only architecture is detected, while we are using encoder-decoder model.")
            if generation_config._pad_token_tensor is not None and batch_size > 1 and len(inputs_tensor.shape) == 2 and torch.sum(inputs_tensor[:, -1] == generation_config._pad_token_tensor) > 0:
                LOGGER.warning("A decoder-only architecture is being used, but right-padding was detected! For correct " "generation results, please set `padding_side='left'` when initializing the tokenizer.")

        # 4. Define other model kwargs
        # decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
        # generating the first new token or not, and we only want to use the embeddings for the first new token)
        LOGGER.debug("rank[%s], step4", ACCELERATOR.process_index)
        LOGGER.debug("rank[%s], Conv2D weight shape: %s", ACCELERATOR.process_index, self.encoder.embeddings.patch_embeddings.projection.weight.shape, main_process_only=False)
        if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
            generation_config.use_cache = True

        if not kwargs_has_attention_mask and requires_attention_mask and accepts_attention_mask:
            model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(inputs_tensor, generation_config, model_kwargs)
            LOGGER.debug("rank[%s], model_kwargs['attention_mask']: %s", ACCELERATOR.process_index, model_kwargs["attention_mask"].shape)
        elif kwargs_has_attention_mask:
            # TODO (joao): generalize this check with other types of inputs
            if model_input_name == "input_ids" and len(model_kwargs["attention_mask"].shape) > 2:
                raise ValueError("`attention_mask` passed to `generate` must be 2D.")

        if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
            # if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
            model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(inputs_tensor, model_kwargs, model_input_name, generation_config)
            LOGGER.debug("rank[%s], model_kwargs step4: %s", ACCELERATOR.process_index, model_kwargs)
            LOGGER.debug("rank[%s], model_kwargs['encoder_outputs'].last_hidden_state: %s", ACCELERATOR.process_index, model_kwargs["encoder_outputs"].last_hidden_state.shape)
            LOGGER.debug("rank[%s], model_kwargs['encoder_outputs'].pooler_output: %s", ACCELERATOR.process_index, model_kwargs["encoder_outputs"].pooler_output.shape)

        # 5. Prepare `input_ids` which will be used for auto-regressive generation
        LOGGER.debug("rank[%s], step5", ACCELERATOR.process_index)
        if self.config.is_encoder_decoder:
            LOGGER.debug("rank[%s], model_input_name: %s", ACCELERATOR.process_index, model_input_name)
            LOGGER.debug("rank[%s], before decoder_start_token_id: %s", ACCELERATOR.process_index, generation_config._decoder_start_token_tensor)
            # 原始方法，当input_ids不是以decoder_start_token_id开头时，添加decoder_start_token_id
            # 更新后的方法，当input_ids不是以decoder_start_token_id 或 pad_token_id 开头时，添加decoder_start_token_id
            # 因为我们在collect_fn中，会将input_ids以8的倍数填充left padding，然后紧跟着decoder_start_token_id和正文
            input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
                batch_size=batch_size,
                model_input_name=model_input_name,
                model_kwargs=model_kwargs,
                decoder_start_token_id=generation_config._decoder_start_token_tensor,
                pad_token_id=torch.tensor(generation_config.pad_token_id, device=inputs_tensor.device),
                device=inputs_tensor.device,
            )
            LOGGER.debug("rank[%s], input_ids: %s", ACCELERATOR.process_index, input_ids.shape)
            LOGGER.debug("rank[%s], input_ids: %s", ACCELERATOR.process_index, input_ids.tolist())
            LOGGER.debug("rank[%s], model_kwargs step5: %s", ACCELERATOR.process_index, model_kwargs)
        else:
            input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")

        if generation_config.token_healing:
            input_ids = self.heal_tokens(input_ids, tokenizer)

        if streamer is not None:
            streamer.put(input_ids.cpu())

        # 6. Prepare `max_length` depending on other stopping criteria.
        LOGGER.debug("rank[%s], step6", ACCELERATOR.process_index)
        input_ids_length = input_ids.shape[-1]
        has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
        has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
        generation_config = self._prepare_generated_length(
            generation_config=generation_config,
            has_default_max_length=has_default_max_length,
            has_default_min_length=has_default_min_length,
            model_input_name=model_input_name,
            inputs_tensor=inputs_tensor,
            input_ids_length=input_ids_length,
        )
        LOGGER.debug("rank[%s], input_ids_length: %s", ACCELERATOR.process_index, input_ids_length)
        LOGGER.debug("rank[%s], has_default_max_length: %s", ACCELERATOR.process_index, has_default_max_length)
        LOGGER.debug("rank[%s], has_default_min_length: %s", ACCELERATOR.process_index, has_default_min_length)
        LOGGER.debug("rank[%s], generation_config: %s", ACCELERATOR.process_index, type(generation_config))

        # If the model supports `logits_to_keep` in forward(), set it to 1 to avoid computing the whole
        # logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding
        # dynamically overrides this value as it can need more than the last token logits
        if self._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs:
            model_kwargs["logits_to_keep"] = 1
            LOGGER.debug("rank[%s], model_kwargs step6: %s", ACCELERATOR.process_index, model_kwargs)

        self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)

        # 7. Prepare the cache.
        # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
        # - different models have a different cache name expected by the model (default = "past_key_values")
        # - `max_length`, prepared above, is used to determine the maximum cache length
        max_cache_length = generation_config.max_length - 1
        if inputs_tensor.shape[1] != input_ids_length and model_input_name == "inputs_embeds" and not self.config.is_encoder_decoder:
            max_cache_length += inputs_tensor.shape[1]
        self._prepare_cache_for_generation(generation_config, model_kwargs, assistant_model, batch_size, max_cache_length, device)

        # 8. determine generation mode
        LOGGER.debug("rank[%s], step8", ACCELERATOR.process_index)
        generation_mode = generation_config.get_generation_mode(assistant_model)
        LOGGER.debug("rank[%s], generation_mode %s", ACCELERATOR.process_index, generation_mode)

        if streamer is not None and (generation_config.num_beams > 1):
            raise ValueError("`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1.")

        if not tf_generation_utils.is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
            tf_generation_utils.warnings.warn(
                "You are calling .generate() with the `input_ids` being on a device type different" f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model" f" is on {self.device.type}. You may experience unexpected behaviors or slower generation." " Please make sure that you have put `input_ids` to the" f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before" " running `.generate()`.",
                UserWarning,
            )

        # 9. prepare logits processors and stopping criteria
        prepared_logits_processor = self._get_logits_processor(
            generation_config=generation_config,
            input_ids_seq_length=input_ids_length,
            encoder_input_ids=inputs_tensor,
            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
            logits_processor=logits_processor,
            device=inputs_tensor.device,
            model_kwargs=model_kwargs,
            negative_prompt_ids=negative_prompt_ids,
            negative_prompt_attention_mask=negative_prompt_attention_mask,
        )
        prepared_stopping_criteria = self._get_stopping_criteria(generation_config=generation_config, stopping_criteria=stopping_criteria, tokenizer=tokenizer, **kwargs)

        # Set model_kwargs `use_cache` so we can use it later in forward runs
        model_kwargs["use_cache"] = generation_config.use_cache
        LOGGER.debug("rank[%s], model_kwargs step9: %s", ACCELERATOR.process_index, model_kwargs)

        # 10. go into different generation modes
        result = None
        if generation_mode == tf_generation_utils.GenerationMode.ASSISTED_GENERATION:
            if generation_config.num_return_sequences > 1:
                raise ValueError("num_return_sequences has to be 1 when doing assisted generate, " f"but is {generation_config.num_return_sequences}.")
            if batch_size > 1:
                raise ValueError("assisted generate is only supported for batch_size = 1")
            if not model_kwargs["use_cache"]:
                raise ValueError("assisted generate requires `use_cache=True`")
            if generation_config.cache_implementation in ["static", "hybrid", "sliding_window"]:
                raise ValueError("assisted generate is not supported with Static cache classes`")
            if self._is_stateful:
                # In assisted generation we need the ability to confirm whether the model would pick certain tokens,
                # which is not possible with stateful models (they can't reset to a previous subset of generated text)
                raise ValueError(f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}")

            # 11. Get the candidate generator, given the parameterization
            candidate_generator = self._get_candidate_generator(
                generation_config=generation_config,
                input_ids=input_ids,
                inputs_tensor=inputs_tensor,
                assistant_model=assistant_model,
                logits_processor=logits_processor,
                target_tokenizer=tokenizer,
                assistant_tokenizer=assistant_tokenizer,
                model_kwargs=model_kwargs,
            )

            # 12. run assisted generate
            result = self._assisted_decoding(
                input_ids,
                candidate_generator=candidate_generator,
                logits_processor=prepared_logits_processor,
                stopping_criteria=prepared_stopping_criteria,
                generation_config=generation_config,
                synced_gpus=synced_gpus,
                streamer=streamer,
                **model_kwargs,
            )
        elif generation_mode == tf_generation_utils.GenerationMode.DOLA_GENERATION:
            if self._is_stateful:
                # DoLa decoding was not designed for stateful models, and would require some changes
                raise ValueError(f"dola decoding is not supported with stateful models, such as {self.__class__.__name__}")
            result = self._dola_decoding(
                input_ids,
                dola_layers=generation_config.dola_layers,
                logits_processor=prepared_logits_processor,
                stopping_criteria=prepared_stopping_criteria,
                generation_config=generation_config,
                synced_gpus=synced_gpus,
                streamer=streamer,
                **model_kwargs,
            )

        elif generation_mode == tf_generation_utils.GenerationMode.CONTRASTIVE_SEARCH:
            if not model_kwargs["use_cache"]:
                raise ValueError("Contrastive search requires `use_cache=True`")
            if self._is_stateful:
                # Just like assisted generation, we need to be able to rollback to a previous state (see comment above)
                raise ValueError(f"contrastive search is not supported with stateful models, such as {self.__class__.__name__}")

            result = self._contrastive_search(
                input_ids,
                logits_processor=prepared_logits_processor,
                stopping_criteria=prepared_stopping_criteria,
                generation_config=generation_config,
                synced_gpus=synced_gpus,
                streamer=streamer,
                **model_kwargs,
            )

        elif generation_mode in (tf_generation_utils.GenerationMode.SAMPLE, tf_generation_utils.GenerationMode.GREEDY_SEARCH):
            # 11. expand input_ids with `num_return_sequences` additional sequences per batch
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids=input_ids,
                expand_size=generation_config.num_return_sequences,
                is_encoder_decoder=self.config.is_encoder_decoder,
                **model_kwargs,
            )

            # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
            result = self._sample(
                input_ids,
                logits_processor=prepared_logits_processor,
                stopping_criteria=prepared_stopping_criteria,
                generation_config=generation_config,
                synced_gpus=synced_gpus,
                streamer=streamer,
                **model_kwargs,
            )

        elif generation_mode in (tf_generation_utils.GenerationMode.BEAM_SAMPLE, tf_generation_utils.GenerationMode.BEAM_SEARCH):
            # 11. prepare beam search scorer
            beam_scorer = tf_generation_utils.BeamSearchScorer(
                batch_size=batch_size,
                num_beams=generation_config.num_beams,
                device=inputs_tensor.device,
                length_penalty=generation_config.length_penalty,
                do_early_stopping=generation_config.early_stopping,
                num_beam_hyps_to_keep=generation_config.num_return_sequences,
                max_length=generation_config.max_length,
            )

            # 12. interleave input_ids with `num_beams` additional sequences per batch
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids=input_ids,
                expand_size=generation_config.num_beams,
                is_encoder_decoder=self.config.is_encoder_decoder,
                **model_kwargs,
            )

            # 13. run beam sample
            result = self._beam_search(
                input_ids,
                beam_scorer,
                logits_processor=prepared_logits_processor,
                stopping_criteria=prepared_stopping_criteria,
                generation_config=generation_config,
                synced_gpus=synced_gpus,
                **model_kwargs,
            )

        elif generation_mode == tf_generation_utils.GenerationMode.GROUP_BEAM_SEARCH:
            # 11. prepare beam search scorer
            beam_scorer = tf_generation_utils.BeamSearchScorer(
                batch_size=batch_size,
                num_beams=generation_config.num_beams,
                device=inputs_tensor.device,
                length_penalty=generation_config.length_penalty,
                do_early_stopping=generation_config.early_stopping,
                num_beam_hyps_to_keep=generation_config.num_return_sequences,
                num_beam_groups=generation_config.num_beam_groups,
                max_length=generation_config.max_length,
            )
            # 12. interleave input_ids with `num_beams` additional sequences per batch
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids=input_ids,
                expand_size=generation_config.num_beams,
                is_encoder_decoder=self.config.is_encoder_decoder,
                **model_kwargs,
            )
            # 13. run beam search
            result = self._group_beam_search(
                input_ids,
                beam_scorer,
                logits_processor=prepared_logits_processor,
                stopping_criteria=prepared_stopping_criteria,
                generation_config=generation_config,
                synced_gpus=synced_gpus,
                **model_kwargs,
            )

        elif generation_mode == tf_generation_utils.GenerationMode.CONSTRAINED_BEAM_SEARCH:
            final_constraints = []
            if generation_config.constraints is not None:
                final_constraints = generation_config.constraints

            if generation_config.force_words_ids is not None:

                def typeerror():
                    raise ValueError("`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]` " f"of positive integers, but is {generation_config.force_words_ids}.")

                if not isinstance(generation_config.force_words_ids, list) or len(generation_config.force_words_ids) == 0:
                    typeerror()

                for word_ids in generation_config.force_words_ids:
                    if isinstance(word_ids[0], list):
                        if not isinstance(word_ids, list) or len(word_ids) == 0:
                            typeerror()
                        if any(not isinstance(token_ids, list) for token_ids in word_ids):
                            typeerror()
                        if any(any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids) for token_ids in word_ids):
                            typeerror()

                        constraint = tf_generation_utils.DisjunctiveConstraint(word_ids)
                    else:
                        if not isinstance(word_ids, list) or len(word_ids) == 0:
                            typeerror()
                        if any((not isinstance(token_id, int) or token_id < 0) for token_id in word_ids):
                            typeerror()

                        constraint = tf_generation_utils.PhrasalConstraint(word_ids)
                    final_constraints.append(constraint)

            # 11. prepare beam search scorer
            constrained_beam_scorer = tf_generation_utils.ConstrainedBeamSearchScorer(
                constraints=final_constraints,
                batch_size=batch_size,
                num_beams=generation_config.num_beams,
                device=inputs_tensor.device,
                length_penalty=generation_config.length_penalty,
                do_early_stopping=generation_config.early_stopping,
                num_beam_hyps_to_keep=generation_config.num_return_sequences,
                max_length=generation_config.max_length,
            )
            # 12. interleave input_ids with `num_beams` additional sequences per batch
            input_ids, model_kwargs = self._expand_inputs_for_generation(
                input_ids=input_ids,
                expand_size=generation_config.num_beams,
                is_encoder_decoder=self.config.is_encoder_decoder,
                **model_kwargs,
            )
            # 13. run beam search
            result = self._constrained_beam_search(
                input_ids,
                constrained_beam_scorer=constrained_beam_scorer,
                logits_processor=prepared_logits_processor,
                stopping_criteria=prepared_stopping_criteria,
                generation_config=generation_config,
                synced_gpus=synced_gpus,
                **model_kwargs,
            )

        # Convert to legacy cache format if requested
        if generation_config.return_legacy_cache is True and not tf_generation_utils.is_torchdynamo_compiling() and hasattr(result, "past_key_values") and getattr(result.past_key_values, "to_legacy_cache") is not None:
            result.past_key_values = result.past_key_values.to_legacy_cache()
        return result

    def _prepare_decoder_input_ids_for_generation(
        self,
        batch_size: int,
        model_input_name: str,
        model_kwargs: Dict[str, torch.Tensor],
        decoder_start_token_id: torch.Tensor,
        pad_token_id: torch.Tensor,
        device: torch.device = None,
    ) -> Tuple[torch.LongTensor, Dict[str, torch.Tensor]]:
        """Prepares `decoder_input_ids` for generation with encoder-decoder models
        Update: if the first token is not decoder_start_token_id or pad_token_id, we need to prepend decoder_start_token_id. Because our input_ids are left padded to multiple of 8, and then followed by decoder_start_token_id and the real input_ids. It is done in the collate_fn.
        """
        # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
        # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input.
        if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
            decoder_input_ids = model_kwargs.pop("decoder_input_ids")
        elif "input_ids" in model_kwargs and model_input_name != "input_ids":
            decoder_input_ids = model_kwargs.pop("input_ids")
        else:
            decoder_input_ids = None

        # 2. `decoder_start_token_id` must have shape (batch_size, 1)
        if device is None:
            device = self.device
        if decoder_start_token_id.ndim == 1:
            if decoder_start_token_id.shape[0] != batch_size:
                raise ValueError(f"`decoder_start_token_id` expected to have length {batch_size} but got {decoder_start_token_id.shape[0]}")
            decoder_start_token_id = decoder_start_token_id.view(-1, 1)
        else:
            decoder_start_token_id = torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id

        # 3. Encoder-decoder models expect the `decoder_input_ids` to start with a special token. Let's ensure that.
        # no user input -> use decoder_start_token_id as decoder_input_ids
        if decoder_input_ids is None:
            decoder_input_ids = decoder_start_token_id
        # exception: Donut checkpoints have task-specific decoder starts and don't expect a BOS token. Note that the
        # original checkpoints can't be detected through `self.__class__.__name__.lower()`, needing custom logic.
        # See: https://github.com/huggingface/transformers/pull/31470
        elif "donut" in self.__class__.__name__.lower() or (self.config.model_type == "vision-encoder-decoder" and "donut" in self.config.encoder.model_type.lower()):
            pass
        elif self.config.model_type in ["whisper"]:
            pass
        # user input but doesn't start with decoder_start_token_id -> prepend decoder_start_token_id (and adjust
        # decoder_attention_mask if provided)
        #######################################
        # !!! Update: if the first token is not decoder_start_token_id or pad_token_id, we need to prepend decoder_start_token_id
        #######################################
        elif ((decoder_input_ids[:, 0] != decoder_start_token_id[:, 0]) & (decoder_input_ids[:, 0] != pad_token_id)).all().item():
            decoder_input_ids = torch.cat([decoder_start_token_id, decoder_input_ids], dim=-1)
            if "decoder_attention_mask" in model_kwargs:
                decoder_attention_mask = model_kwargs["decoder_attention_mask"]
                decoder_attention_mask = torch.cat(
                    (torch.ones_like(decoder_attention_mask)[:, :1], decoder_attention_mask),
                    dim=-1,
                )
                model_kwargs["decoder_attention_mask"] = decoder_attention_mask

        return decoder_input_ids, model_kwargs


class ImageTextDataset(Dataset):
    def __init__(self, hf_dataset, img_processor, tokenizer, split):
        # column_names: ['source', 'images_path', 'images', 'section_text', 'doc_key', 'split_sents', 'split_sent_toks', 'sent_idx_split_idx', 'radlex', 'cxrgraph_ent', 'cxrgraph_attr', 'cxrgraph_rel']
        self.split = split
        self.src_path = os.path.dirname(hf_dataset.cache_files[0]["filename"]) if hf_dataset.cache_files else ""
        self.img_processor = img_processor
        self.tokenizer = tokenizer
        self.samples = hf_dataset

    def __len__(self):
        return len(self.samples)

    # 返回索引的数据与标签
    def __getitem__(self, index):
        return self.samples[index]

In [3]:
model_path = "/scratch/c.c21051562/workspace/arrg_img2text/outputs/models/4_1_test_pretrain_save_with_FULL_STATE_DICT"
pre_treained_model = Vision2LanguageModel.from_pretrained(model_path)

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:11<00:00,  5.87s/it]


In [4]:
vision_model_path = "/scratch/c.c21051562/resources/downloaded_models/rad-dino-maira-2"
language_model_path = "/scratch/c.c21051562/resources/downloaded_models/Llama-3.2-1B"
init_model = Vision2LanguageModel.from_encoder_decoder_pretrained(vision_model_path, language_model_path)

In [5]:
pre_treained_model

Vision2LanguageModel(
  (encoder): Dinov2Model(
    (embeddings): Dinov2Embeddings(
      (patch_embeddings): Dinov2PatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(14, 14), stride=(14, 14))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): Dinov2Encoder(
      (layer): ModuleList(
        (0-11): 12 x Dinov2Layer(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attention): Dinov2Attention(
            (attention): Dinov2SelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): Dinov2SelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )


In [6]:
init_model

Vision2LanguageModel(
  (encoder): Dinov2Model(
    (embeddings): Dinov2Embeddings(
      (patch_embeddings): Dinov2PatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(14, 14), stride=(14, 14))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): Dinov2Encoder(
      (layer): ModuleList(
        (0-11): 12 x Dinov2Layer(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attention): Dinov2SdpaAttention(
            (attention): Dinov2SdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): Dinov2SelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
      

In [8]:
sd = load_file("/scratch/c.c21051562/workspace/arrg_img2text/outputs/models/4_1_vlgen_effu_fsdp_peft_test_finetune/adapter_model.safetensors")

In [10]:
for n, p in sd.items():
    print(n, p.shape)

base_model.model.decoder.lm_head.base_layer.weight torch.Size([128264, 2048])
base_model.model.decoder.lm_head.lora_A.weight torch.Size([32768])
base_model.model.decoder.lm_head.lora_B.weight torch.Size([2052224])
base_model.model.decoder.model.embed_tokens.base_layer.weight torch.Size([128264, 2048])
base_model.model.decoder.model.embed_tokens.lora_embedding_A torch.Size([16, 128264])
base_model.model.decoder.model.embed_tokens.lora_embedding_B torch.Size([2048, 16])
base_model.model.decoder.model.layers.0.self_attn.q_proj.lora_A.weight torch.Size([16, 2048])
base_model.model.decoder.model.layers.0.self_attn.q_proj.lora_B.weight torch.Size([2048, 16])
base_model.model.decoder.model.layers.0.self_attn.v_proj.lora_A.weight torch.Size([16, 2048])
base_model.model.decoder.model.layers.0.self_attn.v_proj.lora_B.weight torch.Size([512, 16])
base_model.model.decoder.model.layers.1.self_attn.q_proj.lora_A.weight torch.Size([16, 2048])
base_model.model.decoder.model.layers.1.self_attn.q_proj.l

In [11]:
sd2 = load_file("/scratch/c.c21051562/workspace/arrg_img2text/outputs/models/test/model.safetensors")

In [12]:
for n, p in sd2.items():
    print(n, p.shape)

decoder.lm_head.weight torch.Size([128264, 2048])
decoder.model.embed_tokens.weight torch.Size([128264, 2048])
decoder.model.layers.0.input_layernorm.weight torch.Size([2048])
decoder.model.layers.0.mlp.down_proj.weight torch.Size([2048, 8192])
decoder.model.layers.0.mlp.gate_proj.weight torch.Size([8192, 2048])
decoder.model.layers.0.mlp.up_proj.weight torch.Size([8192, 2048])
decoder.model.layers.0.post_attention_layernorm.weight torch.Size([2048])
decoder.model.layers.0.self_attn.k_proj.weight torch.Size([512, 2048])
decoder.model.layers.0.self_attn.o_proj.weight torch.Size([2048, 2048])
decoder.model.layers.0.self_attn.q_proj.weight torch.Size([2048, 2048])
decoder.model.layers.0.self_attn.v_proj.weight torch.Size([512, 2048])
decoder.model.layers.1.input_layernorm.weight torch.Size([2048])
decoder.model.layers.1.mlp.down_proj.weight torch.Size([2048, 8192])
decoder.model.layers.1.mlp.gate_proj.weight torch.Size([8192, 2048])
decoder.model.layers.1.mlp.up_proj.weight torch.Size([81