In [1]:
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from transformers.generation.utils import GenerationMixin
from transformers.cache_utils import StaticCache,DynamicCache
from qwen_vl_utils import process_vision_info
import re
import torch
import torch.nn as nn

import torch
from fvcore.nn import FlopCountAnalysis


class CustomGenerationMixin(GenerationMixin):
    def _update_model_kwargs_for_generation(
        self,
        outputs,
        model_kwargs,
        is_encoder_decoder = False,
        num_new_tokens = 1,
    ):
        # update past_key_values keeping its naming used in model code
        cache_name, cache = self._extract_past_from_model_output(outputs)
        model_kwargs[cache_name] = cache
        if getattr(outputs, "state", None) is not None:
            model_kwargs["state"] = outputs.state

        # update token_type_ids with last value
        if "token_type_ids" in model_kwargs:
            token_type_ids = model_kwargs["token_type_ids"]
            model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)

        if not is_encoder_decoder:
            # update attention mask
            if "attention_mask" in model_kwargs:
                attention_mask = model_kwargs["attention_mask"]
                model_kwargs["attention_mask"] = torch.cat(
                    [attention_mask, attention_mask.new_ones((attention_mask.shape[0], num_new_tokens))], dim=-1
                )
        else:
            # update decoder attention mask
            if "decoder_attention_mask" in model_kwargs:
                decoder_attention_mask = model_kwargs["decoder_attention_mask"]
                model_kwargs["decoder_attention_mask"] = torch.cat(
                    [decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
                    dim=-1,
                )

        if model_kwargs.get("use_cache", True):
            model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
        else:
            past_positions = model_kwargs.pop("cache_position")
            new_positions = torch.arange(
                past_positions[-1] + 1, past_positions[-1] + num_new_tokens + 1, dtype=past_positions.dtype
            ).to(past_positions.device)
            model_kwargs["cache_position"] = torch.cat((past_positions, new_positions))
        return model_kwargs


    def _sample(
        self,
        input_ids,
        logits_processor,
        stopping_criteria,
        generation_config,
        synced_gpus,
        streamer,
        **model_kwargs,
    ):
        r"""
        Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
        can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.

        Parameters:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                The sequence used as a prompt for the generation.
            logits_processor (`LogitsProcessorList`):
                An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
                used to modify the prediction scores of the language modeling head applied at each generation step.
            stopping_criteria (`StoppingCriteriaList`):
                An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
                used to tell if the generation loop should stop.
            generation_config ([`~generation.GenerationConfig`]):
                The generation configuration to be used as parametrization of the decoding method.
            synced_gpus (`bool`):
                Whether to continue running the while loop until max_length (needed to avoid deadlocking with
                `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
            streamer (`BaseStreamer`, *optional*):
                Streamer object that will be used to stream the generated sequences. Generated tokens are passed
                through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
            model_kwargs:
                Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
                an encoder-decoder model the kwargs should include `encoder_outputs`.

        Return:
            [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
            A `torch.LongTensor` containing the generated tokens (default behaviour) or a
            [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
            `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
            `model.config.is_encoder_decoder=True`.
        """
        # init values
        pad_token_id = generation_config._pad_token_tensor
        output_attentions = generation_config.output_attentions
        output_hidden_states = generation_config.output_hidden_states
        output_scores = generation_config.output_scores
        output_logits = generation_config.output_logits
        return_dict_in_generate = generation_config.return_dict_in_generate
        max_length = generation_config.max_length
        has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
        do_sample = generation_config.do_sample

        # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        raw_logits = () if (return_dict_in_generate and output_logits) else None
        decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
        cross_attentions = () if (return_dict_in_generate and output_attentions) else None
        decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None

        # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
        if return_dict_in_generate and self.config.is_encoder_decoder:
            encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
            encoder_hidden_states = (
                model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
            )

        # keep track of which sequences are already finished
        batch_size, cur_len = input_ids.shape
        this_peer_finished = False
        unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
        model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)

        model_forward = self.__call__
        if isinstance(model_kwargs.get("past_key_values"), StaticCache):
            if self.device.type == "cuda":
                logger.warning_once("Using `torch.compile`.")
                os.environ["TOKENIZERS_PARALLELISM"] = "0"
                model_forward = self.get_compiled_call(generation_config.compile_config)

        is_prefill = True
        while self._has_unfinished_sequences(
            this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length
        ):
            # prepare model inputs
            model_inputs, origin_input_ids, origin_pixel_values, origin_image_grid_thw= self.prepare_inputs_for_generation(input_ids, **model_kwargs)

            # prepare variable output controls (note: some models won't accept all output controls)
            model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
            model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})

            if is_prefill:
                outputs = self(**model_inputs, return_dict=True)
                is_prefill = False
            else:
                outputs = model_forward(**model_inputs, return_dict=True)

            num_new_tokens = 1 if origin_input_ids.shape[1]==input_ids.shape[1] else (origin_input_ids.shape[1]-input_ids.shape[1]+1)
            # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs,
                model_kwargs,
                is_encoder_decoder=self.config.is_encoder_decoder,
                num_new_tokens=num_new_tokens,
            )
            if synced_gpus and this_peer_finished:
                continue

            # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
            # (the clone itself is always small)
            next_token_logits = outputs.logits[:, -1, :].clone().float()
            next_token_logits = next_token_logits.to(input_ids.device)

            # pre-process distribution
            next_token_scores = logits_processor(input_ids, next_token_logits)

            # Store scores, attentions and hidden_states when required
            if return_dict_in_generate:
                if output_scores:
                    scores += (next_token_scores,)
                if output_logits:
                    raw_logits += (next_token_logits,)
                if output_attentions:
                    decoder_attentions += (
                        (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
                    )
                    if self.config.is_encoder_decoder:
                        cross_attentions += (outputs.cross_attentions,)

                if output_hidden_states:
                    decoder_hidden_states += (
                        (outputs.decoder_hidden_states,)
                        if self.config.is_encoder_decoder
                        else (outputs.hidden_states,)
                    )

            # token selection
            if do_sample:
                probs = nn.functional.softmax(next_token_scores, dim=-1)
                # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
                next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
            else:
                next_tokens = torch.argmax(next_token_scores, dim=-1)

            # finished sentences should have their next token be a padding token
            if has_eos_stopping_criteria:
                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

            # update generated ids, model inputs, and length for next step
            input_ids = torch.cat([origin_input_ids, next_tokens[:, None]], dim=-1)
            model_kwargs["pixel_values"] = origin_pixel_values
            model_kwargs["image_grid_thw"] = origin_image_grid_thw

            if streamer is not None:
                streamer.put(next_tokens.cpu())

            unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
            this_peer_finished = unfinished_sequences.max() == 0
            cur_len += 1

            # This is needed to properly delete outputs.logits which may be very large for first iteration
            # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
            del outputs

        if streamer is not None:
            streamer.end()

        if return_dict_in_generate:
            if self.config.is_encoder_decoder:
                return GenerateEncoderDecoderOutput(
                    sequences=input_ids,
                    scores=scores,
                    logits=raw_logits,
                    encoder_attentions=encoder_attentions,
                    encoder_hidden_states=encoder_hidden_states,
                    decoder_attentions=decoder_attentions,
                    cross_attentions=cross_attentions,
                    decoder_hidden_states=decoder_hidden_states,
                    past_key_values=model_kwargs.get("past_key_values"),
                )
            else:
                return GenerateDecoderOnlyOutput(
                    sequences=input_ids,
                    scores=scores,
                    logits=raw_logits,
                    attentions=decoder_attentions,
                    hidden_states=decoder_hidden_states,
                    past_key_values=model_kwargs.get("past_key_values"),
                )
        else:
            return input_ids


# Subclassing the model, if necessary
class CustomQwen2VLForConditionalGeneration(Qwen2VLForConditionalGeneration, CustomGenerationMixin):
    def __init__(self, config):
        # Call the parent class's constructor to initialize all basic components
        super().__init__(config)
        # Custom initialization logic
        self.image_inputs = None  # To hold image inputs
        self.processor = None 
        self.image_processor = None 
        self.image_token = 151655  # Assuming this is a defined constant
        self.vision_start_token = 151652
        self.vision_end_token = 151653
    
    
    def set_image_inputs(self, image_inputs):
        """
        Method to set image inputs outside of the generation call.
        This can be called independently before calling generate.
        """
        self.image_inputs = image_inputs


    def set_processor(self, processor):
        """
        Method to set image inputs outside of the generation call.
        This can be called independently before calling generate.
        """
        self.processor = processor
        self.image_processor = processor.image_processor

    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        inputs_embeds=None,
        cache_position=None,
        position_ids=None,
        use_cache=True,
        pixel_values=None,
        pixel_values_videos=None,
        image_grid_thw=None,
        video_grid_thw=None,
        **kwargs,
    ):
        assert self.image_inputs is not None, "Image inputs must be provided."
        origin_input_ids = input_ids
        origin_pixel_values = pixel_values
        origin_image_grid_thw = image_grid_thw

        # Custom handling to check for the special token </IMG>
        # Assuming self.tokenizer is available and configured.
        if input_ids[:,-1]==151658:
            coor_start_ind = torch.where(input_ids==151648)[1].tolist()[-1]
            decoded_texts = self.processor.batch_decode(input_ids[:, coor_start_ind:], skip_special_tokens=False)
            # Check if any decoded text contains the special token </IMG>
            for decoded_text in decoded_texts:
                # Retrieve coordinates and index
                match = re.search(r"<\|box_start\|\>\((\d+),(\d+)\),\((\d+),(\d+)\)<\|box_end\|\><IMG>(\d+)</IMG>", decoded_text)
                if match:
                    x1, y1, x2, y2, img_index = map(int, match.groups())

                    # Validate image index
                    if 0 <= img_index < len(self.image_inputs):
                        original_image = self.image_inputs[img_index]
                        try:
                            # Validate and normalize coordinates
                            if all(0 <= coord <= 1000 for coord in (x1, y1, x2, y2)) and x1 < x2 and y1 < y2:
                                width, height = original_image.size
                                x1, y1, x2, y2 = [
                                    int(coord / 1000 * (width if i % 2 == 0 else height))
                                    for i, coord in enumerate((x1, y1, x2, y2))
                                ]
                                # Crop the image
                                cropped_image = original_image.crop((x1, y1, x2, y2))
                            else:
                                raise ValueError(f"Invalid coordinates ({x1}, {y1}, {x2}, {y2})")
                        except Exception as e:
                            print(f"Error in processing image with index {img_index}: {e}")
                            cropped_image = original_image  # Fallback to original image

                        # Process cropped image
                        new_image_inputs = self.image_processor(images=cropped_image, return_tensors='pt')
                        new_pixel_values = torch.tensor(new_image_inputs["pixel_values"], device=input_ids.device)
                        new_image_grid_thw = torch.tensor(new_image_inputs["image_grid_thw"], device=input_ids.device)
                    else:
                        print(f"Image index {img_index} out of range for {len(self.image_inputs)} images.")

                    # Append vision tokens
                    num_image_tokens = new_image_grid_thw.prod() // 4
                    vision_token_ids = [self.vision_start_token] + [self.image_token] * num_image_tokens + [self.vision_end_token]
                    vision_token_ids_tensor = torch.tensor(vision_token_ids, device=input_ids.device)

                    input_ids = torch.cat([input_ids, vision_token_ids_tensor.unsqueeze(0)], dim=1)  # Assuming batch size of 1
                    origin_input_ids = input_ids

                    cache_position = torch.arange(origin_input_ids.shape[1], device=origin_input_ids.device)

                    # Extend the attention mask
                    if attention_mask is not None:
                        # attention_mask = torch.ones((1, origin_input_ids.shape[1]), device=origin_input_ids.device)
                        attention_mask = attention_mask.new_ones((1, origin_input_ids.shape[1]))

                    # # Reset past_key_values
                    past_key_values = DynamicCache()

                    # Combine with original pixel values if needed
                    origin_pixel_values = torch.cat([origin_pixel_values, new_pixel_values], dim=0)
                    origin_image_grid_thw = torch.cat([origin_image_grid_thw, new_image_grid_thw], dim=0)


        # The rest of your prepare_inputs_for_generation logic
        model_inputs = super().prepare_inputs_for_generation(
            origin_input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            cache_position=cache_position,
            position_ids=position_ids,
            use_cache=use_cache,
            pixel_values=origin_pixel_values,  # use the new pixel values
            pixel_values_videos=pixel_values_videos,
            image_grid_thw=origin_image_grid_thw,  # use the new grid thw
            video_grid_thw=video_grid_thw,
            **kwargs,
        )
        return model_inputs, origin_input_ids, origin_pixel_values, origin_image_grid_thw

model_path = "/local/path/to/model"
model = CustomQwen2VLForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.bfloat16, attn_implementation="sdpa", device_map="auto")
processor = AutoProcessor.from_pretrained(model_path)

def count_model_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Count and print the number of parameters
num_parameters = count_model_parameters(model)
print(f"Number of model parameters: {num_parameters}")


  from .autonotebook import tqdm as notebook_tqdm
`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the `config` argument. All other arguments will be removed in v4.46
Loading checkpoint shards: 100%|██████████| 4/4 [00:07<00:00,  1.76s/it]


Number of model parameters: 8291375616


In [2]:

messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": "/mnt/workspace/xiaoxi/code/data/stone/xh_40-img0.png",
            },
            {
                "type": "image",
                "image": "/mnt/workspace/xiaoxi/code/data/stone/xh_40-img1.png",
            },
            {
                "type": "image",
                "image": "/mnt/workspace/xiaoxi/code/data/stone/xh_40-img2.png",
            },
            {
                "type": "text",
                "text": "Answer the following multiple-choice question: if there is wind or some external force implemented on the object, which one is most likely to collapse? Options:(A) image 1 (B) image 2 (C) image 3",
            },
        ],
    }
]


In [3]:


# Preparation for inference
text = processor.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)

# Set image inputs separately
model.set_image_inputs(image_inputs)
model.set_processor(processor)


inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)
inputs = inputs.to("cuda")

# Inference: Generation of the output
generated_ids = model.generate(**inputs, 
                                max_new_tokens=8192,
                                temperature=0.8,  # Increased from 0.7
                                top_k=50,         # Added top_k sampling
                                top_p=0.95,       # Added nucleus sampling
                                do_sample=True ,   # Enable sampling
                                repetition_penalty=1.05)

# generated_ids = model.generate(**inputs, max_new_tokens=8192)
generated_ids_trimmed = [
    out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
print(output_text)
output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=False, clean_up_tokenization_spaces=False)
print(output_text)


  new_pixel_values = torch.tensor(new_image_inputs["pixel_values"], device=input_ids.device)
  new_image_grid_thw = torch.tensor(new_image_inputs["image_grid_thw"], device=input_ids.device)


['In image 1, the balanced stones (408,135),(672,792)0 are placed on water, which adds an element of instability due to potential movement or disturbance by wind. Image 2 shows stones (198,334),(765,765)1 stacked firmly on a stable surface. Image 3 has a large stone (180,90),(730,530)2 balanced on a wooden stump (210,420),(850,990)2. The wind or external force would more easily disrupt the water and thus the stones in image 1, making it more likely to collapse. Therefore, the answer is A.']
['In image 1, the balanced stones <|box_start|>(408,135),(672,792)<|box_end|><IMG>0</IMG><|vision_start|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_pad|><|image_p