Skip to content

Add Phi-3.5-vision #36036

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 47 commits into
base: main
Choose a base branch
from
Open

Conversation

Dahlbomii
Copy link
Contributor

@Dahlbomii Dahlbomii commented Feb 4, 2025

Draft PR for now, still need to add tests and convert to modular

cc @Rocketknight1

Fixes #36071

@Rocketknight1
Copy link
Member

Rocketknight1 commented Feb 5, 2025

Looks good so far! The TODO list from here is:

  • Make the transformers imports relative imports so that we stop getting circular import errors in the CI
  • Add the phi3_5.md doc (you can copy the layout from a similar model and copy the text from the model card)
  • pip install transformers[quality] and make fixup or make style to get the repo consistency checks green

Once everything is green, then:

  • Add tests (you can copy them from another similar VLM and just change model names etc.)

Once the new tests are green as well, then:

  • Convert the modeling and configuration files to modular, then regenerating the modeling/config files and confirm everything passes

At that point, the PR should be finished!

@Rocketknight1 Rocketknight1 mentioned this pull request Feb 13, 2025
2 tasks
@ArthurZucker
Copy link
Collaborator

Fixes #36166!

@Dahlbomii
Copy link
Contributor Author

I've run into an issue when trying to run the model. From Modeling_phi3_v.py I'm getting "AttributeError: 'DynamicCache' object has no attribute 'get_max_length'" This might be due to a change to the DynamicCache class, but I really can't tell!

@zucchini-nlp
Copy link
Member

@Dahlbomii hey! Yes, we deprecated get_max_length from cache a few versions ago, now it has to be Cache.get_max_cache_shape(layer_idx)

@Rocketknight1
Copy link
Member

Woah - this means the remote code version of Phi-3.5-vision-instruct is broken on main right now and requires a fixed older version of transformers. Definitely makes this PR more urgent/important!

@Rocketknight1
Copy link
Member

cc @zucchini-nlp @gante I can't figure this one out either! The test_generate_continue_from_inputs_embeds test is failing for Phi-3V but I don't really understand how generate() handles inputs_embeds - when I added some breakpoints it seems to generate an input_ids tensor with shape (batch_size, 0), which then causes errors in Phi.

@zucchini-nlp
Copy link
Member

zucchini-nlp commented Feb 28, 2025

Seems like Phi-3.5V was not updated for the latest changes we had in generation, probably it has its own custom generation loop. For this test, we can call super().prepare_input_for_generation(kwargs) to get all the inputs processed correctly. Then image related kwargs can be set to None if needed. You can take a look on how we did it in other VLMs, hope it solves the issue

@Rocketknight1
Copy link
Member

@zucchini-nlp yes, that was it, sorry! They don't override generate() but they do override prepare_input_for_generation() and discard inputs_embeds after the first step.

@Dahlbomii
Copy link
Contributor Author

@Rocketknight1 for reasons that escape me, when the test processor tries to call self.get_component for the tokenizer, it breaks. Any insight as to why?

@Dahlbomii
Copy link
Contributor Author

@Rocketknight1 ALRIGHT I got most of the tests into the green, but I have no idea what I'm doing with the chat template stuff!

@Rocketknight1
Copy link
Member

test_apply_chat_template() isn't really necessary - I think it just copied over from smolvlm! You can remove it.

@Dahlbomii Dahlbomii force-pushed the add_phi3-5_vision branch from 404d2a7 to f4684f7 Compare April 2, 2025 21:48
@Dahlbomii
Copy link
Contributor Author

@Rocketknight1 With that I think it might be ready for review

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congrats on the PR! This looks almost ready, and the modular bit looks like it was really annoying. There are a lot of classes that look like they should be inheritable from somewhere else in the codebase, but they're just different enough that you can't.

cc @zucchini-nlp I made some comments, but you're more familiar with VLMs than me, can you review and see if anything else should be changed before we ping a core maintainer?

Comment on lines +327 to +341
class Phi3RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Phi3RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can just be inherited from a class like T5LayerNorm!

Comment on lines +356 to +383
class Phi3RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()

self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.register_buffer("inv_freq", None, persistent=False)

@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if self.inv_freq is None:
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although the code is a little different, I think this could be inherited from another RotaryEmbedding class in the library, with the same outputs. Maybe open_llama?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is pretty much same as phi3 or phi3-MoE. We also support rope scaling and dynamic rope with unified API, adding a decorator will do the thing

Comment on lines +468 to +484
class Phi3MLP(nn.Module):
def __init__(self, config):
super().__init__()

self.config = config
self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)

self.activation_fn = ACT2FN[config.hidden_act]

def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
up_states = self.gate_up_proj(hidden_states)

gate, up_states = up_states.chunk(2, dim=-1)
up_states = up_states * self.activation_fn(gate)

return self.down_proj(up_states)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could be inherited from Phi3!

return self.down_proj(up_states)


class Phi3Attention(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although the code is a little different, I think you could maybe inherit the layer from Phi3 and not need this code here! (But you'd have to test to be sure)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, seems pretty much same. The qkv can be split in conversion script if needed and we can inherit from Phi3, which also adds new attention interface (easier TGI, vLLM integrations)

}


class Phi3DecoderLayer(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although the code is a little different, I think this could inherit from the equivalent class in Phi3 without output changes (maybe!)

_supports_sdpa = False
_supports_cache_class = True

_version = "0.0.5"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_version = "0.0.5"

Probably unnecessary!

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Dahlbomii thanks a lot for working on this!

The PR seems to be adapted from custom code in the hub mostly, and I realize that custom code is outdated and doesn't follow transformers standards. Would be nice to do little bit more clean up before merging, by unifying text and vision backbones as AutoModel (I guess it is identical to CLIP and Phi3) and leave only the Base/ConditionalLM multimodal class in modular

I left a few comments below on how we can do that. LMK if that makes sense

@@ -0,0 +1,59 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: 2025

Comment on lines +18 to +20



Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add the abstract and finalize docs before pinging core maintainer

@@ -574,6 +575,7 @@
("persimmon", "PersimmonForCausalLM"),
("phi", "PhiForCausalLM"),
("phi3", "Phi3ForCausalLM"),
("phi3_v", "Phi3VForCausalLM"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

has to be in ImageTextToText mapping when it works with "image+text". AutoCausalLM is currently reserved for text modality only

Comment on lines +356 to +383
class Phi3RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()

self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.register_buffer("inv_freq", None, persistent=False)

@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if self.inv_freq is None:
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is pretty much same as phi3 or phi3-MoE. We also support rope scaling and dynamic rope with unified API, adding a decorator will do the thing

Comment on lines +524 to +538
def _init_rope(self):
if self.rope_scaling is None:
self.rotary_emb = Phi3RotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling["type"]
if scaling_type == "su":
self.rotary_emb = Phi3SuScaledRotaryEmbedding(self.head_dim, self.config)
elif scaling_type == "yarn":
self.rotary_emb = Phi3YarnScaledRotaryEmbedding(self.head_dim, self.config)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we init RoPE once per model in base class, this is an old way and has to be removed

Comment on lines +65 to +71
images: ImageInput = None,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length=None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
add_special_tokens: bool = True,
) -> BatchFeature:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add kwargs with standard processing API. For ex in llava

class LlavaNextProcessorKwargs(ProcessingKwargs, total=False):
_defaults = {
"text_kwargs": {
"padding": False,
},
"images_kwargs": {
"do_pad": True,
},
}

Comment on lines +132 to +150
def calc_num_image_tokens(self, images: ImageInput):
"""Calculate the number of image tokens for each image.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
"""
return self.image_processor.calc_num_image_tokens(images)

def calc_num_image_tokens_from_image_size(self, width, height):
"""Calculate the number of image token for an image with given width and height.
Args:
width (`int`):
Width of the image.
height (`int`):
Height of the image.
"""
return self.image_processor.calc_num_image_tokens_from_image_size(width, height)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move these all from image processor to processor?

Comment on lines +212 to +213
input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0)
attention_mask = (input_ids > -1000000).to(torch.long)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inputs are converted to tensor only if asked by users, so we have to put it in Batchfeature and let it handle all type conversion

Comment on lines +171 to +197

pattern = r"<\|image_\d+\|>"
prompt_chunks = [self.tokenizer(chunk).input_ids for chunk in re.split(pattern, texts)]

if "num_img_tokens" in images:
num_img_tokens = images["num_img_tokens"]
else:
assert "num_crops" in images, "num_crops must be provided in images if num_img_tokens is not provided"
num_crops = images["num_crops"]
num_img_tokens = [_num_crops * self.num_img_tokens for _num_crops in num_crops]

images, image_sizes = images["pixel_values"], images["image_sizes"]

# image_tags needs to start from 1 to n
image_tags = re.findall(pattern, texts)
# image_ids = [int(s.split("|")[1].split("_")[-1]) * -1 for s in image_tags]
# image_ids_pad = [[iid]*num_img_tokens[i] for i, iid in enumerate(image_ids)]
image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
unique_image_ids = sorted(set(image_ids))
# image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be [1, 4, 5]
# check the condition
assert unique_image_ids == list(range(1, len(unique_image_ids) + 1)), (
f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
)
# total images must be the same as the number of image tags
assert len(unique_image_ids) == len(images), (
f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(images)} images"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks too complicated! From what I see adding special number on image tokens is not needed, since what happens here is

text = "<|image|> What is this?"
text_after_separator = " <placeholder> <placeholder> [<placeholder> .... ] What is this?""

This is same as in many other VLMs, and the simplest case is LLaVA. Let's clean up a bit and make nececssary changes to processor config when needed

token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pixel_values have to be in inputs, so we test same way as model will be used by users

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

modeling_phi3 errors with AttributeError: 'DynamicCache' object has no attribute 'get_max_length'
4 participants