In [None]:
import platform

arch = platform.machine()
sys = platform.system()
processor = platform.processor()
print(f"{arch}\n{sys}\n{processor}")

In [None]:
import onnxruntime as ort
import os
import numpy as np

from pathlib import Path
from tokenizers import Tokenizer

In [None]:
root_dir = Path.cwd().parent.parent
root_dir

In [None]:
onnx_root = Path(ort.__file__).parent
onnx_root

In [None]:
model_subdirectory = "DeepSeek-R1-Distilled-NPU-Optimized"

# The embeddings model is entry point, use netron to visualize
model_name = "deepseek_r1_1_5_embeddings_quant_v2.0.onnx"

# Need this for "context" graph (processes first hidden states)
context_model = "deepseek_r1_1_5_ctx_v2.2.onnx_ctx.onnx"

# Processes second hidden states before passing to head
context_model_iter = "deepseek_r1_1_5_iter_v2.2.onnx_ctx.onnx"

# Initial stage just provides hidden states, need to get logits by applying model head.
head_model = "deepseek_r1_1_5_head_quant_v2.0.onnx"

# Genai configuration path
configuration_json = "genai_config.json"

# Tokenizer
tokenizer_json = "tokenizer.json"

In [None]:
model_path = root_dir/"models"/model_subdirectory/model_name
ctx_path = root_dir/"models"/model_subdirectory/context_model
ctx_path_itr = root_dir/"models"/model_subdirectory/context_model_iter
head_path = root_dir/"models"/model_subdirectory/head_model
tokenizer_path = root_dir/"models"/model_subdirectory/tokenizer_json
config_path = root_dir/"models"/model_subdirectory/configuration_json
hexagon_driver = onnx_root/"capi"/"QnnHtp.dll"

In [None]:
session_options = ort.SessionOptions()

session_options.add_session_config_entry("ep.context_enable","0")
session_options.add_session_config_entry("ep.context_file_path",str(ctx_path))
session_options.add_session_config_entry("ep.context_embed_mode","1")
session_options.add_session_config_entry("qnn.backend_config_path",str(config_path))

qnn_provider_options = {
    "backend_path": hexagon_driver
}

embedding_session = ort.InferenceSession(model_path,
                                providers= [("QNNExecutionProvider",qnn_provider_options)],
                               sess_options= session_options
                              )

head_session = ort.InferenceSession(head_path,
                                providers= [("QNNExecutionProvider",qnn_provider_options)],
                               sess_options= session_options
                              )

ctx_session = ort.InferenceSession(ctx_path,
                                    providers=[("QNNExecutionProvider",qnn_provider_options)],
                                    sess_options= session_options
                                        )

ctx_itr_session = ort.InferenceSession(ctx_path_itr,
                                         providers=[("QNNExecutionProvider",qnn_provider_options)],
                                         sess_options= session_options
                                      )
embedding_session.get_providers()

In [None]:
inputs = embedding_session.get_inputs()
outputs = embedding_session.get_outputs()
input_0 = inputs[0]
output_0 = outputs[0]

In [None]:
print(f"Expected Input Shape: {input_0.shape}")
print(f"Expected Input Type: {input_0.type}")
print(f"Expected Input Name: {input_0.name}")

In [None]:
print(f"Expected Output Shape: {output_0.shape}")
print(f"Expected Output Type: {output_0.type}")
print(f"Expected Output Name: {output_0.name}")

In [None]:
inputs_ctx = ctx_session.get_inputs()
outputs_ctx = ctx_session.get_outputs()
input_0_ctx = inputs_ctx[0]
output_0_ctx = outputs_ctx[0]

In [None]:
print(f"Expected Input Shape: {input_0_ctx.shape}")
print(f"Expected Input Type: {input_0_ctx.type}")
print(f"Expected Input Name: {input_0_ctx.name}")

In [None]:
print(f"Expected Output Shape: {output_0_ctx.shape}")
print(f"Expected Output Type: {output_0_ctx.type}")
print(f"Expected Output Name: {output_0_ctx.name}")

In [None]:
inputs_ctx_itr = ctx_itr_session.get_inputs()
outputs_ctx_itr = ctx_itr_session.get_outputs()
input_0_ctx_itr = inputs_ctx_itr[0]
output_0_ctx_itr = outputs_ctx_itr[0]

In [None]:
print(f"Expected Input Shape: {input_0_ctx_itr.shape}")
print(f"Expected Input Type: {input_0_ctx_itr.type}")
print(f"Expected Input Name: {input_0_ctx_itr.name}")

In [None]:
print(f"Expected Output Shape: {output_0_ctx_itr.shape}")
print(f"Expected Output Type: {output_0_ctx_itr.type}")
print(f"Expected Output Name: {output_0_ctx_itr.name}")

In [None]:
inputs_head = head_session.get_inputs()
outputs_head = head_session.get_outputs()
input_0_head = inputs_head[0]
output_0_head = outputs_head[0]

In [None]:
print(f"Expected Input Name: {input_0_head.name}")
print(f"Expected Input Shape: {input_0_head.shape}")
print(f"Expected Input Type: {input_0_head.type}")

In [None]:
print(f"Expected Output Name: {output_0_head.name}")
print(f"Expected Output Shape: {output_0_head.shape}")
print(f"Expected Output Type: {output_0_head.type}")

In [None]:
tokenizer = Tokenizer.from_file(str(tokenizer_path))

In [None]:
my_query = "<｜User｜>\nWhat is it like to be a dog? Please explain step by step.\n<｜Assistant｜> <think>\n"
encoding = tokenizer.encode(my_query)

In [None]:
print("Token IDs:", encoding.ids)
print("Tokens:", encoding.tokens)

In [None]:
input_ids = encoding.ids
input_ids

In [None]:
# pad_token_id = tokenizer.token_to_id("<|pad|>") or 0
# pad_token_id

In [None]:
# pad the inputs to expected size of seq_len of 64
# target_seq_len = 64
# input_ids += [pad_token_id] * (target_seq_len - len(input_ids))
input_ids = np.array([input_ids], dtype=np.int64)
input_ids.shape

In [None]:
# Run embedding session first
embedding_output = embedding_session.run(None, {"input_ids":input_ids})[0]
embedding_output.shape

In [None]:
embedding_output

In [None]:
# Preparing inputs for prompt
batch_size = 1
seq_len = embedding_output.shape[1]
hidden_size = embedding_output.shape[2]
num_heads = 2
attn_head_size = 128 #hidden_size // num_heads
num_layers = 28
max_seq_len = 64

In [None]:
attn_head_size

In [None]:
hidden_size

In [None]:
empty_kv = {}
for i in range(num_layers):
    past_shape = (batch_size, num_heads, max_seq_len, attn_head_size)
    empty_kv[f"past_keys_{i}"] = np.zeros(past_shape, dtype=np.float32)
    empty_kv[f"past_values_{i}"] = np.zeros(past_shape, dtype=np.float32)

len(empty_kv)

In [None]:
empty_kv.keys()

In [None]:
empty_kv.values()

In [None]:
embedding_output.shape

In [None]:
# subtract 1 off shape to account for indexing from 0-15
init_sequence_length = np.array(embedding_output.shape[1]-1, dtype=np.int32).reshape(1,1)
max_seq_length = np.array([max_seq_len], dtype=np.int32)

In [None]:
init_sequence_length = np.array(5, dtype=np.int32).reshape(1,1)

seq_lens = {
    "past_seq_len": init_sequence_length,
    "total_seq_len": max_seq_length #seq_len
}
seq_lens

In [None]:
embedding_output.shape

In [None]:
# pad the inputs to expected size of seq_len of 64
batch_size, seq_len, embed_dim = embedding_output.shape
target_seq_len = 64

padded_embedding = np.zeros((batch_size, target_seq_len, embed_dim), dtype=embedding_output.dtype)

padded_embedding[:, :seq_len, :] = embedding_output
padded_embedding.shape

In [None]:
# Check to ensure null vectors were added
padded_embedding[0,:18,:]

In [None]:
init_prompt_inputs = {
    "input_hidden_states": padded_embedding, #embedding_output,
    **empty_kv,
    **seq_lens
}
init_prompt_inputs

In [None]:
init_prompt_inputs.keys()

In [None]:
prompt_outputs = ctx_session.run(None, init_prompt_inputs)
len(prompt_outputs)

In [None]:
prompt_outputs[0].shape

In [None]:
# Extract final hidden states and present_keys/values
print("Batch, prompt length (up to max 64 tokens), embedding size")
output_hidden_states = prompt_outputs[0]
output_hidden_states.shape

In [None]:
print("Batch, key/value heads, prompt length (up to max 64 tokens), head dimension (size of projection for each head)")
print("Note: Total embedding size is 1536, this is split amongst 12 attention heads")
prompt_outputs[1].shape

In [None]:
prompt_outputs[1][0].shape

In [None]:
print("Prompt Length x Head Dimension (Embedding Window)")
prompt_outputs[1][0][0].shape

In [None]:
# Populate initial past key/values
# Must start with index==1 because index==0 is output_hidden_states (see genai_config.json)
present_kv = {f"past_keys_{i}": prompt_outputs[1 + i * 2] for i in range(num_layers)}
present_kv.update({f"past_values_{i}": prompt_outputs[1 + i * 2 + 1] for i in range(num_layers)})
present_kv

In [None]:
present_kv.keys()

In [None]:
# Dimension checks
present_kv["past_keys_0"].shape

In [None]:
present_kv["past_keys_27"].shape

In [None]:
output_hidden_states.shape

In [None]:
logits = head_session.run(None, {"output_hidden_states": output_hidden_states})[0]
logits

In [None]:
logits.shape

In [None]:
logits[0,-1].shape

In [None]:
# Greedy Inference
# Grabs last tokens logits
next_token_id = int(np.argmax(logits[0, -1]))
next_token_id

In [None]:
tokenizer.decode([next_token_id])

In [None]:
max_tokens = 50

generated_ids = [next_token_id]
prev_seq_len = 64

for _ in range(max_tokens):
    input_ids = np.array([[next_token_id]], dtype=np.int64)
    # print(tokenizer.decode(generated_ids, skip_special_tokens=True))
    # print(tokenizer.decode([next_token_id], skip_special_tokens=True))
    embedding_output = embedding_session.run(None, {"input_ids": input_ids})[0]

    # print(embedding_output.shape)

    lengths = {
    "past_seq_len": np.array([[prev_seq_len]], dtype=np.int32),
    "total_seq_len": np.array([prev_seq_len + 1], dtype=np.int32)
    }

    iter_inputs = {
    **present_kv,
    "input_hidden_states": embedding_output,
    **lengths,
    }

    iter_outputs = ctx_itr_session.run(None, iter_inputs)

    # Hidden states are stored in last index of iter outputs
    output_hidden_states = iter_outputs[-1]
    
    # For output tensor update key/value layers start at index = 0 
    present_kv = {f"past_keys_{i}": iter_outputs[i * 2] for i in range(num_layers)}
    present_kv.update({f"past_values_{i}":iter_outputs[i * 2 + 1] for i in range(num_layers)})
    logits = head_session.run(None, {"output_hidden_states": output_hidden_states})[0]

    next_token_id = int(np.argmax(logits[0, -1]))
    generated_ids.append(next_token_id)

    prev_seq_len += 1

    if next_token_id == tokenizer.token_to_id("< | end_of_sentence | >"):
        break

output_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
print("")
print("*"*100)
print("\nInitial Query:\n", my_query)
print("Generated:", output_text)