diff --git a/QEfficient/transformers/models/internvl/modeling_internvl.py b/QEfficient/transformers/models/internvl/modeling_internvl.py index c39e7b65d..b5c64c460 100644 --- a/QEfficient/transformers/models/internvl/modeling_internvl.py +++ b/QEfficient/transformers/models/internvl/modeling_internvl.py @@ -29,6 +29,7 @@ def __init__(self, model): super().__init__() self.model = model self.config = self.model.language_model.config + self.language_model = self.model.language_model def forward(self, input_ids, vit_embeds, position_ids, past_key_values): # TODO: Check if Hardcoding this is okay, i.e. check if this value is common for all intern models diff --git a/QEfficient/transformers/models/llava/modeling_llava.py b/QEfficient/transformers/models/llava/modeling_llava.py index 93d6f4c3b..4ce9f087e 100644 --- a/QEfficient/transformers/models/llava/modeling_llava.py +++ b/QEfficient/transformers/models/llava/modeling_llava.py @@ -25,6 +25,7 @@ class QEFFLlavaEncoderWrapper(nn.Module): def __init__(self, model): super().__init__() self.model = model + self.model.vision_model = self.model.vision_tower def forward(self, pixel_values): # Image features @@ -47,6 +48,7 @@ def __init__(self, model): super().__init__() self.model = model self.config = self.model.config + self.language_model = self.model.language_model def forward(self, input_ids, image_features, position_ids, past_key_values): inputs_embeds = self.model.get_input_embeddings()(input_ids)