Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MLCD model #36182

Open
wants to merge 42 commits into
base: main
Choose a base branch
from
Open

Add MLCD model #36182

wants to merge 42 commits into from

Conversation

tanhuajie
Copy link

@tanhuajie tanhuajie commented Feb 13, 2025

What does this PR do?

This PR adds MLCD model from DeepGlint-AI Team.

Fixes #36181

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@amyeroberts @qubvel @ArthurZucker

Quick Test

from transformers import AutoProcessor, MLCDVisionModel
from PIL import Image
import requests

# Load model and processor
model = MLCDVisionModel.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-336")
processor = AutoProcessor.from_pretrained("DeepGlint-AI/mlcd-vit-bigG-patch14-336")

# Process single image
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(images=image, return_tensors="pt")

# Get visual features
outputs = model(**inputs)
features = outputs.last_hidden_state

print(f"Extracted features shape: {features.shape}")

Sorry, something went wrong.

Verified

This commit was signed with the committer’s verified signature.
ndelangen Norbert de Langen
tanhuajie and others added 15 commits February 14, 2025 02:17
@tanhuajie
Copy link
Author

image

Hi, Pavel @qubvel. Hope this message finds you well. It appears that our PR failed on the final step during CI. From the error messages, it looks like the issue stems from the Rt-Detr tests rather than problems within our code. Could you please guide us on how to disable or skip the tests for other models so we can successfully complete the CI process for our PR? Thanks!

@anxiangsir
Copy link

anxiangsir commented Feb 14, 2025

Hi Pavel and Arthur, @qubvel @Rocketknight1 :

Could you please take a moment to review the pull request? Your insights would be immensely appreciated and would greatly contribute to ensuring the quality of the changes. We're truly grateful for your help!

Thank you so much!

@Rocketknight1
Copy link
Member

Gentle ping @qubvel, but let me know if you want me to take any part of the review!

@qubvel qubvel self-requested a review February 26, 2025 11:04
Copy link
Member

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @tanhuajie! Sorry for the delay and thanks a lot for working on the model addition to transformers, great work, and already looks super clean!

I see the model is built on pretty standard modules, so it would be incredibly helpful if you could reuse library modules with inheritance using our new modular tool.

See other comments below!

@@ -0,0 +1,383 @@
# coding=utf-8
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

File to remove?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this file is redundant and has been removed.

Comment on lines 72 to 90
def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
"""Applies Rotary Position Embedding to the given tensor for vision-related tasks.
Args:
tensor (torch.Tensor): The input tensor with shape (..., num_channels).
freqs (torch.Tensor): The frequency matrix computed from rotary embeddings,
typically obtained from MLCDRotaryEmbedding.
Returns:
torch.Tensor: The transformed tensor after applying rotary positional embeddings.
"""
orig_dtype = tensor.dtype
tensor = tensor.float()
cos = freqs.cos()
sin = freqs.sin()
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
output = (tensor * cos) + (rotate_half(tensor) * sin)
output = output.to(orig_dtype)
return output
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's reuse the same function as in other modeling files

def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this suggestion. The file modeling_mlcd.py has been refactored and regenerated using the modular tool (Nice Tool!!!). This function has been reused now.

Comment on lines 174 to 192
batch_size, seq_length, hidden_size = hidden_states.size()
# Each of shape: [batch_size, seq_length, num_heads, head_dim]
q = self.q_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
k = self.k_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
v = self.v_proj(hidden_states).reshape((batch_size, seq_length, self.num_heads, self.head_dim))
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
q = q.permute(0, 2, 1, 3).contiguous()
k = k.permute(0, 2, 1, 3).contiguous()
v = v.permute(0, 2, 1, 3).contiguous()
# q (batch_size, num_heads, seq_length, head_dim)
# k (batch_size, num_heads, seq_length, head_dim)
# v (batch_size, num_heads, seq_length, head_dim)
attn_output = F.scaled_dot_product_attention(q, k, v, None, dropout_p=0.0)
attn_output = attn_output.permute(2, 0, 1, 3).contiguous() # [seq_length, batch_size, num_heads, head_dim]
attn_output = attn_output.view(seq_length, batch_size, -1) # [seq_length, batch_size, embedding_dim]
attn_output = self.out_proj(attn_output)
attn_output = attn_output.permute(1, 0, 2).contiguous() # [batch_size, seq_length, embedding_dim]
return attn_output, None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's reuse library patterns as much as possible, e.g. see LlamaAttention

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, this pattern has also been reused now.

Comment on lines 330 to 347
def rot_pos_emb(self, grid_thw):
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(h, 1, w, 1)
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()

wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(h, 1, w, 1)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.vision_rotary_embedding(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Should we put it into vision_rotary_embedding module?
  2. Can we avoid cycle? Is there any way to vectorize it, otherwise we can use compile_compatible_method_lru_cache (see rt_detr) to cache position_ids creation

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. You are right. I have already put it into MLCDRotaryEmbedding module.
  2. Completed. Vectorize the process to avoid cycle !!

def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
# output_attentions: Optional[bool] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no commented code please, we should support this argument

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. This argument has been supported in the newest commit.

@@ -0,0 +1,520 @@
# coding=utf-8
# Copyright 2024 Mistral and the HuggingFace Inc. team. All rights reserved.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everywhere

Suggested change
# Copyright 2024 Mistral and the HuggingFace Inc. team. All rights reserved.
# Copyright 2025 Mistral and the HuggingFace Inc. team. All rights reserved.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.



@require_torch
class MLCDVisionModelModelTest(ModelTesterMixin, unittest.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class MLCDVisionModelModelTest(ModelTesterMixin, unittest.TestCase):
class MLCDVisionModelTest(ModelTesterMixin, unittest.TestCase):

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Fixed now.

@@ -0,0 +1,152 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add Integration tests as well

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Integration tests for MLCD have been added in test_modeling_mlcd.py

Comment on lines 39 to 52
from .original_vit_rope2d import RoPE2d_ViT_bigG_14_1024


def copy_attn_layer(hf_attn_layer, pt_attn_layer):
# self.in_proj = nn.Linear(dim, dim * 3, bias=True)
# self.out_proj = nn.Linear(dim, dim)

q_proj, k_proj, v_proj = pt_attn_layer.in_proj.weight.data.chunk(3, dim=0)
q_proj_bias, k_proj_bias, v_proj_bias = pt_attn_layer.in_proj.bias.data.chunk(3, dim=0)

out_proj_weights = pt_attn_layer.out_proj.weight
out_proj_bias = pt_attn_layer.out_proj.bias

hf_attn_layer.q_proj.weight.data = q_proj
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see other conversion scripts to structure your code (I would recommend mllama, because it follows out new standards for conversion).

The general idea is that we do not use 3rd party model to convert weights, instead, we should

  1. load state dict
  2. rename/split/transpose weights
  3. load new state dict into HF model to check everything match

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reminder. convert_mlcd_weights_to_hf.py has been refactored with reference to both mllama and siglip2.

tanhuajie and others added 7 commits March 13, 2025 03:14
@tanhuajie
Copy link
Author

Hi @qubvel! Thanks for your suggestions. I've made necessary modifications based on your advice and think it's ready for your review again. Here's a short summary of what we've done so far by this moment:

  1. The modeling_mlcd.py file has been refactored and regenerated using the new modular tool, and we've reused as many existing modules and functions from other modeling files as possible right now.
  2. The conversion script convert_mlcd_weights_to_hf.py has been refactored, following your recommendation to refer to both mllama and siglip2.
  3. The output_attentions argument in MLCDVisionModel is now supported in the newest commit, and some cycle processes have been also vectorized to improve GPU parallel computation.
  4. Additional integration tests for MLCD have been added to test_modeling_mlcd.py.
  5. Fixed some typos.

Looking forward to your feedback on these updates. Much thanks!!!

@tanhuajie tanhuajie requested a review from qubvel March 12, 2025 23:16
Copy link
Member

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @tanhuajie, thanks for addressing comments and using modular! I left a few more comments on how to use inheritance a bit more. Other than that looks pretty clean, great work 🤗

Comment on lines 356 to 363
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
)
attn_output = attn_output.permute(2, 0, 1, 3).contiguous() # [seq_length, batch_size, num_heads, head_dim]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see Llama model for the new attention interface. In short we should have one module + be able to switch between different attention implementations (spda/eager/fa2).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Attention interface has been refactored in the latest commit.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @tanhuajie, It looks like attention interface was not refactored, we should use

attention_interface: Callable = eager_attention_forward
        if self.config._attn_implementation != "eager":
            if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
                logger.warning_once(
                    "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
                    'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
                )
            else:
                attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            **kwargs,
        )

Comment on lines 370 to 379
class MLCDEncoderLayer(nn.Module):
def __init__(self, config: MLCDVisionConfig):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = MLCDSdpaAttention(config)
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = MLCDMLP(config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)

def forward(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Init is similar to CLIP/Siglip encoder layer, right? Please re-use in modular.

class MLCDEncoderLayer(nn.Module):
    def __init__(self, config: MLCDVisionConfig):
        super().__init__()
        self.embed_dim = config.hidden_size
        self.self_attn = MLCDSdpaAttention(config)
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        self.mlp = MLCDMLP(config)
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)

-->

class MLCDEncoderLayer(CLIPEncoderLayer):

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. But I notice this modification would cause the config parameter in modeling_mlcd.py to be automatically generated as MLCDConfig by the modular tool, rather than the desired MLCDVisionConfig. Therefore, I ultimately made the following changes instead:

class MLCDEncoderLayer(CLIPEncoderLayer):
    def __init__(self, config: MLCDVisionConfig):
        super().__init__(config)

Comment on lines 566 to 574
twh = (1, pixel_values.size(3) // self.config.patch_size, pixel_values.size(2) // self.config.patch_size)
rotary_pos_emb = self.vision_rotary_embedding(torch.tensor([twh], device=pixel_values.device))
rotary_pos_emb = torch.cat([self.class_pos_emb, rotary_pos_emb], dim=0)

output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes the order

Suggested change
twh = (1, pixel_values.size(3) // self.config.patch_size, pixel_values.size(2) // self.config.patch_size)
rotary_pos_emb = self.vision_rotary_embedding(torch.tensor([twh], device=pixel_values.device))
rotary_pos_emb = torch.cat([self.class_pos_emb, rotary_pos_emb], dim=0)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
twh = (1, pixel_values.size(3) // self.config.patch_size, pixel_values.size(2) // self.config.patch_size)
rotary_pos_emb = self.vision_rotary_embedding(torch.tensor([twh], device=pixel_values.device))
rotary_pos_emb = torch.cat([self.class_pos_emb, rotary_pos_emb], dim=0)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completed !

Comment on lines 666 to 678
class MLCDVisionModel(MLCDPreTrainedModel):
config_class = MLCDVisionConfig
main_input_name = "pixel_values"
_no_split_modules = ["MLCDEncoderLayer"]

def __init__(self, config: MLCDVisionConfig):
super().__init__(config)
self.vision_model = MLCDVisionTransformer(config)
# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class MLCDVisionModel(MLCDPreTrainedModel):
config_class = MLCDVisionConfig
main_input_name = "pixel_values"
_no_split_modules = ["MLCDEncoderLayer"]
def __init__(self, config: MLCDVisionConfig):
super().__init__(config)
self.vision_model = MLCDVisionTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
class MLCDVisionModel(CLIPVisionModel):

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completed.

Comment on lines 544 to 556
class MLCDVisionTransformer(nn.Module):
def __init__(self, config: MLCDVisionConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size

self.embeddings = MLCDVisionEmbeddings(config)
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.encoder = MLCDEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

self.vision_rotary_embedding = MLCDRotaryEmbedding(config.hidden_size // config.num_attention_heads // 2)
self.class_pos_emb = nn.Parameter(torch.randn(1, config.hidden_size // config.num_attention_heads // 2))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class MLCDVisionTransformer(nn.Module):
def __init__(self, config: MLCDVisionConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = MLCDVisionEmbeddings(config)
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.encoder = MLCDEncoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.vision_rotary_embedding = MLCDRotaryEmbedding(config.hidden_size // config.num_attention_heads // 2)
self.class_pos_emb = nn.Parameter(torch.randn(1, config.hidden_size // config.num_attention_heads // 2))
class MLCDVisionTransformer(CLIPVisionTransformer):
def __init__(self, config: MLCDVisionConfig):
super().__init__(config)
self.vision_rotary_embedding = MLCDRotaryEmbedding(config.hidden_size // config.num_attention_heads // 2)
self.class_pos_emb = nn.Parameter(torch.randn(1, config.hidden_size // config.num_attention_heads // 2))

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completed.

Comment on lines 429 to 441
class MLCDEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`MLCDEncoderLayer`].
Args:
config: MLCDVisionConfig
"""

def __init__(self, config: MLCDVisionConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList([MLCDEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class MLCDEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`MLCDEncoderLayer`].
Args:
config: MLCDVisionConfig
"""
def __init__(self, config: MLCDVisionConfig):
super().__init__()
self.config = config
self.layers = nn.ModuleList([MLCDEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
class MLCDEncoder(CLIPEncoder):

Comment on lines 168 to 187
class MLCDVisionEmbeddings(nn.Module):
def __init__(self, config: MLCDVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size

self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))

self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)

self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class MLCDVisionEmbeddings(nn.Module):
def __init__(self, config: MLCDVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
class MLCDVisionEmbeddings(CLIPVisionEmbeddings):
def __init__(self, config: MLCDVisionConfig):
super().__init__(config)
del self.position_embedding

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Completed.

tanhuajie and others added 2 commits March 14, 2025 08:03
@tanhuajie
Copy link
Author

tanhuajie commented Mar 14, 2025

Hi @qubvel! Thank you for your further suggestions. Based on your latest comments, I have made additional modifications to modular_mlcd.py. These changes include improving the reusability of the modular components through class inheritance and refactoring the Attention interface to support different attention implementations. 🤗

@tanhuajie tanhuajie requested a review from qubvel March 14, 2025 00:49
@tanhuajie
Copy link
Author

Hi @qubvel, just a gentle reminder to review the latest commits. Your feedback is incredibly valuable, and I’d appreciate it if you could take a look when you have a moment. Thanks! 🤗

Copy link
Member

@qubvel qubvel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing the comments! I think we are close to make it done, added a few more comments, please have a look!

Comment on lines 356 to 363
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
)
attn_output = attn_output.permute(2, 0, 1, 3).contiguous() # [seq_length, batch_size, num_heads, head_dim]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @tanhuajie, It looks like attention interface was not refactored, we should use

attention_interface: Callable = eager_attention_forward
        if self.config._attn_implementation != "eager":
            if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
                logger.warning_once(
                    "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
                    'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
                )
            else:
                attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            **kwargs,
        )

Comment on lines 501 to 502
def __init__(self, config: MLCDVisionConfig):
super().__init__(config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(self, config: MLCDVisionConfig):
super().__init__(config)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. I misunderstood it earlier. The latest commit has refactored the attention interface based on your suggestions!

def forward(self, grid_thw: torch.Tensor) -> torch.Tensor:
"""Calculate sequence length from grid, and then get the RoPE for MLCDVisionModel"""

t, h, w = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need t?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I see it's always 1, isn't it?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I have adjusted them.

Comment on lines 628 to 629
twh = (1, pixel_values.size(3) // self.config.patch_size, pixel_values.size(2) // self.config.patch_size)
rotary_pos_emb = self.vision_rotary_embedding(torch.tensor([twh], device=pixel_values.device))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to make it as follows, considering t is always 1?

Suggested change
twh = (1, pixel_values.size(3) // self.config.patch_size, pixel_values.size(2) // self.config.patch_size)
rotary_pos_emb = self.vision_rotary_embedding(torch.tensor([twh], device=pixel_values.device))
num_patches_height = pixel_values.shape[-2] // self.config.patch_size
num_patches_width = pixel_values.shape[-1] // self.config.patch_size
rotary_pos_emb = self.vision_rotary_embedding(num_patches_height, num_patches_width)

Comment on lines 152 to 153
def forward(self, grid_thw: torch.Tensor) -> torch.Tensor:
"""Calculate sequence length from grid, and then get the RoPE for MLCDVisionModel"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def forward(self, grid_thw: torch.Tensor) -> torch.Tensor:
"""Calculate sequence length from grid, and then get the RoPE for MLCDVisionModel"""
def forward(self, num_patches_height: int, num_patches_width: int) -> torch.Tensor:
"""Calculate sequence length from grid, and then get the RoPE for MLCDVisionModel"""

test_pruning = False
test_head_masking = False
test_torchscript = False
test_resize_embeddings = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
test_resize_embeddings = False
test_resize_embeddings = False
test_torch_exportable = True

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MLCD is a vision-only model and doesn't have a vocab_size that needs to be adjusted, so I think, it's better to keep test_resize_embeddings = False.

tanhuajie and others added 8 commits March 20, 2025 02:40
@tanhuajie
Copy link
Author

Hey @qubvel, Thanks for your comments! I have refactored the attention interface and adjusted the code based on your suggestions. I think it's ready for your review again. Please take a look when you have time! Thanks 🤗

@tanhuajie tanhuajie requested a review from qubvel March 19, 2025 20:12
@anxiangsir
Copy link

Hi @qubvel ,

Thanks again for your valuable feedback earlier! We’ve refactored the attention interface and made adjustments to the code based on your suggestions. It should be ready for your review now. Whenever you have some time, we’d really appreciate it if you could take a look. Thank you so much!

@anxiangsir
Copy link

Hi @qubvel, hope you're doing well! Just a quick follow-up on the updated interface based on your feedback—would love to hear your thoughts when you have time.

Thanks again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add MLCD Model
6 participants