diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 260a32f1d..424e9fc4c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,22 +1,160 @@ ## Contributing to PROJECT Hi there! -We’re thrilled that you’d like to contribute to this project. +We're thrilled that you'd like to contribute to this project. Your help is essential for keeping this project great and for making it better. -## Branching Strategy -In general, contributors should develop on branches based off of `main` and pull requests should be made against `main`. +## Submitting Your Contribution -## Submitting a pull request +Follow these steps to submit your example to the QEfficient repository: 1. Please read our [code of conduct](CODE-OF-CONDUCT.md) and [license](LICENSE). -1. Fork and clone the repository. -1. Create a new branch based on `main`: `git checkout -b main`. -1. Make your changes, add tests, and make sure the tests still pass. -1. Commit your changes using the [DCO](http://developercertificate.org/). You can attest to the DCO by commiting with the **-s** or **--signoff** options or manually adding the "Signed-off-by". -1. Push to your fork and submit a pull request from your branch to `main`. -1. Pat yourself on the back and wait for your pull request to be reviewed. + +### 1. Fork and Clone the Repository + +First, fork the repository to your GitHub account, then clone your fork: + +```bash +# Fork the repository on GitHub (click the "Fork" button) +# Then clone your fork +git clone git@github.com:YOUR_USERNAME/efficient-transformers.git +cd efficient-transformers + +# Add upstream remote to keep your fork in sync +git remote add upstream git@github.com:quic/efficient-transformers.git +``` + +### 2. Create a Feature Branch + +Create a descriptive branch for your changes: + +```bash +# Update your main branch +git checkout main +git pull upstream main + +# Create a new branch +git checkout -b +``` + +### 3. Make Your Changes + +When making changes to the codebase: + +- **Follow Existing Design Patterns** + - Review similar implementations before creating new code + - Maintain consistency with the project's architecture and coding style + - Reuse existing utilities and base classes where applicable + +- **Onboarding New Models** + - For adding new model support, refer to the comprehensive guide: `examples/onboarding_guide/causallm/` + - Follow the step-by-step process with code examples provided + +- **Testing is Mandatory** + - Add tests for all new features in the appropriate `tests/` subdirectory + - Run tests locally before pushing: `pytest tests/path/to/your/test.py -v` + - For model additions, verify all 4 pipeline stages (PyTorch HF → KV → ORT → AI 100) and make sure tokens are matching with refernce PyTorch HF + +- **Documentation** + - **For New Features/Flags:** + - Document usage in `docs/source/` with feature description and usage examples + - Ensure documentation is clear enough for others to understand and use the feature + - **For New Models:** + - Test with basic inference scripts in the `examples/` folder + - If specific changes are needed, create a dedicated example file + - Update `docs/source/validate.md` with the model's HuggingFace card name and relevant details + + +- **Code Quality Checks** + - Pre-commit hooks, DCO sign-off, and CI checks are covered in the following steps + - Ensure you complete steps 4-8 before finalizing your PR + +### 4. Run Pre-commit Checks + +Before committing, ensure your code passes all quality checks: + +```bash +# Install pre-commit and ruff if not already installed +pip install pre-commit +pip install ruff + +# Run pre-commit on your changed files +pre-commit run --files path/to/your/file1.py path/to/your/file2.py + +# Run Ruff check +ruff check +``` + +**Important:** If pre-commit reports any failures: +- Some issues will be auto-fixed (formatting, trailing whitespace, etc.) +- For issues that aren't auto-fixed, manually correct them +- Re-run `pre-commit run --files ` or `ruff check` until all checks pass + +### 5. Commit with Sign-off (DCO) + +All commits must be signed off to comply with the Developer Certificate of Origin (DCO): + +```bash +# Stage your changes +git add examples/your_domain/your_example.py +git add examples/your_domain/README.md + +# Commit with sign-off +git commit -s --author "Your Name " -m "Add [model-name] support + +- Implements inference for [model-name] +- Includes documentation and usage examples +- Tested with [specific configurations]" +``` + +**Commit Message Guidelines:** +- Use a clear, descriptive title +- Add a blank line, then detailed description if needed +- Always include the `-s` flag for DCO sign-off + +### 6. Push to Your Fork + +Push your branch to your forked repository: + +```bash +git push origin +``` + +### 7. Create a Pull Request + +1. Go to your fork on GitHub +2. Click "Compare & pull request" for your branch +3. Fill out the PR template with: + - **Title:** Clear, descriptive title (e.g., "Add Llama-3.2-Vision Support" or "Fix memory leak in KV cache") + - **Description:** + - What changes were made and why + - What problem it solves or feature it adds + - Any special considerations or breaking changes + - Links to relevant documentation, issues, or model cards (if applicable) + - **Testing:** Describe how you tested your changes + +### 8. Ensure CI Checks Pass + +After creating the PR, verify that all automated checks pass: + +- ✅ **DCO Check:** Ensures all commits are signed off +- ✅ **Lint Check:** Code style and formatting validation +- ✅ **Tests:** Automated test suite (if applicable) + +If any checks fail: +1. Review the error messages in the PR +2. Make necessary fixes in your local branch +3. Commit and push the fixes (with sign-off) +4. The PR will automatically update and re-run checks + +### 9. Address Review Feedback + +Maintainers will review your PR and may request changes: +- Make requested changes in your local branch +- Commit with sign-off and push to update the PR +- Respond to comments to facilitate discussion + Here are a few things you can do that will increase the likelihood of your pull request to be accepted: diff --git a/examples/onboarding_guide/causallm/Onboarding.png b/examples/onboarding_guide/causallm/Onboarding.png new file mode 100644 index 000000000..8c83b0ac0 Binary files /dev/null and b/examples/onboarding_guide/causallm/Onboarding.png differ diff --git a/examples/onboarding_guide/causallm/README.md b/examples/onboarding_guide/causallm/README.md new file mode 100644 index 000000000..5a000eade --- /dev/null +++ b/examples/onboarding_guide/causallm/README.md @@ -0,0 +1,196 @@ +# Onboarding a CausalLM Model + +## Prerequisites + +Install `qefficient-transformers` library in editable mode: +```sh +git clone https://github.com/quic/efficient-transformers.git +cd efficient-transformers +pip install -e . +``` + +## Introduction + +This guide walks you through onboarding a new CausalLM model to QEfficient-transformers. We use an example model named `Blueprint` to demonstrate the required changes. + +--- + +## Onboarding Process + +![Onboarding Flowchart](./Onboarding.png) + +--- + +## Step 1: Check Transformers Library + +1. **Locate the model** in the transformers library: + - Path: `/src/transformers/models//modeling_.py` + - Example: `/src/transformers/models/blueprint/modeling_blueprint.py` + +2. **Identify required classes**: + - Attention Layer + - Decoder Layer + - Model (main class) + - ForCausalLM (top-level) + - RMSNorm/LayerNorm + - RotaryEmbedding (if applicable) + +3. **Check existing implementations** in `QEfficient/transformers/models/`: + - If similar classes exist → Reuse patterns + - If not → Create custom implementations + +--- + +## Step 2: Create Custom Files & Mappings + +### 2.1 Create Custom Modeling File + +Create directory structure: +``` +QEfficient/transformers/models/blueprint/ +├── __init__.py +└── modeling_blueprint.py +``` + +**Key modifications in `modeling_blueprint.py`:** +- `QEffBlueprintRotaryEmbedding`: Precompute sin/cos for rotary embeddings +- `QEffBlueprintAttention`: Use `position_ids`, return `past_key_value`, implement `__qeff_init__` +- `QEffBlueprintDecoderLayer`: Return `past_key_value` from forward pass +- `QEffBlueprintModel`: Use `QEffDynamicCache` instead of standard cache +- `QEffBlueprintForCausalLM`: Entry point with additional parameters + +See `modeling_example.py` for detailed implementation examples. + +### 2.2 Add Mappings in pytorch_transforms.py + +**CustomOpsTransform** (RMSNorm mapping): +```python +class CustomOpsTransform(ModuleMappingTransform): + _module_mapping = { + BlueprintRMSNorm: CustomRMSNormAIC, + } +``` + +**KVCacheTransform** (all model classes): +```python +class KVCacheTransform(ModuleMappingTransform): + _module_mapping = { + BlueprintAttention: QEffBlueprintAttention, + BlueprintDecoderLayer: QEffBlueprintDecoderLayer, + BlueprintModel: QEffBlueprintModel, + BlueprintForCausalLM: QEffBlueprintForCausalLM, + } +``` + +See `example_pytorch_transforms.py` for complete example. + +--- + +## Step 3: Testing (4-Stage Pipeline) + +Your implementation is validated through four stages: + +| Stage | Description | Validation | +|-------|-------------|------------| +| **1. PyTorch HF** | Original transformers model | Baseline tokens | +| **2. PyTorch KV** | After QEff transforms | Tokens match Stage 1 | +| **3. ONNX/ORT** | After export to ONNX | Tokens match Stage 2 | +| **4. Cloud AI 100** | Hardware execution | Tokens match Stage 3 | + +**Test function:** `check_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100` in `tests/transformers/models/test_causal_lm_models.py` + +### Common Issues + +**Token mismatch (Stage 1→2):** +- Check all classes are mapped in `KVCacheTransform` +- Verify `__qeff_init__` methods exist +- Ensure `position_ids` are correctly passed + +**ONNX export failure (Stage 2→3):** +- Check for unsupported PyTorch operations +- Verify dynamic shapes are defined + +**Compilation failure (Stage 3→4):** +- Reduce `num_cores` or model size +- Check device availability: `get_available_device_id()` + +--- + +## Step 4: Add to Test Suite + +Edit `tests/transformers/models/test_causal_lm_models.py`: + +```python +test_models_causal = [ + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "gpt2", + # ... existing models ... + "YourOrg/YourModel-7B", # Add your model here +] +``` + +**Run tests:** +```bash +# Test your specific model +pytest tests/transformers/models/test_causal_lm_models.py::test_custom_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100 -k "YourModel" -v + +# Run all regular tests +pytest tests/transformers/models/test_causal_lm_models.py -m regular +``` + +--- + +## Step 5: Validation Checklist + +Before submitting PR: + +**Implementation:** +- [ ] Created `QEfficient/transformers/models//` directory +- [ ] Implemented all required custom classes +- [ ] Added mappings in `CustomOpsTransform` and `KVCacheTransform` +- [ ] Added imports at top of `pytorch_transforms.py` + +**Testing:** +- [ ] Model added to `test_models_causal` list +- [ ] All 4 stages pass (PyTorch HF → KV → ORT → AI 100) +- [ ] Continuous batching tests pass +- [ ] `qconfig.json` generated successfully + +**Code Quality:** +- [ ] Code follows project style guidelines +- [ ] Commits use DCO sign-off (`git commit -s`) +- [ ] Branch created from `main` + +--- + +## Step 6: Submit Pull Request + +Follow guidelines in [CONTRIBUTING.md](../../../CONTRIBUTING.md): + +1. Create feature branch: `git checkout -b add-yourmodel-support main` +2. Commit with DCO: `git commit -s -m "Add support for YourModel"` +3. Push and create PR targeting `main` branch +4. Include test results in PR description + +--- + +## Troubleshooting Quick Reference + +| Issue | Solution | +|-------|----------| +| Token mismatch between stages | Check class mappings, verify `position_ids` handling | +| Shape errors | Verify KV cache dimensions, check `past_key_value` returns | +| ONNX export fails | Replace unsupported ops, define dynamic shapes | +| Compilation fails | Reduce `num_cores`, check device availability | +| Runtime errors | Verify input shapes match specializations | + +**Debug tip:** Start with `n_layer=1` and short prompts, then gradually increase complexity. + +--- + +## References + +- [Hugging Face Transformers](https://github.com/huggingface/transformers) +- [QEfficient Transformers](https://github.com/quic/efficient-transformers) +- [Contributing Guidelines](../../../CONTRIBUTING.md) +- [Test Suite](../../../tests/transformers/models/test_causal_lm_models.py) diff --git a/examples/onboarding_guide/causallm/example_pytorch_transforms.py b/examples/onboarding_guide/causallm/example_pytorch_transforms.py new file mode 100644 index 000000000..ff62588f9 --- /dev/null +++ b/examples/onboarding_guide/causallm/example_pytorch_transforms.py @@ -0,0 +1,291 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Example pytorch_transforms.py showing common model onboarding patterns. + +This file demonstrates three representative patterns: +1. Blueprint - Standard decoder-only model (example for onboarding) +2. Llama - Most common architecture pattern +3. Mixtral - Mixture of Experts (MoE) model + +For more examples and patterns, see: +- Production transforms: QEfficient/base/pytorch_transforms.py +- All model implementations: QEfficient/transformers/models/ +- Specific patterns: + * Gemma (custom RMSNorm): QEfficient/transformers/models/gemma/ + * Multimodal (Llama4, Mllama): QEfficient/transformers/models/llama4/ + * External models (Grok): QEfficient/transformers/models/grok_1/ + * Vision-Language models: QEfficient/transformers/models/mllama/ +""" + +import warnings +from types import MethodType +from typing import Callable, Optional, Tuple, Union + +from QEfficient.transformers.models.blueprint.modeling_blueprint import ( + QEffBlueprintAttention, + QEffBlueprintDecoderLayer, + QEffBlueprintForCausalLM, + QEffBlueprintModel, +) +from torch import nn + +# Example imports for three representative models +from transformers.models.blueprint.modeling_blueprint import ( + BlueprintAttention, + BlueprintDecoderLayer, + BlueprintForCausalLM, + BlueprintModel, + BlueprintRMSNorm, +) +from transformers.models.llama.modeling_llama import ( + LlamaAttention, + LlamaDecoderLayer, + LlamaForCausalLM, + LlamaModel, + LlamaRMSNorm, +) +from transformers.models.mixtral.modeling_mixtral import ( + MixtralAttention, + MixtralDecoderLayer, + MixtralForCausalLM, + MixtralModel, + MixtralRMSNorm, + MixtralSparseMoeBlock, +) + +from QEfficient.base.pytorch_transforms import ExternalModuleMapperTransform, ModuleMappingTransform +from QEfficient.customop import CustomRMSNormAIC +from QEfficient.transformers.embeddings.embedding_utils import POOLING_MAP, PooledModel, validate_user_pooling_function +from QEfficient.transformers.models.llama.modeling_llama import ( + QEffLlamaAttention, + QEffLlamaDecoderLayer, + QEffLlamaForCausalLM, + QEffLlamaModel, +) +from QEfficient.transformers.models.mixtral_moe.modeling_mixtral import ( + QEffMixtralAttention, + QeffMixtralDecoderLayer, + QEffMixtralForCausalLM, + QEffMixtralModel, + QEffMixtralSparseMoeBlock, +) +from QEfficient.transformers.post_processing import build_and_attach_mlp, model_type_registry +from QEfficient.transformers.sampler.sampler import sampler_forward +from QEfficient.transformers.spd.spd_transform_forward import tlm_forward + +SPD_TARGET = "target" + + +class CustomOpsTransform(ModuleMappingTransform): + """ + Maps RMSNorm classes to custom implementations optimized for Cloud AI 100. + + Most models use the standard CustomRMSNormAIC. For special cases (like Gemma), + you can create custom RMSNorm in QEfficient.customop. + """ + + _module_mapping = { + # Blueprint - Example model for onboarding + BlueprintRMSNorm: CustomRMSNormAIC, + # Llama - Most common pattern + LlamaRMSNorm: CustomRMSNormAIC, + # Mixtral - MoE model pattern + MixtralRMSNorm: CustomRMSNormAIC, + # TODO: Add your model's RMSNorm mapping here: + # YourModelRMSNorm: CustomRMSNormAIC, + } + + +class KVCacheTransform(ModuleMappingTransform): + """ + Maps model classes to their QEfficient counterparts with KV cache support. + + This is the most critical transform for enabling efficient inference. + All model classes (Attention, DecoderLayer, Model, ForCausalLM) must be mapped. + """ + + _module_mapping = { + # Blueprint - Example model for onboarding + BlueprintAttention: QEffBlueprintAttention, + BlueprintDecoderLayer: QEffBlueprintDecoderLayer, + BlueprintModel: QEffBlueprintModel, + BlueprintForCausalLM: QEffBlueprintForCausalLM, + # Llama - Most common pattern (standard decoder-only) + LlamaAttention: QEffLlamaAttention, + LlamaDecoderLayer: QEffLlamaDecoderLayer, + LlamaModel: QEffLlamaModel, + LlamaForCausalLM: QEffLlamaForCausalLM, + # Mixtral - MoE model pattern (includes SparseMoeBlock) + MixtralAttention: QEffMixtralAttention, + MixtralSparseMoeBlock: QEffMixtralSparseMoeBlock, + MixtralDecoderLayer: QeffMixtralDecoderLayer, + MixtralModel: QEffMixtralModel, + MixtralForCausalLM: QEffMixtralForCausalLM, + # TODO: Add your model's class mappings here: + # YourModelAttention: QEffYourModelAttention, + # YourModelDecoderLayer: QEffYourModelDecoderLayer, + # YourModelModel: QEffYourModelModel, + # YourModelForCausalLM: QEffYourModelForCausalLM, + } + + @classmethod + def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: + model, transformed = super().apply(model) + return model, transformed + + +class SpDTransform: + """ + Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits during decode phase and extract last predicted token during prefill. + This is only needed if user is exporting Target Language Model (TLM) for Speculative Decoding to validate output logits + against the speculated tokens from a smaller model. + Other than the computed logits, there should be no difference between the SpD Transformed model and its corresponding cunterpart. + + ``Mandatory`` Args: + :model (nn.Module): PyTorch model. + + Returns: + :model (nn.Module): PyTorch model. + :transformed (bool): whether transformation was applied successfully. + """ + + # supported architectures + _module_mapping = { + QEffBlueprintForCausalLM, + # TODO: Add your model's ForCausalLM class here if using Speculative Decoding: + # QEffYourModelForCausalLM, + } + + @classmethod + def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -> Tuple[nn.Module, bool]: + transformed = False + pretrained_model_name_or_path_temp = kwargs.pop("pretrained_model_name_or_path", None) + + if qaic_config is None or (speculative_model_type := qaic_config.get("speculative_model_type")) is None: + return model, transformed + + if speculative_model_type not in (supported_spd_model_types := [SPD_TARGET] + list(model_type_registry.keys())): + raise ValueError( + f"Speculative model type {speculative_model_type} is not supported. " + f"Currently only support {supported_spd_model_types}" + ) + + if (model_class := model.__class__) in cls._module_mapping: + model.forward = MethodType(tlm_forward, model) + if speculative_model_type != SPD_TARGET: + pretrained_model_name_or_path = qaic_config["pretrained_model_name_or_path"] + model = build_and_attach_mlp( + model, pretrained_model_name_or_path, speculative_model_type=speculative_model_type, **kwargs + ) + transformed = True + else: + raise NotImplementedError( + f"Model class {model_class} does not yet support returning multiple logits to keep." + ) + + kwargs["pretrained_model_name_or_path"] = pretrained_model_name_or_path_temp + return model, transformed + + +class SamplerTransform: + """ + Add nodes at the output of any generic QEffForCausalLM model to enable the + sampling of next tokens at the device (instead of the host) and return the + next tokens and/or probability distributions. + + Note: To achieve this, the generic QEffForCausalLM model must provide the + logits as output. + + ``Mandatory`` Args: + :model (nn.Module): PyTorch model. + + Returns: + :model (nn.Module): PyTorch model. + :transformed (bool): whether transformation was applied successfully. + """ + + # supported architectures + _module_mapping = { + # TODO: Add your model's ForCausalLM class here if using on-device sampling: + # QEffYourModelForCausalLM, + } + + @classmethod + def apply(cls, model: nn.Module, qaic_config: Optional[dict] = None, **kwargs) -> Tuple[nn.Module, bool]: + transformed = False + if qaic_config is None or not qaic_config.get("include_sampler", False): + return model, transformed + + if (model_class := model.__class__) in cls._module_mapping: + model.old_forward = model.forward + model.forward = MethodType(sampler_forward, model) + transformed = True + else: + raise NotImplementedError(f"Model class {model_class} does not support on device sampling.") + + return model, transformed + + +class VlmKVOffloadTransform(ModuleMappingTransform): + """ + Vision-Language Model transform with KV offloading (two QPC setup). + + Used for multimodal models where vision and text processing are separated. + See QEfficient/transformers/models/mllama/ for implementation examples. + """ + + _module_mapping = { + # TODO: Add VLM models with KV offloading here: + # YourVLMTextCrossAttention: QEffYourVLMTextCrossAttentionTwoQPC, + } + + +class VlmNoKVOffloadTransform(ModuleMappingTransform): + """ + Vision-Language Model transform without KV offloading (single QPC setup). + + Used for multimodal models in single QPC configuration. + See QEfficient/transformers/models/mllama/ for implementation examples. + """ + + _module_mapping = { + # TODO: Add VLM models without KV offloading here: + # YourVLMTextCrossAttention: QEffYourVLMTextCrossAttentionSingleQPC, + } + + +class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): + _match_string_replace_method = { + # TODO: Add external model mappings here (for models not in transformers library): + # "YourExternalModelClass": { + # "forward": QEffYourExternalModel.forward, + # "__qeff_init__": QEffYourExternalModel.__qeff_init__, + # }, + } + + _match_class_replace_method = {} + + +class PoolingTransform: + """ + Apply a pooling transformation to the model. This transformation appends a pooling layer to the model, allowing for the reduction of spatial dimensions in the output. + The pooling layer can be configured to use different pooling methods, such as max pooling or average pooling. + """ + + @classmethod + def apply(cls, model: nn.Module, pooling: Union[str, Callable]) -> Tuple[nn.Module, bool]: + transformed = False + pooling_method = ( + POOLING_MAP[pooling] + if isinstance(pooling, str) and pooling in POOLING_MAP + else validate_user_pooling_function(pooling) + ) + model = PooledModel(model, pooling_method) + warnings.warn("Pooling is applied to the model.") + return model, transformed diff --git a/examples/onboarding_guide/causallm/modeling_example.py b/examples/onboarding_guide/causallm/modeling_example.py new file mode 100644 index 000000000..195c9d7db --- /dev/null +++ b/examples/onboarding_guide/causallm/modeling_example.py @@ -0,0 +1,394 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +"""PyTorch Blueprint model.""" + +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from transformers.cache_utils import Cache +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.models.blueprint.modeling_blueprint import ( + BlueprintAttention, + BlueprintConfig, + BlueprintDecoderLayer, + BlueprintForCausalLM, + BlueprintModel, + BlueprintRotaryEmbedding, + rotate_half, +) + +from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask + + +class QEffBlueprintRotaryEmbedding(BlueprintRotaryEmbedding): + """ + Add the required Rotary Embedding functionality to the model based on the Class in the transformers modeling file. + The purpose of this class is to precompute sin and cos values for the rotary embedding and cache it for faster inference. + This class is more or less the same for all models that are onboarded. + """ + + def __init__(self, config: BlueprintConfig, device=None): + super().__init__(config=config) + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + ) + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors. + + We modify this method to enable the application of the rotary embedding based on position_ids + instead of seq_len. This is needed as our modified modelling accepts position_ids and not + the attention_mask as an input. + """ + # + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + **kwargs, +): + """ + Implements the forward pass of Eager Attention for the model. + We explicitly support Eager mode based attention on our device. + The method would mostly be generic so we don't expect it to have much changes. + MIN_MASKED_ATTENTION_VALUE is a special value which helps our compiler know what -inf should be represented by. + """ + pass + + +class QEffBlueprintAttention(BlueprintAttention): + """ + Here we'll setup the forward pass of the Attention module as implemented in the original model. + We initialize our own RotaryEmbedding module via __qeff_init__ method call. + + """ + + # < We load our own custom class for the rotary embedding to enable supporting position_ids> + # Since we map the custom classes to the original classes, __init__ method wouldn't work as expected, + # Hence we use __qeff_init__ method to initialize something while the mapping happens. + + def __qeff_init__(self): + self.rotary_emb = QEffBlueprintRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Most of the implementation remains the same as original forward method. + The parts where difference occurs are the way we apply the rotary embeddings. + Also, we return the past_key_values instead of storing it in the default transformers cache. + """ + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states, **kwargs) + key_states = self.k_proj(hidden_states, **kwargs) + value_states = self.v_proj(hidden_states, **kwargs) + + query_states = query_states.view(hidden_shape).transpose(1, 2) + key_states = key_states.view(hidden_shape).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + + # We build the rotary embeddings different from the transformers method. + kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + # Application of the rotary embeddings requires position_ids as well. + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # < We add all the required items for cache kwargs which would enable updating QEffDynamicCache > + cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # < We override the attention_interface method with our own to enable Eager Attention> + attention_interface: Callable = eager_attention_forward + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value + + +class QEffBlueprintDecoderLayer(BlueprintDecoderLayer): + """ + Overrides the forward method of the original BlueprintDecoderLayer. + Only changes being that the past_key_value is returned and `self.self_attn` method + is now an object of QEffBlueprintAttention instead. + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + The modified forward function also stores and returns the past_key_value. + Every other operation remains the same. + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # < Self attention would also have to return the past_key_value as well and we capture it here> + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class QEffBlueprintModel(BlueprintModel): + """ + Replaces the original BlueprintModel with a modified version. + We initialize the custom `QEffDynamicCache` for past_key_values here instead of the DynamicCache class. + This custom Cache class has all the required custom ops to perform CtxScatter/CtxGather as well as other required operations. + This enables us to cache the past key values in the way we want for AIC. The component won't require any changes mostly. + """ + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + # < We create the custom QEffDynamicCache here to be used during the AIC execution> + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_causal_mask( + position_ids=position_ids, target_length=target_length, sliding_window=self.config.sliding_window + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + batch_index=batch_index, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + +class QEffBlueprintForCausalLM(BlueprintForCausalLM): + """ + No major changes are needed in the forward method of this class, it is the entry point for the model during inference. + We add the additionally required parameters and pass those down the line as well. + """ + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # < We add the additional parameters that we use for our models here and pass them down the line > + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + # Cast to INT32 to avoid issue while running in ONNXRT + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + + logits = self.lm_head(hidden_states) + logits = logits.float() + + return CausalLMOutputWithPast( + loss=None, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + )