In [1]:
import numpy as np
from transformers import AutoModelForCausalLM
import openvino as ov
from nncf import compress_weights, CompressWeightsMode
import nncf
import torch

#model_id = 'stabilityai/japanese-stablelm-base-alpha-7b'
model_id = 'stabilityai/japanese-stablelm-base-gamma-7b'
model_vendor, model_name = model_id.split('/') 

INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, tensorflow, onnx, openvino


## Load (or download) the model

In [2]:
model = AutoModelForCausalLM.from_pretrained(f'{model_vendor}/{model_name}', trust_remote_code=True)
model.eval()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (norm): MistralRMSNorm(

## Define '`example_input`' for model conversion

In [3]:
num_seq = 10
# from config.json
vocab_size = 32000
hidden_size = 4096
num_hidden_layers = 32
num_attention_heads = 32
num_key_value_heads = 8

past_kv = torch.Tensor(size=(1, num_key_value_heads, 0, hidden_size // num_hidden_layers))
past_key_values = tuple([(past_kv, past_kv) for _ in range(num_hidden_layers)])

example_input = {
    'input_ids'     : torch.tensor([[ 123 for _ in range(num_seq)]], dtype=torch.int),
    'attention_mask': torch.tensor([[ 1] * num_seq], dtype=torch.int),
    'position_ids'  : torch.tensor([[ nn for nn in range(num_seq)]], dtype=torch.int),
    #'inputs_embeds': np.array([0], dtype=np.int32),
    # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
    #'head_mask': torch.Tensor(size=(num_attention_heads,1,1,1,n)).to(torch.float32),

    'past_key_values' : past_key_values,
    #  past_key_values [n][0|1][ 1, 32, seq_len, 128]       # alpha
    #  past_key_values [n][0|1][ 1, 8, seq_len, 128]        # gamma

    #'labels': np.zeros((1,100), dtype=np.int32),
    'use_cache'           : torch.tensor( True, dtype=torch.bool),
    'output_attentions'   : torch.tensor(False, dtype=torch.bool),
    'output_hidden_states': torch.tensor(False, dtype=torch.bool),
    'return_dict'         : torch.tensor(False, dtype=torch.bool),
}

## Convert the model into OpenVINO IR

In [4]:
ov_model = ov.convert_model(model, example_input=example_input)
print(ov_model)



  if use_cache:
  elif self._attn_implementation == "sdpa" and not output_attentions:
  if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
  if past_key_values_length > 0:
  if query_length > 1 and not is_tracing:
  all_hidden_states = () if output_hidden_states else None
  all_self_attns = () if output_attentions else None
  if output_hidden_states:
  if output_attentions:
  if seq_len > self.max_seq_len_cached:
  if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
  if output_attentions:
  if use_cache:
  if use_cache:
  next_decoder_cache = layer_outputs[2 if output_attentions else 1]
  if output_attentions:
  if output_hidden_states:
  if use_cache:
  if not return_dict:
  if not return_dict:


<Model: 'Model0'
inputs[
<ConstOutput: names[input_ids] shape[?,?] type: i32>,
<ConstOutput: names[attention_mask] shape[?,?] type: i32>,
<ConstOutput: names[position_ids] shape[?,?] type: i32>,
<ConstOutput: names[42, key_states.1] shape[?,8,?,128] type: f32>,
<ConstOutput: names[43] shape[?,8,?,128] type: f32>,
<ConstOutput: names[44] shape[?,8,?,128] type: f32>,
<ConstOutput: names[45] shape[?,8,?,128] type: f32>,
<ConstOutput: names[46] shape[?,8,?,128] type: f32>,
<ConstOutput: names[47] shape[?,8,?,128] type: f32>,
<ConstOutput: names[48] shape[?,8,?,128] type: f32>,
<ConstOutput: names[49] shape[?,8,?,128] type: f32>,
<ConstOutput: names[50] shape[?,8,?,128] type: f32>,
<ConstOutput: names[51] shape[?,8,?,128] type: f32>,
<ConstOutput: names[52] shape[?,8,?,128] type: f32>,
<ConstOutput: names[53] shape[?,8,?,128] type: f32>,
<ConstOutput: names[54] shape[?,8,?,128] type: f32>,
<ConstOutput: names[55] shape[?,8,?,128] type: f32>,
<ConstOutput: names[56] shape[?,8,?,128] type: f3

In [9]:
compressed_model = compress_weights(ov_model.clone(), mode=CompressWeightsMode.INT8)
ov.save_model(compressed_model, 'openvino_model_int8.xml')



INFO:nncf:Statistics of the bitwidth distribution:
+--------------+---------------------------+-----------------------------------+
| Num bits (N) | % all parameters (layers) |    % ratio-defining parameters    |
|              |                           |             (layers)              |
| 8            | 100% (226 / 226)          | 100% (226 / 226)                  |
+--------------+---------------------------+-----------------------------------+


Output()

In [10]:
compressed_model = compress_weights(ov_model.clone(), mode=CompressWeightsMode.INT4_ASYM)
ov.save_model(compressed_model, 'openvino_model_int4asym.xml')

INFO:nncf:Statistics of the bitwidth distribution:
+--------------+---------------------------+-----------------------------------+
| Num bits (N) | % all parameters (layers) |    % ratio-defining parameters    |
|              |                           |             (layers)              |
| 8            | 4% (2 / 226)              | 0% (0 / 224)                      |
+--------------+---------------------------+-----------------------------------+
| 4            | 96% (224 / 226)           | 100% (224 / 224)                  |
+--------------+---------------------------+-----------------------------------+


Output()