-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MLCD model #36182
base: main
Are you sure you want to change the base?
Add MLCD model #36182
Conversation
Hi, Pavel @qubvel. Hope this message finds you well. It appears that our PR failed on the final step during CI. From the error messages, it looks like the issue stems from the Rt-Detr tests rather than problems within our code. Could you please guide us on how to disable or skip the tests for other models so we can successfully complete the CI process for our PR? Thanks! |
Hi Pavel and Arthur, @qubvel @Rocketknight1 : Could you please take a moment to review the pull request? Your insights would be immensely appreciated and would greatly contribute to ensuring the quality of the changes. We're truly grateful for your help! Thank you so much! |
Gentle ping @qubvel, but let me know if you want me to take any part of the review! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @tanhuajie! Sorry for the delay and thanks a lot for working on the model addition to transformers
, great work, and already looks super clean!
I see the model is built on pretty standard modules, so it would be incredibly helpful if you could reuse library modules with inheritance using our new modular
tool.
- https://huggingface.co/docs/transformers/modular_transformers
- see other models such as gemma/ijepa/siglip2 (
modular_*.py
file)
See other comments below!
@@ -0,0 +1,383 @@ | |||
# coding=utf-8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
File to remove?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this file is redundant and has been removed.
def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: | ||
"""Applies Rotary Position Embedding to the given tensor for vision-related tasks. | ||
Args: | ||
tensor (torch.Tensor): The input tensor with shape (..., num_channels). | ||
freqs (torch.Tensor): The frequency matrix computed from rotary embeddings, | ||
typically obtained from MLCDRotaryEmbedding. | ||
Returns: | ||
torch.Tensor: The transformed tensor after applying rotary positional embeddings. | ||
""" | ||
orig_dtype = tensor.dtype | ||
tensor = tensor.float() | ||
cos = freqs.cos() | ||
sin = freqs.sin() | ||
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() | ||
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float() | ||
output = (tensor * cos) + (rotate_half(tensor) * sin) | ||
output = output.to(orig_dtype) | ||
return output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's reuse the same function as in other modeling files
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this suggestion. The file modeling_mlcd.py
has been refactored and regenerated using the modular
tool (Nice Tool!!!). This function has been reused now.
batch_size, seq_length, hidden_size = hidden_states.size() | ||
# Each of shape: [batch_size, seq_length, num_heads, head_dim] | ||
q = self.q_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim)) | ||
k = self.k_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim)) | ||
v = self.v_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim)) | ||
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) | ||
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) | ||
q = q.permute(0, 2, 1, 3).contiguous() | ||
k = k.permute(0, 2, 1, 3).contiguous() | ||
v = v.permute(0, 2, 1, 3).contiguous() | ||
# q (batch_size, num_heads, seq_length, head_dim) | ||
# k (batch_size, num_heads, seq_length, head_dim) | ||
# v (batch_size, num_heads, seq_length, head_dim) | ||
attn_output = F.scaled_dot_product_attention(q, k, v, None, dropout_p=0.0) | ||
attn_output = attn_output.permute(2, 0, 1, 3).contiguous() # [seq_length, batch_size, num_heads, head_dim] | ||
attn_output = attn_output.view(seq_length, batch_size, -1) # [seq_length, batch_size, embedding_dim] | ||
attn_output = self.out_proj(attn_output) | ||
attn_output = attn_output.permute(1, 0, 2).contiguous() # [batch_size, seq_length, embedding_dim] | ||
return attn_output, None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's reuse library patterns as much as possible, e.g. see LlamaAttention
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, this pattern has also been reused now.
def rot_pos_emb(self, grid_thw): | ||
pos_ids = [] | ||
for t, h, w in grid_thw: | ||
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) | ||
hpos_ids = hpos_ids.reshape(h, 1, w, 1) | ||
hpos_ids = hpos_ids.permute(0, 2, 1, 3) | ||
hpos_ids = hpos_ids.flatten() | ||
|
||
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) | ||
wpos_ids = wpos_ids.reshape(h, 1, w, 1) | ||
wpos_ids = wpos_ids.permute(0, 2, 1, 3) | ||
wpos_ids = wpos_ids.flatten() | ||
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) | ||
pos_ids = torch.cat(pos_ids, dim=0) | ||
max_grid_size = grid_thw[:, 1:].max() | ||
rotary_pos_emb_full = self.vision_rotary_embedding(max_grid_size) | ||
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) | ||
return rotary_pos_emb |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Should we put it into
vision_rotary_embedding
module? - Can we avoid cycle? Is there any way to vectorize it, otherwise we can use
compile_compatible_method_lru_cache
(see rt_detr) to cache position_ids creation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- You are right. I have already put it into
MLCDRotaryEmbedding
module. - Completed. Vectorize the process to avoid cycle !!
def forward( | ||
self, | ||
pixel_values: Optional[torch.FloatTensor] = None, | ||
# output_attentions: Optional[bool] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no commented code please, we should support this argument
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. This argument has been supported in the newest commit.
@@ -0,0 +1,520 @@ | |||
# coding=utf-8 | |||
# Copyright 2024 Mistral and the HuggingFace Inc. team. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Everywhere
# Copyright 2024 Mistral and the HuggingFace Inc. team. All rights reserved. | |
# Copyright 2025 Mistral and the HuggingFace Inc. team. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
|
||
|
||
@require_torch | ||
class MLCDVisionModelModelTest(ModelTesterMixin, unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class MLCDVisionModelModelTest(ModelTesterMixin, unittest.TestCase): | |
class MLCDVisionModelTest(ModelTesterMixin, unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Fixed now.
@@ -0,0 +1,152 @@ | |||
# coding=utf-8 | |||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | |
# Copyright 2025 The HuggingFace Inc. team. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add Integration tests as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. Integration tests for MLCD
have been added in test_modeling_mlcd.py
from .original_vit_rope2d import RoPE2d_ViT_bigG_14_1024 | ||
|
||
|
||
def copy_attn_layer(hf_attn_layer, pt_attn_layer): | ||
# self.in_proj = nn.Linear(dim, dim * 3, bias=True) | ||
# self.out_proj = nn.Linear(dim, dim) | ||
|
||
q_proj, k_proj, v_proj = pt_attn_layer.in_proj.weight.data.chunk(3, dim=0) | ||
q_proj_bias, k_proj_bias, v_proj_bias = pt_attn_layer.in_proj.bias.data.chunk(3, dim=0) | ||
|
||
out_proj_weights = pt_attn_layer.out_proj.weight | ||
out_proj_bias = pt_attn_layer.out_proj.bias | ||
|
||
hf_attn_layer.q_proj.weight.data = q_proj |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see other conversion scripts to structure your code (I would recommend mllama
, because it follows out new standards for conversion).
The general idea is that we do not use 3rd party model to convert weights, instead, we should
- load state dict
- rename/split/transpose weights
- load new state dict into HF model to check everything match
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for reminder. convert_mlcd_weights_to_hf.py
has been refactored with reference to both mllama
and siglip2
.
Hi @qubvel! Thanks for your suggestions. I've made necessary modifications based on your advice and think it's ready for your review again. Here's a short summary of what we've done so far by this moment:
Looking forward to your feedback on these updates. Much thanks!!! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @tanhuajie, thanks for addressing comments and using modular! I left a few more comments on how to use inheritance a bit more. Other than that looks pretty clean, great work 🤗
attn_output = F.scaled_dot_product_attention( | ||
query_states, | ||
key_states, | ||
value_states, | ||
attn_mask=attention_mask, | ||
dropout_p=self.dropout if self.training else 0.0, | ||
) | ||
attn_output = attn_output.permute(2, 0, 1, 3).contiguous() # [seq_length, batch_size, num_heads, head_dim] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see Llama model for the new attention interface. In short we should have one module + be able to switch between different attention implementations (spda/eager/fa2).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. Attention interface has been refactored in the latest commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @tanhuajie, It looks like attention interface was not refactored, we should use
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
class MLCDEncoderLayer(nn.Module): | ||
def __init__(self, config: MLCDVisionConfig): | ||
super().__init__() | ||
self.embed_dim = config.hidden_size | ||
self.self_attn = MLCDSdpaAttention(config) | ||
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) | ||
self.mlp = MLCDMLP(config) | ||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) | ||
|
||
def forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Init is similar to CLIP/Siglip encoder layer, right? Please re-use in modular.
class MLCDEncoderLayer(nn.Module):
def __init__(self, config: MLCDVisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = MLCDSdpaAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = MLCDMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
-->
class MLCDEncoderLayer(CLIPEncoderLayer):
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. But I notice this modification would cause the config
parameter in modeling_mlcd.py
to be automatically generated as MLCDConfig
by the modular
tool, rather than the desired MLCDVisionConfig
. Therefore, I ultimately made the following changes instead:
class MLCDEncoderLayer(CLIPEncoderLayer):
def __init__(self, config: MLCDVisionConfig):
super().__init__(config)
twh = (1, pixel_values.size(3) // self.config.patch_size, pixel_values.size(2) // self.config.patch_size) | ||
rotary_pos_emb = self.vision_rotary_embedding(torch.tensor([twh], device=pixel_values.device)) | ||
rotary_pos_emb = torch.cat([self.class_pos_emb, rotary_pos_emb], dim=0) | ||
|
||
output_hidden_states = ( | ||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | ||
) | ||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes the order
twh = (1, pixel_values.size(3) // self.config.patch_size, pixel_values.size(2) // self.config.patch_size) | |
rotary_pos_emb = self.vision_rotary_embedding(torch.tensor([twh], device=pixel_values.device)) | |
rotary_pos_emb = torch.cat([self.class_pos_emb, rotary_pos_emb], dim=0) | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
twh = (1, pixel_values.size(3) // self.config.patch_size, pixel_values.size(2) // self.config.patch_size) | |
rotary_pos_emb = self.vision_rotary_embedding(torch.tensor([twh], device=pixel_values.device)) | |
rotary_pos_emb = torch.cat([self.class_pos_emb, rotary_pos_emb], dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Completed !
class MLCDVisionModel(MLCDPreTrainedModel): | ||
config_class = MLCDVisionConfig | ||
main_input_name = "pixel_values" | ||
_no_split_modules = ["MLCDEncoderLayer"] | ||
|
||
def __init__(self, config: MLCDVisionConfig): | ||
super().__init__(config) | ||
self.vision_model = MLCDVisionTransformer(config) | ||
# Initialize weights and apply final processing | ||
self.post_init() | ||
|
||
def get_input_embeddings(self) -> nn.Module: | ||
return self.vision_model.embeddings.patch_embedding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class MLCDVisionModel(MLCDPreTrainedModel): | |
config_class = MLCDVisionConfig | |
main_input_name = "pixel_values" | |
_no_split_modules = ["MLCDEncoderLayer"] | |
def __init__(self, config: MLCDVisionConfig): | |
super().__init__(config) | |
self.vision_model = MLCDVisionTransformer(config) | |
# Initialize weights and apply final processing | |
self.post_init() | |
def get_input_embeddings(self) -> nn.Module: | |
return self.vision_model.embeddings.patch_embedding | |
class MLCDVisionModel(CLIPVisionModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Completed.
class MLCDVisionTransformer(nn.Module): | ||
def __init__(self, config: MLCDVisionConfig): | ||
super().__init__() | ||
self.config = config | ||
embed_dim = config.hidden_size | ||
|
||
self.embeddings = MLCDVisionEmbeddings(config) | ||
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) | ||
self.encoder = MLCDEncoder(config) | ||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) | ||
|
||
self.vision_rotary_embedding = MLCDRotaryEmbedding(config.hidden_size // config.num_attention_heads // 2) | ||
self.class_pos_emb = nn.Parameter(torch.randn(1, config.hidden_size // config.num_attention_heads // 2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class MLCDVisionTransformer(nn.Module): | |
def __init__(self, config: MLCDVisionConfig): | |
super().__init__() | |
self.config = config | |
embed_dim = config.hidden_size | |
self.embeddings = MLCDVisionEmbeddings(config) | |
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) | |
self.encoder = MLCDEncoder(config) | |
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) | |
self.vision_rotary_embedding = MLCDRotaryEmbedding(config.hidden_size // config.num_attention_heads // 2) | |
self.class_pos_emb = nn.Parameter(torch.randn(1, config.hidden_size // config.num_attention_heads // 2)) | |
class MLCDVisionTransformer(CLIPVisionTransformer): | |
def __init__(self, config: MLCDVisionConfig): | |
super().__init__(config) | |
self.vision_rotary_embedding = MLCDRotaryEmbedding(config.hidden_size // config.num_attention_heads // 2) | |
self.class_pos_emb = nn.Parameter(torch.randn(1, config.hidden_size // config.num_attention_heads // 2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Completed.
class MLCDEncoder(nn.Module): | ||
""" | ||
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a | ||
[`MLCDEncoderLayer`]. | ||
Args: | ||
config: MLCDVisionConfig | ||
""" | ||
|
||
def __init__(self, config: MLCDVisionConfig): | ||
super().__init__() | ||
self.config = config | ||
self.layers = nn.ModuleList([MLCDEncoderLayer(config) for _ in range(config.num_hidden_layers)]) | ||
self.gradient_checkpointing = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class MLCDEncoder(nn.Module): | |
""" | |
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a | |
[`MLCDEncoderLayer`]. | |
Args: | |
config: MLCDVisionConfig | |
""" | |
def __init__(self, config: MLCDVisionConfig): | |
super().__init__() | |
self.config = config | |
self.layers = nn.ModuleList([MLCDEncoderLayer(config) for _ in range(config.num_hidden_layers)]) | |
self.gradient_checkpointing = False | |
class MLCDEncoder(CLIPEncoder): | |
class MLCDVisionEmbeddings(nn.Module): | ||
def __init__(self, config: MLCDVisionConfig): | ||
super().__init__() | ||
self.config = config | ||
self.embed_dim = config.hidden_size | ||
self.image_size = config.image_size | ||
self.patch_size = config.patch_size | ||
|
||
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) | ||
|
||
self.patch_embedding = nn.Conv2d( | ||
in_channels=config.num_channels, | ||
out_channels=self.embed_dim, | ||
kernel_size=self.patch_size, | ||
stride=self.patch_size, | ||
bias=False, | ||
) | ||
|
||
self.num_patches = (self.image_size // self.patch_size) ** 2 | ||
self.num_positions = self.num_patches + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class MLCDVisionEmbeddings(nn.Module): | |
def __init__(self, config: MLCDVisionConfig): | |
super().__init__() | |
self.config = config | |
self.embed_dim = config.hidden_size | |
self.image_size = config.image_size | |
self.patch_size = config.patch_size | |
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) | |
self.patch_embedding = nn.Conv2d( | |
in_channels=config.num_channels, | |
out_channels=self.embed_dim, | |
kernel_size=self.patch_size, | |
stride=self.patch_size, | |
bias=False, | |
) | |
self.num_patches = (self.image_size // self.patch_size) ** 2 | |
self.num_positions = self.num_patches + 1 | |
class MLCDVisionEmbeddings(CLIPVisionEmbeddings): | |
def __init__(self, config: MLCDVisionConfig): | |
super().__init__(config) | |
del self.position_embedding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Completed.
Hi @qubvel! Thank you for your further suggestions. Based on your latest comments, I have made additional modifications to |
Hi @qubvel, just a gentle reminder to review the latest commits. Your feedback is incredibly valuable, and I’d appreciate it if you could take a look when you have a moment. Thanks! 🤗 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing the comments! I think we are close to make it done, added a few more comments, please have a look!
attn_output = F.scaled_dot_product_attention( | ||
query_states, | ||
key_states, | ||
value_states, | ||
attn_mask=attention_mask, | ||
dropout_p=self.dropout if self.training else 0.0, | ||
) | ||
attn_output = attn_output.permute(2, 0, 1, 3).contiguous() # [seq_length, batch_size, num_heads, head_dim] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @tanhuajie, It looks like attention interface was not refactored, we should use
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
def __init__(self, config: MLCDVisionConfig): | ||
super().__init__(config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def __init__(self, config: MLCDVisionConfig): | |
super().__init__(config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right. I misunderstood it earlier. The latest commit has refactored the attention interface based on your suggestions!
def forward(self, grid_thw: torch.Tensor) -> torch.Tensor: | ||
"""Calculate sequence length from grid, and then get the RoPE for MLCDVisionModel""" | ||
|
||
t, h, w = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need t
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I see it's always 1, isn't it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I have adjusted them.
twh = (1, pixel_values.size(3) // self.config.patch_size, pixel_values.size(2) // self.config.patch_size) | ||
rotary_pos_emb = self.vision_rotary_embedding(torch.tensor([twh], device=pixel_values.device)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be better to make it as follows, considering t
is always 1?
twh = (1, pixel_values.size(3) // self.config.patch_size, pixel_values.size(2) // self.config.patch_size) | |
rotary_pos_emb = self.vision_rotary_embedding(torch.tensor([twh], device=pixel_values.device)) | |
num_patches_height = pixel_values.shape[-2] // self.config.patch_size | |
num_patches_width = pixel_values.shape[-1] // self.config.patch_size | |
rotary_pos_emb = self.vision_rotary_embedding(num_patches_height, num_patches_width) |
def forward(self, grid_thw: torch.Tensor) -> torch.Tensor: | ||
"""Calculate sequence length from grid, and then get the RoPE for MLCDVisionModel""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def forward(self, grid_thw: torch.Tensor) -> torch.Tensor: | |
"""Calculate sequence length from grid, and then get the RoPE for MLCDVisionModel""" | |
def forward(self, num_patches_height: int, num_patches_width: int) -> torch.Tensor: | |
"""Calculate sequence length from grid, and then get the RoPE for MLCDVisionModel""" |
test_pruning = False | ||
test_head_masking = False | ||
test_torchscript = False | ||
test_resize_embeddings = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_resize_embeddings = False | |
test_resize_embeddings = False | |
test_torch_exportable = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MLCD
is a vision-only model and doesn't have a vocab_size
that needs to be adjusted, so I think, it's better to keep test_resize_embeddings = False
.
Hey @qubvel, Thanks for your comments! I have refactored the attention interface and adjusted the code based on your suggestions. I think it's ready for your review again. Please take a look when you have time! Thanks 🤗 |
Hi @qubvel , Thanks again for your valuable feedback earlier! We’ve refactored the attention interface and made adjustments to the code based on your suggestions. It should be ready for your review now. Whenever you have some time, we’d really appreciate it if you could take a look. Thank you so much! |
Hi @qubvel, hope you're doing well! Just a quick follow-up on the updated interface based on your feedback—would love to hear your thoughts when you have time. Thanks again! |
What does this PR do?
This PR adds MLCD model from DeepGlint-AI Team.
Fixes #36181
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@amyeroberts @qubvel @ArthurZucker
Quick Test