In [1]:

import os

import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer

In [4]:
model_path = "mistralai/Mistral-7B-Instruct-v0.3"

os.environ["TOKENIZERS_PARALLELISM"] = "false"

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)

tokenizer_config.json:   0%|          | 0.00/141k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/587k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.96M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/601 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.55G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

In [5]:
# Model Wrapper Class
class MistralModelForCausalLM(nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model

    @torch.no_grad()
    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        return self.model(input_ids, attention_mask).logits


In [7]:
# test inference

test_input = tokenizer("Hello, how are you?", return_tensors="pt")
print(test_input)

{'input_ids': tensor([[    1, 23325, 29493,  1678,  1228,  1136, 29572]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}


In [8]:
model = MistralModelForCausalLM(model).eval()

output = model(**test_input)

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


In [12]:
test_input.attention_mask.shape

torch.Size([1, 7])

In [14]:
input_ids: torch.Tensor = torch.zeros((1, 2), dtype=torch.int32)
causal_mask: torch.Tensor = torch.zeros((1, 1, 2, 5), dtype=torch.float32)

output = model.forward(input_ids, causal_mask)

In [17]:
output

tensor([[[-9.3285, -8.7385,  1.8744,  ..., -7.1242, -4.3918, -8.2920],
         [-9.3285, -8.7385,  1.8744,  ..., -7.1242, -4.3918, -8.2920]]],
       grad_fn=<UnsafeViewBackward0>)

In [18]:
output2 = model.forward(input_ids, torch.zeros((1,2), dtype=torch.int32))

In [1]:
from transformers.models.mistral.modeling_mistral import (
    MISTRAL_ATTENTION_CLASSES,
    MistralAttention,
    MistralConfig,
    MistralForCausalLM,
    apply_rotary_pos_emb,
    repeat_kv
    )
from transformers.cache_utils import Cache

from typing import Tuple, List, Optional

import torch

In [12]:
class SliceUpdateKeyValueCache(Cache):
    def __init__(
        self,
        shape: Tuple[int, ...],
        device: str = "cpu",
        dtype=torch.float32
    ) -> None:
        super().__init__()
        self.past_seen_tokens: int = 0
        self.k_cache: torch.Tensor = torch.zeros(shape, dtype=dtype, device=device)
        self.v_cache: torch.Tensor = torch.zeros(shape, dtype=dtype, device=device)

    def update(
        self,
        k_state: torch.Tensor,
        v_state: torch.Tensor,
        layer_idx: int,
        slice_indices: torch.LongTensor
        ) -> Tuple[torch.Tensor, torch.Tensor]:
        
        if len(slice_indices) != 2:
            raise ValueError(f"slice_indices must be of length 2, got {len(slice_indices)}")
        begin, end = slice_indices
        self.k_cache[layer_idx, :, : k_state.shape[1], begin: end, :] = k_state
        self.v_cache[layer_idx, :, : v_state.shape[1], begin: end, :] = v_state
        k_cache: torch.Tensor = self.k_cache[layer_idx, :, :, :end, :]
        v_cache: torch.Tensor = self.v_cache[layer_idx, :, :, :end, :]
        return k_cache, v_cache
    
    def get_seq_length(self, _: int | None = 0) -> int:
        return self.past_seen_tokens
        

In [13]:
class SliceUpdateMistralAttention(MistralAttention):
    def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None) -> None:
        super().__init__(config, layer_idx)
        
    @torch.no_grad()
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
        position_ids: Optional[torch.Tensor] = None,
        past_key_value: Optional[Cache] = None,
        **kwargs
    ) -> Tuple[torch.Tensor | None, ...]:
        bsz, q_len, _ = hidden_states.shape
        
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(
            1, 2
        )
        value_states = value_states.view(
            bsz, q_len, self.num_key_value_heads, self.head_dim
        ).transpose(1, 2)
        
        cos, sin = self.rotary_emb(value_states, position_ids)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
        
        end_step = attention_mask.shape[-1]
        key_states, value_states = past_key_value.update(
            key_states,
            value_states,
            self.layer_idx,
            slice_indices=(end_step - q_len, end_step)
        )
        
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)
        
        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask = attention_mask
        )
        
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, self.hidden_size)
        attn_output = self.o_proj(attn_output)
        return attn_output, None, None
        
            
            

In [2]:
class StatefulMistralModelForCausalLM(torch.nn.Module):
    def __init__(self, model_path: str, max_context_size: int = 2048, batch_size: int = 1):
        super().__init__()
        # MISTRAL_ATTENTION_CLASSES["sdpa"] = SliceUpdateMistralAttention
        self.model = MistralForCausalLM.from_pretrained(model_path)
        
        # config: MistralConfig = self.model.config
        # self.kv_cache_shape: Tuple[int, ...] = (
        #     config.num_hidden_layers,
        #     batch_size,
        #     config.num_key_value_heads,
        #     max_context_size,
        #     config.hidden_size // config.num_attention_heads
        # )
        # self.kv_cache = SliceUpdateKeyValueCache(shape=self.kv_cache_shape)
        # self.register_buffer("keyCache", self.kv_cache.k_cache)
        # self.register_buffer("valueCache", self.kv_cache.v_cache)
        
    @torch.no_grad()
    def forward(
        self,
        input_ids: torch.LongTensor,
        causal_mask: torch.Tensor
    ) -> torch.Tensor:
        # self.kv_cache.past_seen_tokens = causal_mask.shape[-1] - input_ids.shape[-1]
        return self.model(
            input_ids,
            attention_mask=causal_mask,
            # past_key_values=self.kv_cache
        ).logits
        
    
    

        

In [3]:
max_context_size: int = 2048
model_id: str = "mistralai/Mistral-7B-Instruct-v0.3"
torch_model = StatefulMistralModelForCausalLM(model_id, max_context_size=max_context_size)
torch_model.eval()
input_ids: torch.Tensor = torch.zeros((1,2), dtype=torch.int32)
causal_mask: torch.Tensor = torch.zeros((1, 1, 2, 5), dtype=torch.float32)
traced_model = torch.jit.trace(torch_model, [input_ids, causal_mask])


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
  if attention_mask.max() != 0:


In [5]:
import coremltools as ct
import numpy as np
# convert traced TorchScript to CoreML format

query_length = ct.RangeDim(lower_bound=1, upper_bound=max_context_size, default=1)
end_step_dim = ct.RangeDim(lower_bound=1, upper_bound=max_context_size, default=1)

inputs: List[ct.TensorType] = [
    ct.TensorType(shape=(1, query_length), dtype=np.int32, name="inputIds"),
    ct.TensorType(
        shape=(1, 1, query_length, end_step_dim),
        dtype=np.float16,
        name="causal_mask"
    )
]

outputs: List[ct.TensorType] = [ct.TensorType(dtype=np.float16, name="logits")]
# states: List[ct.StateType] = [
#     ct.StateType(
#         wrapped_type=ct.TensorType(shape=torch_model.kv_cache_shape, dtype=np.float16),
#         name="keyCache"
#     ),
#     ct.StateType(
#         wrapped_type=ct.TensorType(shape=torch_model.kv_cache_shape, dtype=np.float16),
#         name="valueCache"
#     )
# ]

In [7]:
mlmodel_fp16 = ct.convert(
    traced_model,
    inputs=inputs,
    outputs=outputs,
    # states=states,
    minimum_deployment_target=ct.target.iOS18,
    skip_model_load=True
)

Converting PyTorch Frontend ==> MIL Ops: 100%|█████████▉| 4740/4741 [00:02<00:00, 2264.50 ops/s]
Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00,  7.38 passes/s]
Running MIL default pipeline: 100%|██████████| 86/86 [02:54<00:00,  2.02s/ passes]
Running MIL backend_mlprogram pipeline: 100%|██████████| 12/12 [00:00<00:00, 21.16 passes/s]


In [8]:
mlmodel_fp16.save("mlmodel-no-state-fp16.mlpackage")

In [9]:
op_config = ct.optimize.coreml.OpLinearQuantizerConfig(
    mode="linear_symmetric",
    dtype="int4",
    granularity="per_block",
    block_size=32    
)

config = ct.optimize.coreml.OptimizationConfig(global_config=op_config)
mlmodel_int4 = ct.optimize.coreml.linear_quantize_weights(mlmodel_fp16, config=config)

  quantized_data = np.round(weight / scale)
  quantized_data = np.clip(quantized_data, q_val_min, q_val_max).astype(dtype)
Running compression pass linear_quantize_weights: 100%|██████████| 296/296 [05:20<00:00,  1.08s/ ops]
Running MIL frontend_milinternal pipeline: 0 passes [00:00, ? passes/s]
Running MIL default pipeline: 100%|██████████| 84/84 [00:14<00:00,  5.71 passes/s]
Running MIL backend_mlprogram pipeline: 100%|██████████| 12/12 [00:00<00:00, 17.64 passes/s]


In [12]:
mlmodel_int4.save("mlmodel-no-state-int4.mlpackage")

In [10]:
!du -hs ./mlmodel-no-state-fp16.mlpackage/

14G	./mlmodel-no-state-fp16.mlpackage/


In [13]:
!du -hs ./mlmodel-no-state-int4.mlpackage/

3.8G	./mlmodel-no-state-int4.mlpackage/
