In [1]:

import os

import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
from transformers.models.mistral.modeling_mistral import (
    MISTRAL_ATTENTION_CLASSES,
    MistralAttention,
    MistralConfig,
    MistralForCausalLM,
    apply_rotary_pos_emb,
    repeat_kv
    )
from transformers.cache_utils import Cache

from typing import Tuple, List, Optional

import torch

In [3]:
class SliceUpdateKeyValueCache(Cache):
    def __init__(
        self,
        shape: Tuple[int, ...],
        device: str = "cpu",
        dtype=torch.float32
    ) -> None:
        super().__init__()
        self.past_seen_tokens: int = 0
        self.k_cache: torch.Tensor = torch.zeros(shape, dtype=dtype, device=device)
        self.v_cache: torch.Tensor = torch.zeros(shape, dtype=dtype, device=device)

    def update(
        self,
        k_state: torch.Tensor,
        v_state: torch.Tensor,
        layer_idx: int,
        slice_indices: torch.LongTensor
        ) -> Tuple[torch.Tensor, torch.Tensor]:
        
        if len(slice_indices) != 2:
            raise ValueError(f"slice_indices must be of length 2, got {len(slice_indices)}")
        begin, end = slice_indices
        self.k_cache[layer_idx, :, : k_state.shape[1], begin: end, :] = k_state
        self.v_cache[layer_idx, :, : v_state.shape[1], begin: end, :] = v_state
        k_cache: torch.Tensor = self.k_cache[layer_idx, :, :, :end, :]
        v_cache: torch.Tensor = self.v_cache[layer_idx, :, :, :end, :]
        return k_cache, v_cache
    
    def get_seq_length(self, _: int | None = 0) -> int:
        return self.past_seen_tokens
        

In [4]:
class SliceUpdateMistralAttention(MistralAttention):
    def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None) -> None:
        super().__init__(config, layer_idx)
        
    @torch.no_grad()
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
        position_ids: Optional[torch.Tensor] = None,
        past_key_value: Optional[SliceUpdateKeyValueCache] = None,
        **kwargs
    ) -> Tuple[torch.Tensor | None, ...]:
        bsz, q_len, _ = hidden_states.shape
        
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(
            1, 2
        )
        value_states = value_states.view(
            bsz, q_len, self.num_key_value_heads, self.head_dim
        ).transpose(1, 2)
        
        cos, sin = self.rotary_emb(value_states, position_ids)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
        
        end_step = attention_mask.shape[-1]
        key_states, value_states = past_key_value.update(
            key_states,
            value_states,
            self.layer_idx,
            slice_indices=(end_step - q_len, end_step)
        )
        
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)
        
        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask = attention_mask
        )
        
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, self.hidden_size)
        attn_output = self.o_proj(attn_output)
        return attn_output, None, None
        
            
            

In [3]:
class StatefulMistralModelForCausalLM(torch.nn.Module):
    def __init__(self, model_path: str, max_context_size: int = 2048, batch_size: int = 1):
        super().__init__()
        # MISTRAL_ATTENTION_CLASSES["sdpa"] = SliceUpdateMistralAttention
        self.model = MistralForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16)
        
        # config: MistralConfig = self.model.config
        # self.kv_cache_shape: Tuple[int, ...] = (
        #     config.num_hidden_layers,
        #     batch_size,
        #     config.num_key_value_heads,
        #     max_context_size,
        #     config.hidden_size // config.num_attention_heads
        # )
        # self.kv_cache = SliceUpdateKeyValueCache(shape=self.kv_cache_shape)
        # self.register_buffer("keyCache", self.kv_cache.k_cache)
        # self.register_buffer("valueCache", self.kv_cache.v_cache)
        
    @torch.no_grad()
    def forward(
        self,
        input_ids: torch.LongTensor,
        causal_mask: torch.Tensor
    ) -> torch.Tensor:
        # self.kv_cache.past_seen_tokens = causal_mask.shape[-1] - input_ids.shape[-1]
        return self.model(
            input_ids,
            attention_mask=causal_mask,
            # past_key_values=self.kv_cache
        ).logits


In [4]:
max_context_size: int = 2048
model_id: str = "mistralai/Mistral-7B-Instruct-v0.3"
torch_model = StatefulMistralModelForCausalLM(model_id, max_context_size=max_context_size)
torch_model.eval()
input_ids: torch.Tensor = torch.zeros((1,2), dtype=torch.int32)
causal_mask: torch.Tensor = torch.zeros((1, 1, 2, 5), dtype=torch.float16)
traced_model = torch.jit.trace(torch_model, [input_ids, causal_mask])


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
  if attention_mask.max() != 0:


In [7]:
# save traced model

traced_model.save("mistral-7B-fp16-traced.pt") # type: ignore

In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [6]:
test_input = tokenizer("Hello, how are you?", return_tensors="pt")
print(test_input)

causal_mask = torch.triu(torch.full((1, 1, test_input.input_ids.shape[-1], test_input.input_ids.shape[-1]), 0, dtype=torch.float16))

{'input_ids': tensor([[    1, 23325, 29493,  1678,  1228,  1136, 29572]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}


In [7]:
output = traced_model(test_input.input_ids, causal_mask) # type: ignore

In [11]:
def sample(logits: torch.Tensor, k: int = 3) -> torch.Tensor:
    # Get the last token's logits
    last_token_logits = logits[0][-1]
    # Get the top k values and indices
    top_k_values, top_k_indices = torch.topk(last_token_logits, k)
    # Sample from the top k
    probs = torch.softmax(top_k_values, dim=-1)
    chosen_idx = torch.multinomial(probs, num_samples=1, generator=torch.Generator(device="cuda").manual_seed(42))
    return top_k_indices[chosen_idx]

from typing import Generator

def generate(model: torch.nn.Module, input_ids: torch.Tensor, eos_id: int, max_length: int, device: str = "cpu") -> Generator[torch.Tensor, None, None]:
    model.to(device)
    def inference(model: torch.nn.Module, input_ids: torch.Tensor) -> torch.Tensor:
        input_ids = input_ids.to(device)
        causal_mask = torch.triu(torch.full((1, 1, input_ids.shape[-1], input_ids.shape[-1]), 0, dtype=torch.float16), diagonal=1).to(device)
        return model(input_ids, causal_mask)
    input_ids = input_ids.to(device)
    logits = inference(model, input_ids)
    token = sample(logits)
    n_tokens = 0
    while True:
        yield token
        # print(token.shape)
        input_ids = torch.cat([input_ids, token.unsqueeze(0)], dim=-1)
        logits = inference(model, input_ids)
        token = sample(logits)
        if token == eos_id or n_tokens >= max_length:
            break
        n_tokens += 1
    

In [15]:
token_output = []
for token in generate(torch_model, test_input.input_ids, tokenizer.eos_token_id, 100, device="cuda"):
    token_output.append(token)
    print(tokenizer.decode(token), end=" ")

I ' m doing well , thank you . How can I help you today ? I ' m looking for a specific book . Great ! I ' d be happy to help you find it . Could you please tell me the title and author of the book you ' re looking for ? Sure , the book is called " To Kill a Mock ing bird " and it was written by Harper Lee . I ' ll see if I can find it for you . 
 
 I ' m sorry , but I don ' 

{'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>'}

In [23]:
sampled_token = sample(output)
print(sampled_token)

tokenizer.decode(sampled_token)
print(sampled_token.shape)
print(output.shape)
print(test_input.input_ids.shape)


tensor(1083)
torch.Size([])
torch.Size([1, 7, 32768])
torch.Size([1, 7])


In [7]:
import coremltools as ct
import numpy as np
# convert traced TorchScript to CoreML format

query_length = ct.RangeDim(lower_bound=1, upper_bound=max_context_size, default=1)
end_step_dim = ct.RangeDim(lower_bound=1, upper_bound=max_context_size, default=1)

inputs: List[ct.TensorType] = [
    ct.TensorType(shape=(1, query_length), dtype=np.int32, name="inputIds"),
    ct.TensorType(
        shape=(1, 1, query_length, end_step_dim),
        dtype=np.float16,
        name="causalMask"
    )
]

outputs: List[ct.TensorType] = [ct.TensorType(dtype=np.float16, name="logits")]
# states: List[ct.StateType] = [
#     ct.StateType(
#         wrapped_type=ct.TensorType(shape=torch_model.kv_cache_shape, dtype=np.float16),
#         name="keyCache"
#     ),
#     ct.StateType(
#         wrapped_type=ct.TensorType(shape=torch_model.kv_cache_shape, dtype=np.float16),
#         name="valueCache"
#     )
# ]

Failed to load _MLModelProxy: No module named 'coremltools.libcoremlpython'


In [9]:
mlmodel_fp16 = ct.convert(
    traced_model,
    inputs=inputs,
    outputs=outputs,
    # states=states,
    minimum_deployment_target=ct.target.iOS18,
    skip_model_load=True,
    compute_precision=ct.precision.FLOAT16
)

Converting PyTorch Frontend ==> MIL Ops: 100%|█████████▉| 4740/4741 [00:02<00:00, 1851.24 ops/s]
Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00,  8.00 passes/s]
Running MIL default pipeline: 100%|██████████| 84/84 [00:20<00:00,  4.16 passes/s]
Running MIL backend_mlprogram pipeline: 100%|██████████| 12/12 [00:00<00:00, 24.88 passes/s]


In [10]:
mlmodel_fp16.save("mlmodel-no-state-fp32.mlpackage")

In [11]:
op_config = ct.optimize.coreml.OpLinearQuantizerConfig(
    mode="linear_symmetric",
    dtype="int4",
    granularity="per_block",
    block_size=32    
)

config = ct.optimize.coreml.OptimizationConfig(global_config=op_config)
mlmodel_int4 = ct.optimize.coreml.linear_quantize_weights(mlmodel_fp16, config=config)

Running compression pass linear_quantize_weights: 100%|██████████| 296/296 [00:44<00:00,  6.72 ops/s]
Running MIL frontend_milinternal pipeline: 0 passes [00:00, ? passes/s]
Running MIL default pipeline: 100%|██████████| 84/84 [00:12<00:00,  6.92 passes/s]
Running MIL backend_mlprogram pipeline: 100%|██████████| 12/12 [00:00<00:00, 20.78 passes/s]


In [12]:
mlmodel_int4.save("mlmodel-no-state-fp32-to-int4.mlpackage")

In [18]:
!du -hs ./mlmodel-no-state-fp16.mlpackage/

 14G	./mlmodel-no-state-fp16.mlpackage/


In [19]:
!du -hs ./mlmodel-no-state-int4.mlpackage/

3.8G	./mlmodel-no-state-int4.mlpackage/
