diff --git a/torchchat/model.py b/torchchat/model.py index b052d112b..a576d5036 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -12,7 +12,10 @@ from enum import Enum from pathlib import Path +import torchvision + from typing import Any, Callable, Dict, Optional, Union +from collections.abc import Hashable import torch import torch.nn as nn @@ -31,22 +34,136 @@ from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder from torchtune.models.llama3_1._component_builders import llama3_1 as llama3_1_builder from torchtune.modules.model_fusion import DeepFusionModel +from torchtune.models.clip import clip_vision_encoder from torchchat.utils.build_utils import find_multiple, get_precision config_path = Path(f"{str(Path(__file__).parent)}/model_params") +class QuickGELUActivation(nn.Module): + """ + Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs + """ + + def forward(self, input): + return input * torch.sigmoid(1.702 * input) + + def identity(**kwargs): if len(kwargs) != 1: raise ValueError("Only one argument is expected") return list(kwargs.values())[0] + +class MultiModalProjector(nn.Module): + def __init__(self, in_channels: int, out_channels: int, act: nn.Module): + super().__init__() + + self.linear_1 = nn.Linear(in_channels, out_channels, bias=True) + self.act = act + self.linear_2 = nn.Linear(out_channels, out_channels, bias=True) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class ConcateFusion(nn.Module): + def __init__( + self, + encoder: nn.Module, + decoder: nn.Module, + token_embedding_name="tok_embeddings", + mm_proj_in_channels=1024, + mm_proj_out_channels=4096, + mm_proj_activation=nn.GELU(), + ): + super().__init__() + self.encoder = encoder + self.decoder = decoder + + # esclate the embedding layer outside decoder llava model need to fuse + # the text and image embedding together before passing to decoder. + self.tok_embeddings = getattr(self.decoder, token_embedding_name) + + # set the embedding layer in decoder to None to jump the embedding layer over in decoder + self.decoder.__setattr__(token_embedding_name, None) + + self.mm_projector = MultiModalProjector( + in_channels=mm_proj_in_channels, + out_channels=mm_proj_out_channels, + act=mm_proj_activation, + ) + + def forward( + self, + tokens: Tensor, + *, + post_tokens: Optional[Tensor] = None, + encoder_input: Optional[Tensor] = None, + encoder_mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + ) -> Tensor: + if encoder_input is not None: + encoder_input = encoder_input.view(1, 1, *encoder_input.shape) + encoder_output = self.encoder(encoder_input) + encoder_output = self._encoder_feature_select(encoder_output) + else: + encoder_output = None + + decoder_input = self._get_decoder_input( + tokens, encoder_output=encoder_output, post_tokens=post_tokens + ) + + if input_pos is None: + input_pos = torch.arange( + decoder_input.shape[1], + device=decoder_input.device, + dtype=torch.int, + ) + + return self.decoder(decoder_input, input_pos=input_pos) + + def setup_caches(self, batch_size, max_seq_len) -> None: + self.decoder.setup_caches(batch_size, max_seq_len) + + def _encoder_feature_select(self, encoder_output) -> Tensor: + selected_image_feature = encoder_output[1][0].view( + *encoder_output[1][0].shape[2:] + ) + + selected_image_feature = selected_image_feature[:, 1:] + return selected_image_feature + + def _get_decoder_input( + self, + tokens: Tensor, + *, + encoder_output: Optional[Tensor], + post_tokens: Optional[Tensor], + ) -> Tensor: + if encoder_output is None: + assert post_tokens is None + return self.tok_embeddings(tokens) + else: + pre_img_embed = self.tok_embeddings(tokens) + image_embeds = self.mm_projector(encoder_output) + if post_tokens is None: + return torch.cat((pre_img_embed, image_embeds), dim=1) + + post_img_embed = self.tok_embeddings(post_tokens) + return torch.cat((pre_img_embed, image_embeds, post_img_embed), dim=1) + + class ModelType(Enum): TextOnly = "text_only" Llama3_1 = "llama3_1" Flamingo = "flamingo" + Llava = "llava" # Type for objects that can generate nn.Module instance @@ -100,16 +217,30 @@ def _flamingo(cls): fusion_class=DeepFusionModel, ) + @classmethod + def _llava(cls): + return cls( + model_type=ModelType.Llava, + modules={ + 'encoder': clip_vision_encoder, + 'decoder': Transformer + }, + fusion_class=ConcateFusion, + ) + @classmethod def get_recipe(cls, model_type): - if model_type == ModelType.TextOnly: - return cls._text_only() - elif model_type == ModelType.Flamingo: - return cls._flamingo() - elif model_type == ModelType.Llama3_1: - return cls._llama3_1() - else: - raise ValueError(f"Can not find the model recipe for {model_type}") + match model_type: + case ModelType.TextOnly: + return cls._text_only() + case ModelType.Flamingo: + return cls._flamingo() + case ModelType.Llama3_1: + return cls._llama3_1() + case ModelType.Llava: + return cls._llava() + case _: + raise ValueError(f"Can not find the model recipe for {model_type}") @dataclass @@ -329,7 +460,14 @@ def build_model(self) -> nn.Module: modules[name] = module_class(**config_args) return recipe.fusion_class(**modules) - + + def _replace_known_params(self, params): + patterns = {"QuickGELUActivation()": QuickGELUActivation()} + for key, value in params.items(): + if isinstance(value, Hashable) and value in patterns: + params[key] = patterns[value] + return params + @abstractmethod def forward(self, *args, **kwargs): raise NotImplementedError("forward method is not implemented") @@ -414,11 +552,26 @@ def reset_caches(self): self.model.reset_caches() +class LlavaModel(Model): + def forward( + self, + tokens: Tensor, + *, + encoder_input: Optional[Dict[str, Tensor]] = None, + post_tokens: Optional[Tensor] = None, + input_pos: Optional[Tensor] = None, + ) -> Tensor: + return self.model(tokens, encoder_input=encoder_input, post_tokens=post_tokens, input_pos=input_pos) + + def setup_caches(self, max_batch_size, max_seq_length): + self.model.setup_caches(max_batch_size, max_seq_length) + MODEL_TYPE_TO_CLASS = { ModelType.TextOnly: TextOnlyModel, ModelType.Flamingo: FlamingoModel, ModelType.Llama3_1: Llama31Model, + ModelType.Llava: LlavaModel, } diff --git a/torchchat/model_params/llava-1.5.json b/torchchat/model_params/llava-1.5.json new file mode 100644 index 000000000..992cc2c69 --- /dev/null +++ b/torchchat/model_params/llava-1.5.json @@ -0,0 +1,25 @@ +{ + "model_type": "llava", + "use_tiktoken": true, + "encoder": { + "tile_size": 336, + "patch_size": 14, + "embed_dim": 1024, + "num_layers": 24, + "num_heads": 16, + "out_indices": [ + 23 + ], + "output_cls_projection": false, + "max_num_tiles": 1, + "in_channels": 3, + "intermediate_act": "QuickGELUActivation()" + }, + "decoder": { + "n_layers": 32, + "n_heads": 32, + "dim": 4096, + "vocab_size": 32064, + "max_seq_length": 768 + } +}