diff --git a/backends/qualcomm/_passes/convert_conv1d_to_conv2d.py b/backends/qualcomm/_passes/convert_conv1d_to_conv2d.py index dabe0243a47..1ee71d42bd4 100644 --- a/backends/qualcomm/_passes/convert_conv1d_to_conv2d.py +++ b/backends/qualcomm/_passes/convert_conv1d_to_conv2d.py @@ -99,7 +99,8 @@ def call(self, graph_module: torch.fx.GraphModule): ) num_args = len(node.args) - bias_node = node.args[2] + + bias_node = node.args[2] if num_args > 2 else None stride = [1] + node.args[3] if num_args > 3 else [1, 1] padding = [0] + node.args[4] if num_args > 4 else [0, 0] if node.target == torch.ops.aten.conv1d.default: diff --git a/backends/qualcomm/_passes/lift_constant_scalar_operands.py b/backends/qualcomm/_passes/lift_constant_scalar_operands.py index 757d5cee2c4..0e65396dbcf 100644 --- a/backends/qualcomm/_passes/lift_constant_scalar_operands.py +++ b/backends/qualcomm/_passes/lift_constant_scalar_operands.py @@ -40,10 +40,13 @@ class TensorOpInfo: aten.ne.Scalar: TensorOpInfo(aten.ne.Tensor, False, False), aten.add.Scalar: TensorOpInfo(aten.add.Tensor, False, False), aten.add_.Scalar: TensorOpInfo(aten.add_.Tensor, False, False), + # For below cases, refer to LiftAddTensor Model in UT for sample + aten.add.Tensor: TensorOpInfo(aten.add.Tensor, False, False), aten.div.Scalar: TensorOpInfo(aten.div.Tensor, False, False), aten.mul.Scalar: TensorOpInfo(aten.mul.Tensor, False, False), aten.rsub.Scalar: TensorOpInfo(aten.rsub.Tensor, False, False), aten.sub.Scalar: TensorOpInfo(aten.sub.Tensor, False, False), + aten.sub.Tensor: TensorOpInfo(aten.sub.Tensor, False, False), aten.pow.Tensor_Scalar: TensorOpInfo(aten.pow.Tensor_Tensor, False, False), # The scalar number arg[1] is missing when using default. Result in a corner case to deal aten.leaky_relu.default: TensorOpInfo(aten.prelu.default, True, False), diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 025c0bee171..8be289871e1 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -998,6 +998,15 @@ def forward(self, x): return self.constant < x +class LiftAddTensor(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + N = 2 - 1 + return x + N + + class Linear(torch.nn.Module): def __init__(self, use_bias: bool = True): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 081dda7187b..748b2f43059 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -1060,6 +1060,12 @@ def test_qnn_backend_einsum_outer_product_relu(self): ) self.lower_module_and_test_output(module, sample_input) + # TODO: Create a new UT class for passes specific checks + def test_qnn_backend_lift_add_tensor(self): + module = LiftAddTensor() # noqa: F405 + sample_input = (torch.Tensor([1, 2, 3, 4]).to(torch.int32),) + self.lower_module_and_test_output(module, sample_input) + @unittest.skip("Fail because of bad accuracy") def test_qnn_backend_moe_feed_forward(self): args = ModelArgs() diff --git a/examples/qualcomm/CMakeLists.txt b/examples/qualcomm/CMakeLists.txt index 4f338a23044..7936528d610 100644 --- a/examples/qualcomm/CMakeLists.txt +++ b/examples/qualcomm/CMakeLists.txt @@ -93,6 +93,9 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/executor_runner) # build qnn_llama_runner for llama add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/llama) +# build qnn_mimi_decoder_runner +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/moshi) + # build qaihub_llama2_7b_runner and qaihub_llama3_8b_runner add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/qaihub_scripts/llama) diff --git a/examples/qualcomm/oss_scripts/moshi/CMakeLists.txt b/examples/qualcomm/oss_scripts/moshi/CMakeLists.txt new file mode 100644 index 00000000000..70356e54906 --- /dev/null +++ b/examples/qualcomm/oss_scripts/moshi/CMakeLists.txt @@ -0,0 +1,34 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set(_qnn_mimi_decoder_runner__srcs + ${CMAKE_CURRENT_LIST_DIR}/qnn_mimi_decoder_runner.cpp + ${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp + ${CMAKE_CURRENT_LIST_DIR}/runner/runner.h +) + +# build mimi decoder runner +add_executable(qnn_mimi_decoder_runner ${_qnn_mimi_decoder_runner__srcs}) +target_include_directories( + qnn_mimi_decoder_runner PUBLIC ${_common_include_directories} +) +target_link_libraries( + qnn_mimi_decoder_runner + qnn_executorch_backend + executorch_core + extension_module + extension_data_loader + extension_flat_tensor + gflags +) + +target_compile_options( + qnn_llama_runner PUBLIC ${_common_compile_options} +) + +set_target_properties( + qnn_mimi_decoder_runner PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'" +) diff --git a/examples/qualcomm/oss_scripts/moshi/mimi.py b/examples/qualcomm/oss_scripts/moshi/mimi.py index 70e339a32d6..1dba9bc8da1 100644 --- a/examples/qualcomm/oss_scripts/moshi/mimi.py +++ b/examples/qualcomm/oss_scripts/moshi/mimi.py @@ -4,16 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# import argparse import io import json +import logging import os -import random from multiprocessing.connection import Client +import moshi + import numpy as np import requests - import sphn import torch @@ -24,7 +24,18 @@ annotate_mimi_decoder, ) from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype - +from executorch.backends.qualcomm.utils.utils import ( + generate_htp_compiler_spec, + generate_qnn_executorch_compiler_spec, + get_soc_to_chipset_map, + to_edge_transform_and_lower_to_qnn, +) +from executorch.examples.qualcomm.oss_scripts.moshi.model.static_mimi import ( + _transformer_kwargs, + DEFAULT_REPO, + get_static_mimi, + MIMI_NAME, +) from executorch.examples.qualcomm.utils import ( build_executorch_binary, make_output_dir, @@ -33,22 +44,20 @@ setup_common_args_and_variables, SimpleADB, ) +from executorch.exir.capture._config import ExecutorchBackendConfig +from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass from huggingface_hub import hf_hub_download from moshi.models import loaders from torchao.quantization.pt2e import MinMaxObserver +from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) +logging.getLogger().setLevel(logging.INFO) -def seed_all(seed): - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) # for multi-GPU setups - random.seed(seed) - np.random.seed(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False +MOSHI_VERSION = "0.2.3" def read_mp3_from_url(url): @@ -64,98 +73,87 @@ def read_mp3_from_url(url): return waveform.numpy(), sample_rate -def compute_scores(cpu_decode_res: torch.Tensor, htp_decode_res: torch.Tensor): - assert cpu_decode_res.shape == htp_decode_res.shape, "Tensor shapes do not match" - abs_diff = torch.abs(cpu_decode_res - htp_decode_res) +def compute_scores(cpu_decode_res: torch.Tensor, qnn_decode_res: torch.Tensor): + assert cpu_decode_res.shape == qnn_decode_res.shape, "Tensor shapes do not match" + abs_diff = torch.abs(cpu_decode_res - qnn_decode_res) atol = torch.max(abs_diff) - print("Atol: ", atol) + logging.info("Atol: {:.3f}".format(atol)) cpu_decode_res = cpu_decode_res.float() - htp_decode_res = htp_decode_res.float() - error = cpu_decode_res - htp_decode_res + qnn_decode_res = qnn_decode_res.float() + error = cpu_decode_res - qnn_decode_res original_power = torch.mean(torch.pow(cpu_decode_res, 2)) error_power = torch.mean(torch.pow(error, 2)) sqnr = 10 * torch.log10(original_power / error_power) - print("SQNR: ", sqnr) + logging.info("SQNR: {:.3f}".format(sqnr)) -def test_decoder_with_emb_input(mimi, args): - class MimiDecode(nn.Module): - def __init__(self, mimi: nn.Module): - super().__init__() - self.mimi_model = mimi +def init_inputs(): + num_layers = _transformer_kwargs["num_layers"] + batch_size = 1 # 1 chunk per batch + num_heads = _transformer_kwargs["num_heads"] + context = _transformer_kwargs["context"] + head_dim = _transformer_kwargs["d_model"] // _transformer_kwargs["num_heads"] - def forward(self, x): - x = x.transpose(1, 2) - x = self.mimi_model.upsample(x) - (emb,) = self.mimi_model.decoder_transformer(x) - emb.transpose(1, 2) - with self.mimi_model._context_for_encoder_decoder: - out = self.mimi_model.decoder(emb) - return out - - emb_input = torch.rand(1, 1, 512, device="cpu") - mimi_decode = MimiDecode(mimi).eval() - cpu_res = mimi_decode(emb_input) - pte_filename = "mimi_decoder_emb_qnn" - - quantizer = make_quantizer( - quant_dtype=QuantDtype.use_16a8w, - per_channel_conv=True, - per_channel_linear=True, - act_observer=MinMaxObserver, + k_cache = torch.zeros( + (num_layers, batch_size, num_heads, context, head_dim), + device="cpu", + dtype=torch.float32, ) - quantizer.add_custom_quant_annotations((annotate_mimi_decoder,)) - - emb_inputs = [(emb_input,)] - build_executorch_binary( - mimi_decode, - emb_inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - emb_inputs, - custom_quantizer=quantizer, - quant_dtype=QuantDtype.use_16a8w, - shared_buffer=args.shared_buffer, + v_cache = torch.zeros( + (num_layers, batch_size, num_heads, context, head_dim), + device="cpu", + dtype=torch.float32, ) - adb = SimpleADB( - qnn_sdk=os.getenv("QNN_SDK_ROOT"), - build_path=f"{args.build_folder}", - pte_path=f"{args.artifact}/{pte_filename}.pte", - workspace=f"/data/local/tmp/executorch/{pte_filename}", - device_id=args.device, - host_id=args.host, - soc_model=args.model, - shared_buffer=args.shared_buffer, + # end_index will perform end_index % context + # end_offset will keep increment even after end_offset > context + end_index = torch.zeros((num_layers, 1), device="cpu", dtype=torch.long) + end_offset = torch.zeros((num_layers, 1), device="cpu", dtype=torch.long) + + # partial for transpose conv layers, please refer to StaticRawStreamingConvTranspose1d under static_convtr.py for more info + # There are total of 5 transpose conv in mimi decoder, 1 is under to_encoder_framerate, and 4 is under SeanetDecoder + partial_convtr_0 = torch.zeros((1, 512, 2), device="cpu", dtype=torch.float32) + partial_convtr_1 = torch.zeros((1, 512, 8), device="cpu", dtype=torch.float32) + partial_convtr_2 = torch.zeros((1, 256, 6), device="cpu", dtype=torch.float32) + partial_convtr_3 = torch.zeros((1, 128, 5), device="cpu", dtype=torch.float32) + partial_convtr_4 = torch.zeros((1, 64, 4), device="cpu", dtype=torch.float32) + + # Some index for naming are skipped on purpose as those conv_layers have empty previous + previous_conv_0 = torch.zeros((1, 512, 6), device="cpu", dtype=torch.float32) + previous_conv_1 = torch.zeros((1, 512, 2), device="cpu", dtype=torch.float32) + previous_conv_3 = torch.zeros((1, 256, 2), device="cpu", dtype=torch.float32) + previous_conv_5 = torch.zeros((1, 128, 2), device="cpu", dtype=torch.float32) + previous_conv_7 = torch.zeros((1, 64, 2), device="cpu", dtype=torch.float32) + previous_conv_9 = torch.zeros((1, 64, 2), device="cpu", dtype=torch.float32) + + return ( + k_cache, + v_cache, + end_index, + end_offset, + partial_convtr_0, + partial_convtr_1, + partial_convtr_2, + partial_convtr_3, + partial_convtr_4, + previous_conv_0, + previous_conv_1, + previous_conv_3, + previous_conv_5, + previous_conv_7, + previous_conv_9, ) - adb.push(inputs=emb_inputs, input_list="input_0_0.raw\n") - adb.execute() - # collect output data - output_data_folder = f"{args.artifact}/outputs" - make_output_dir(output_data_folder) - - adb.pull(output_path=args.artifact) - - emb_predictions = [] - for i in range(len(emb_inputs)): - np_arr = np.fromfile( - os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32 - ) - emb_predictions.append(torch.from_numpy(np_arr).view(1, 1, 1920)) - print("Emb input test results") - compute_scores(cpu_res, emb_predictions[0]) - -def mimi_encode( - mimi, - encode_inputs, - encoder_input_list, - pcm_chunk_size, +def compile_mimi_encoder( + args, + orig_mimi, + encoder_inputs, skip_node_id_set, skip_node_op_set, -) -> torch.Tensor: + encoder_pte_filename, +): class MimiEncode(nn.Module): def __init__(self, mimi: nn.Module): super().__init__() @@ -164,32 +162,34 @@ def __init__(self, mimi: nn.Module): def forward(self, x): return self.mimi_model.encode(x) - mimi_encode_model = MimiEncode(mimi) - - pte_filename = "mimi_encoder_qnn" + mimi_encoder_model = MimiEncode(orig_mimi) build_executorch_binary( - mimi_encode_model.eval(), - encode_inputs[0], + mimi_encoder_model.eval(), + encoder_inputs[0], args.model, - f"{args.artifact}/{pte_filename}", - encode_inputs, + f"{args.artifact}/{encoder_pte_filename}", + encoder_inputs, skip_node_id_set=skip_node_id_set, skip_node_op_set=skip_node_op_set, quant_dtype=QuantDtype.use_8a8w, shared_buffer=args.shared_buffer, ) + +def inference_mimi_encoder( + args, encoder_inputs, encoder_input_list, encoder_pte_filename +): adb = SimpleADB( qnn_sdk=os.getenv("QNN_SDK_ROOT"), build_path=f"{args.build_folder}", - pte_path=f"{args.artifact}/{pte_filename}.pte", - workspace=f"/data/local/tmp/executorch/{pte_filename}", + pte_path=f"{args.artifact}/{encoder_pte_filename}.pte", + workspace=f"/data/local/tmp/executorch/{encoder_pte_filename}", device_id=args.device, host_id=args.host, soc_model=args.model, shared_buffer=args.shared_buffer, ) - adb.push(inputs=encode_inputs, input_list=encoder_input_list) + adb.push(inputs=encoder_inputs, input_list=encoder_input_list) adb.execute() # collect output data @@ -199,35 +199,89 @@ def forward(self, x): adb.pull(output_path=args.artifact) encoder_predictions = [] - # Num chunks should align with args.chunks_per_batch - num_chunks = encode_inputs[0][0].shape[-1] // pcm_chunk_size - for i in range(len(encode_inputs)): + for i in range(len(encoder_inputs)): np_arr = np.fromfile( os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.int64 ) - encoder_predictions.append(torch.from_numpy(np_arr).view(1, 8, num_chunks)) + encoder_predictions.append(torch.from_numpy(np_arr).view(1, 8, 1)) return encoder_predictions -def mimi_decode( - mimi, encode_res_list, pcm_chunk_size, skip_node_id_set, skip_node_op_set -) -> torch.Tensor: - class MimiDecode(nn.Module): - def __init__(self, mimi: nn.Module): - super().__init__() - self.mimi_model = mimi +def export_mimi_encoder( + args, orig_mimi, sample_pcm, pcm_chunk_size, skip_node_id_set, skip_node_op_set +): + encoder_inputs, encoder_input_list = [], "" + count = 0 + cpu_encoded_results = [] + logging.info("streaming encoding...") + for start_idx in range(0, sample_pcm.shape[-1], pcm_chunk_size): + end_idx = min(sample_pcm.shape[-1], start_idx + pcm_chunk_size) + chunk = sample_pcm[..., start_idx:end_idx] + # Preparing QNN inputs + encoder_inputs.append((chunk,)) + encoder_input_list += f"input_{count}_0.raw\n" + count += 1 + # Performing cpu encoding for golden + codes = orig_mimi.encode(chunk) + if codes.shape[-1]: + cpu_encoded_results.append(codes) + + encoder_pte_filename = "mimi_encoder_qnn" + if args.use_cpu_encoder: + logging.info("Using CPU Encoder, Skip Compile and Inference for QNN Encoder") + elif args.compile_only: + logging.info("Compile only for QNN Encoder") + compile_mimi_encoder( + args, + orig_mimi, + encoder_inputs, + skip_node_id_set, + skip_node_op_set, + encoder_pte_filename, + ) + elif args.pre_gen_pte: + logging.info("Inference only for QNN Encoder") + qnn_encoded_results = inference_mimi_encoder( + args, + encoder_inputs, + encoder_input_list, + encoder_pte_filename, + ) + else: + logging.info("Compile and Inference for QNN Encoder") + compile_mimi_encoder( + args, + orig_mimi, + encoder_inputs, + skip_node_id_set, + skip_node_op_set, + encoder_pte_filename, + ) + qnn_encoded_results = inference_mimi_encoder( + args, + encoder_inputs, + encoder_input_list, + encoder_pte_filename, + ) - def forward(self, x): - return self.mimi_model.decode(x) + encoded_results = ( + cpu_encoded_results + if (args.use_cpu_encoder or args.compile_only) + else qnn_encoded_results + ) - mimi_decode_model = MimiDecode(mimi) - decode_inputs, decode_input_list = [], "" - for index, encoder_res in enumerate(encode_res_list): - decode_inputs.append((encoder_res.to(torch.int32),)) - decode_input_list += f"input_{index}_0.raw\n" + # These 2 returned values will be same if use cpu_encoder instead of QNN encoder. + return encoded_results, cpu_encoded_results - pte_filename = "mimi_decoder_qnn" +def compile_static_mimi_decoder( + args, + static_mimi_decoder, + encoded_results, + skip_node_id_set, + skip_node_op_set, + static_decoder_pte_filename, +): quantizer = make_quantizer( quant_dtype=QuantDtype.use_16a8w, per_channel_conv=True, @@ -236,31 +290,85 @@ def forward(self, x): ) quantizer.add_custom_quant_annotations((annotate_mimi_decoder,)) - build_executorch_binary( - mimi_decode_model.eval(), - decode_inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - decode_inputs, + static_states = init_inputs() + + with torch.no_grad(): + static_mimi_decoder(encoded_results[0], *static_states) + + fx_graph_module = torch.export.export( + static_mimi_decoder, + ( + encoded_results[0], + *static_states, + ), + strict=False, + ).module() + + annotated_model = prepare_pt2e(fx_graph_module, quantizer) + logging.info("Quantizing the model...") + for codes in encoded_results: + _out, *static_states = annotated_model(codes, *static_states) + quantized_model = convert_pt2e(annotated_model) + + backend_options = generate_htp_compiler_spec(use_fp16=False) + compiler_spec = generate_qnn_executorch_compiler_spec( + soc_model=get_soc_to_chipset_map()[args.model], + backend_options=backend_options, + ) + + edge_prog_mgr = to_edge_transform_and_lower_to_qnn( + quantized_model, + ( + encoded_results[0], + *static_states, + ), + compiler_spec, skip_node_id_set=skip_node_id_set, skip_node_op_set=skip_node_op_set, - custom_quantizer=quantizer, - quant_dtype=QuantDtype.use_16a8w, - shared_buffer=args.shared_buffer, + ) + + executorch_config = ExecutorchBackendConfig( + memory_planning_pass=MemoryPlanningPass( + alloc_graph_input=False, + alloc_graph_output=False, + ), + ) + exec_prog_mgr = edge_prog_mgr.to_executorch(config=executorch_config) + with open(f"{args.artifact}/{static_decoder_pte_filename}.pte", "wb") as file: + exec_prog_mgr.write_to_file(file) + + +def inference_static_mimi_decoder( + args, + encoded_results, + encoded_results_list, + pcm_chunk_size, + static_decoder_pte_filename, +): + workspace = f"/data/local/tmp/executorch/{static_decoder_pte_filename}" + pte_path = f"{args.artifact}/{static_decoder_pte_filename}.pte" + runner_cmd = " ".join( + [ + f"cd {workspace} &&", + "./qnn_mimi_decoder_runner", + f"--model_path {workspace}/{static_decoder_pte_filename}.pte", + f"--output_folder_path {workspace}/outputs", + ] ) adb = SimpleADB( qnn_sdk=os.getenv("QNN_SDK_ROOT"), build_path=f"{args.build_folder}", - pte_path=f"{args.artifact}/{pte_filename}.pte", - workspace=f"/data/local/tmp/executorch/{pte_filename}", + pte_path=pte_path, + workspace=workspace, device_id=args.device, host_id=args.host, soc_model=args.model, shared_buffer=args.shared_buffer, + runner="examples/qualcomm/oss_scripts/moshi/qnn_mimi_decoder_runner", ) - adb.push(inputs=decode_inputs, input_list=decode_input_list) - adb.execute() + adb.push(inputs=encoded_results, input_list=encoded_results_list) + adb.execute(custom_runner_cmd=runner_cmd) # collect output data output_data_folder = f"{args.artifact}/outputs" @@ -268,102 +376,140 @@ def forward(self, x): adb.pull(output_path=args.artifact) - decoder_predictions = [] - # Num chunks should align with args.chunks_per_batch - num_chunks = decode_inputs[0][0].shape[-1] + num_chunks = len(encoded_results) shape = num_chunks * pcm_chunk_size - for i in range(len(decode_inputs)): - np_arr = np.fromfile( - os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32 + qnn_decode_res = torch.from_numpy( + np.fromfile( + os.path.join(output_data_folder, "output_0_0.raw"), dtype=np.float32 ) - decoder_predictions.append(torch.from_numpy(np_arr).view(1, 1, shape)) - htp_decode_res = torch.cat(decoder_predictions, dim=-1) + ).view(1, 1, shape) + + return qnn_decode_res + + +def export_mimi_decoder( + args, + static_mimi_decoder, + encoded_results, + pcm_chunk_size, + skip_node_id_set, + skip_node_op_set, +): + encoded_results_list = "" + for index, encoder_result in enumerate(encoded_results): + encoded_results[index] = encoder_result.to(torch.int32) + encoded_results_list += f"input_{index}_0.raw\n" + + logging.info("streaming decoding...") + qnn_decode_res = None + static_decoder_pte_filename = "static_mimi_decoder_qnn" + with static_mimi_decoder.streaming(1): + if args.compile_only: + logging.info("Compile only for QNN Static Decoder") + compile_static_mimi_decoder( + args, + static_mimi_decoder, + encoded_results, + skip_node_id_set, + skip_node_op_set, + static_decoder_pte_filename, + ) + elif args.pre_gen_pte: + logging.info("Inference only for QNN Static Decoder") + qnn_decode_res = inference_static_mimi_decoder( + args, + encoded_results, + encoded_results_list, + pcm_chunk_size, + static_decoder_pte_filename, + ) + else: + logging.info("Compile and Inference for QNN Static Decoder") + compile_static_mimi_decoder( + args, + static_mimi_decoder, + encoded_results, + skip_node_id_set, + skip_node_op_set, + static_decoder_pte_filename, + ) + qnn_decode_res = inference_static_mimi_decoder( + args, + encoded_results, + encoded_results_list, + pcm_chunk_size, + static_decoder_pte_filename, + ) + return qnn_decode_res - return htp_decode_res +def main(args): + assert ( + moshi.__version__ == MOSHI_VERSION + ), f"Please ensure Moshi version == {MOSHI_VERSION}, current version is {moshi.__version__}" + + if args.compile_only and args.pre_gen_pte: + exit("Cannot set both compile_only and pre_gen_pte as true") + + logging.info("loading mimi") + if args.mimi_weight is None: + args.mimi_weight = hf_hub_download(args.hf_repo, MIMI_NAME) + orig_mimi = loaders.get_mimi(args.mimi_weight, "cpu") # For encoder + static_mimi = get_static_mimi(args.mimi_weight, "cpu") # For static decoder + logging.info("mimi loaded") -def export_mimi(mimi, args, max_duration_sec=10.0): skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) os.makedirs(args.artifact, exist_ok=True) - if args.emb_input_test: - test_decoder_with_emb_input(mimi, args) - return - - sample_rate = mimi.sample_rate + sample_rate = orig_mimi.sample_rate url = "https://huggingface.co/lmz/moshi-swift/resolve/main/bria-24khz.mp3" sample_pcm, sample_sr = read_mp3_from_url(url) - sample_rate = mimi.sample_rate + sample_rate = orig_mimi.sample_rate sample_pcm = torch.tensor(sample_pcm, device="cpu") - max_duration_len = int(sample_rate * max_duration_sec) + max_duration_len = int(sample_rate * args.max_duration_sec) if sample_pcm.shape[-1] > max_duration_len: sample_pcm = sample_pcm[..., :max_duration_len] sample_pcm = sample_pcm[None].to(device="cpu") - - encoder_inputs, encoder_input_list = [], "" # 1920 chunk_size = 0.08sec - pcm_chunk_size = int(mimi.sample_rate / mimi.frame_rate) - batch_size = pcm_chunk_size * args.chunks_per_batch - count = 0 - for start_idx in range(0, sample_pcm.shape[-1], batch_size): - end_idx = min(sample_pcm.shape[-1], start_idx + batch_size) - chunk = sample_pcm[..., start_idx:end_idx] - encoder_inputs.append((chunk,)) - encoder_input_list += f"input_{count}_0.raw\n" - count += 1 - - print("streaming encoding...") - cpu_encode_res = mimi.encode(sample_pcm) - htp_encode_res = mimi_encode( - mimi, - encoder_inputs, - encoder_input_list, - pcm_chunk_size, - skip_node_id_set, - skip_node_op_set, - ) - - # Leave it here for now, uncomment this to check htp_encoder with cpu_decoder - # htp_res = torch.cat(htp_encode_res, dim=-1) - # cpu_decode_htp_encode = mimi.decode(htp_res) - # sphn.write_wav("cpu_decode_htp_encode.wav", cpu_decode_htp_encode[0, 0].cpu().numpy(), sample_rate) - - print("streaming decoding...") - cpu_decode_res = mimi.decode(cpu_encode_res) - # TODO: Enable streaming mode, which is the correct way to execute 1 chunk at a time. - # with mimi.streaming(1): - htp_decode_res = mimi_decode( - mimi, htp_encode_res, pcm_chunk_size, skip_node_id_set, skip_node_op_set - ) - compute_scores(cpu_decode_res, htp_decode_res) - - sphn.write_wav( - f"{args.artifact}/cpu_decode_res.wav", - cpu_decode_res[0, 0].cpu().numpy(), - sample_rate, - ) - sphn.write_wav( - f"{args.artifact}/htp_decode_res.wav", - htp_decode_res[0, 0].cpu().numpy(), - sample_rate, - ) + pcm_chunk_size = int(orig_mimi.sample_rate / orig_mimi.frame_rate) + qnn_decode_res = None + with torch.no_grad(): + encoded_results, cpu_encoded_results = export_mimi_encoder( + args, + orig_mimi, + sample_pcm, + pcm_chunk_size, + skip_node_id_set, + skip_node_op_set, + ) + qnn_decode_res = export_mimi_decoder( + args, + static_mimi, + encoded_results, + pcm_chunk_size, + skip_node_id_set, + skip_node_op_set, + ) -def main(args): - seed_all(42424242) + if args.compile_only: + exit(f"Finish compile_only and saved to {args.artifact}") - print("loading mimi") - if args.mimi_weight is None: - args.mimi_weight = hf_hub_download(args.hf_repo, loaders.MIMI_NAME) - mimi = loaders.get_mimi(args.mimi_weight, "cpu") - print("mimi loaded") + pcm_ref = orig_mimi.decode(torch.cat(cpu_encoded_results, dim=-1)) + logging.info("PCM ref V.S. QNN streaming mode") + compute_scores(pcm_ref, qnn_decode_res) - with torch.no_grad(): - export_mimi(mimi, args) + sphn.write_wav( + f"{args.artifact}/pcm_ref.wav", pcm_ref[0, 0].cpu().numpy(), sample_rate + ) + sphn.write_wav( + f"{args.artifact}/qnn_decode_res.wav", + qnn_decode_res[0, 0].cpu().numpy(), + sample_rate, + ) if __name__ == "__main__": - parser = setup_common_args_and_variables() parser.add_argument( @@ -375,21 +521,27 @@ def main(args): ) parser.add_argument( - "--chunks_per_batch", - help="Number of chunks to process per time. Default is 1 chunk per batch, which equals to 0.08 second", - default=1, - type=int, + "--max_duration_sec", + help="Max duration seconds for the audio to be processed.", + type=float, + default=10.0, + ) + + parser.add_argument( + "--pre_gen_pte", + help="Run the pre-generated mimi encoder/decoder in the given directory.", + type=str, ) parser.add_argument( - "--emb_input_test", - help="This is just a metrics used to compute accuracy scores, not recommended for general users.", + "--use_cpu_encoder", + help="Enable this flag to perform encoder with cpu.", action="store_true", default=False, ) parser.add_argument("--mimi-weight", type=str) - parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO) + parser.add_argument("--hf-repo", type=str, default=DEFAULT_REPO) args = parser.parse_args() try: diff --git a/examples/qualcomm/oss_scripts/moshi/model/static_conv.py b/examples/qualcomm/oss_scripts/moshi/model/static_conv.py new file mode 100644 index 00000000000..c74c8b73422 --- /dev/null +++ b/examples/qualcomm/oss_scripts/moshi/model/static_conv.py @@ -0,0 +1,145 @@ +# Copyright (c) Kyutai, all rights reserved. +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import math +import typing as tp +import warnings + +import torch +from moshi.modules.conv import apply_parametrization_norm, NormConv1d, StreamingConv1d +from moshi.modules.streaming import RawStreamingConv1d + + +class StaticRawStreamingConv1d(RawStreamingConv1d): + def __init__(self, ignore_previous: bool = False, *args, **kwargs): + super(RawStreamingConv1d, self).__init__(*args, **kwargs) + assert self.padding[0] == 0, "Padding should be handled outside." + assert ( + self.stride[0] <= self.kernel_size[0] + ), "stride must be less than kernel_size." + + self.ignore_previous = ignore_previous + + # Static Mimi Changes + # 1) If ignore_previous, return only output but not previous since it is an empty tensor. + # Refer to StaticStreamingConv1d's forward() comments for more detail + # 2) Remove all states related variables + # 3) Create previous tensor ahead of time, shape should be constant throughout execution. + def forward(self, input: torch.Tensor, previous: torch.Tensor = None): + stride = self.stride[0] + # Effective kernel size accounting for dilation. + kernel = (self.kernel_size[0] - 1) * self.dilation[0] + 1 + + if not self.ignore_previous: + input = torch.cat([previous, input], dim=-1) + B, C, T = input.shape + # We now compute the number of full convolution frames, i.e. the frames + # that are ready to be computed. + num_frames = max(0, int(math.floor((T - kernel) / stride) + 1)) + offset = num_frames * stride + # We will compute `num_frames` outputs, and we are advancing by `stride` + # for each of the frame, so we know the data before `stride * num_frames` + # will never be used again. + if num_frames > 0: + out = super(RawStreamingConv1d, self).forward(input) + else: + # Not enough data as this point to output some new frames. + out = torch.empty( + B, self.out_channels, 0, device=input.device, dtype=input.dtype + ) + if self.ignore_previous: + return out + else: + previous = input[..., offset:] + return out, previous + + +class StaticNormConv1d(NormConv1d): + """Wrapper around Conv1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + + def __init__( + self, + *args, + causal: bool = False, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, # noqa: B006 + ignore_previous: bool = False, + **kwargs, + ): + super(NormConv1d, self).__init__() + self.conv = apply_parametrization_norm( + StaticRawStreamingConv1d(ignore_previous, *args, **kwargs), norm + ) + self.norm_type = norm + self.ignore_previous = ignore_previous + + def forward(self, x, previous=None): + if self.ignore_previous: + return self.conv(x) + else: + return self.conv(x, previous) + + +class StaticStreamingConv1d(StreamingConv1d): + """Conv1d with some builtin handling of asymmetric or causal padding + and normalization. + """ + + # Static Mimi Change + # 1) Add ignore_previous variable. Some previous tensor has shape such as (1, 512, 0), which is an empty tensor. + # These shapes does not really makes sense in QNN, so just ignore for conv_layers that has empty tensor. + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + causal: bool = False, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, # noqa: B006 + pad_mode: str = "reflect", + ignore_previous: bool = False, + ): + super(StreamingConv1d, self).__init__() + # warn user on unusual setup between dilation and stride + if stride > 1 and dilation > 1: + warnings.warn( # noqa: B028 + "StreamingConv1d has been initialized with stride > 1 and dilation > 1" + f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})." + ) + self.conv = StaticNormConv1d( + in_channels, + out_channels, + kernel_size, + stride, + dilation=dilation, + groups=groups, + bias=bias, + causal=causal, + norm=norm, + norm_kwargs=norm_kwargs, + ignore_previous=ignore_previous, + ) + self.causal = causal + self.pad_mode = pad_mode + self.ignore_previous = ignore_previous + + # Static Mimi Changes + # 1) Remove all non streaming mode logic. + # 2) Always perform padding since we need constant output shape + def forward(self, x, previous=None): + if self.ignore_previous: + return self.conv(x) + else: + return self.conv(x, previous) diff --git a/examples/qualcomm/oss_scripts/moshi/model/static_convtr.py b/examples/qualcomm/oss_scripts/moshi/model/static_convtr.py new file mode 100644 index 00000000000..dc97794c8e2 --- /dev/null +++ b/examples/qualcomm/oss_scripts/moshi/model/static_convtr.py @@ -0,0 +1,175 @@ +# Copyright (c) Kyutai, all rights reserved. +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import typing as tp + +import torch +from moshi.modules.conv import ( + apply_parametrization_norm, + NormConvTranspose1d, + StreamingConvTranspose1d, +) +from moshi.modules.resample import ConvTrUpsample1d +from moshi.modules.streaming import RawStreamingConvTranspose1d + + +class StaticRawStreamingConvTranspose1d(RawStreamingConvTranspose1d): + def __init__(self, *args, **kwargs): + super(RawStreamingConvTranspose1d, self).__init__(*args, **kwargs) + assert self.padding[0] == 0, "Padding should be handled outside." + assert self.dilation[0] == 1, "No dilation for now" + assert ( + self.stride[0] <= self.kernel_size[0] + ), "stride must be less than kernel_size." + assert self.output_padding[0] == 0, "Output padding not supported." + + # Static Mimi Changes: + # 1) Remove non streaming mode logic + # 2) Remove all states related variables + # 3) Create partial tensor ahead of time, shape should be constant throughout execution. + def forward(self, x: torch.Tensor, partial: torch.Tensor): # type: ignore + # Batch, Channel, Temp_Dimension + B, C, T = x.shape + stride = self.stride[0] + kernel = self.kernel_size[0] + if T == 0: + return torch.empty(B, self.out_channels, 0, device=x.device, dtype=x.dtype) + out = super(RawStreamingConvTranspose1d, self).forward(x) + + OT = out.shape[-1] + # Due to the potential overlap, the rightmost output of the conv transpose is not + # ready to be output, as it will receive contributions from the next input frames. + # Here we recover those `partial` output frames. We know that the first time step + # of the `partial` tensor corresponds to the first time step of `out` as anything + # coming before the first time step of `out` would have been already flushed. + PT = partial.shape[-1] + if self.bias is not None: + condition = torch.all(partial != 0) + updated_part = torch.where( + condition, (partial - self.bias[:, None]), partial + ) + out = torch.cat((out[..., :PT] + updated_part, out[..., PT:]), dim=-1) + else: + updated_part = out[..., :PT] + partial + out = torch.cat((updated_part, out[..., PT:]), dim=-1) + + # The input is T, the output is S * (T - 1) + K. + # The offset of the left of the next frame will be S * T + # so everything between 0 and S * T is ready to be output, and we need + # to keep in the internal state everything beyond that, i.e. S (T - 1) + K - S T = K - S + invalid_steps = kernel - stride + + start_idx = OT - invalid_steps + + partial = out[..., start_idx:] + out = out[..., :start_idx] + return out, partial + + +class StaticNormConvTranspose1d(NormConvTranspose1d): + """Wrapper around ConvTranspose1d and normalization applied to this conv + to provide a uniform interface across normalization approaches. + """ + + def __init__( + self, + *args, + causal: bool = False, + norm: str = "none", + norm_kwargs: tp.Dict[str, tp.Any] = {}, # noqa: B006 + **kwargs, + ): + super(NormConvTranspose1d, self).__init__() + self.convtr = apply_parametrization_norm( + StaticRawStreamingConvTranspose1d(*args, **kwargs), norm + ) + self.norm_type = norm + + def forward(self, x, partial): + return self.convtr(x, partial) + + +class StaticStreamingConvTranspose1d(StreamingConvTranspose1d): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + groups: int = 1, + bias: bool = True, + causal: bool = False, + norm: str = "none", + trim_right_ratio: float = 1.0, + norm_kwargs: tp.Dict[str, tp.Any] = {}, # noqa: B006 + ): + super(StreamingConvTranspose1d, self).__init__() + self.convtr = StaticNormConvTranspose1d( + in_channels, + out_channels, + kernel_size, + stride, + groups=groups, + bias=bias, + causal=causal, + norm=norm, + norm_kwargs=norm_kwargs, + ) + self.causal = causal + self.trim_right_ratio = trim_right_ratio + assert ( + self.causal or self.trim_right_ratio == 1.0 + ), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions" + assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0 + + # Static Mimi Changes: + # 1) Remove non streaming mode logic + def forward(self, x, partial): + return self.convtr(x, partial) + + +class StaticConvTrUpsample1d(ConvTrUpsample1d): + """ + Upsample by some integer amount `stride` using transposed convolutions. + """ + + def __init__( + self, + stride: int, + dimension: tp.Optional[int] = None, + causal: bool = False, + learnt: bool = False, + channel_wise: bool = False, + ): + super(ConvTrUpsample1d, self).__init__() + self.learnt = learnt + self.channel_wise = channel_wise + groups = 1 + + assert dimension is not None, "Dimension required for learnt convolutions." + in_channels = dimension + out_channels = dimension + if channel_wise: + groups = dimension + + self.convtr = StaticStreamingConvTranspose1d( + in_channels, + out_channels, + kernel_size=2 * stride, + stride=stride, + causal=causal, + groups=groups, + bias=False, + ) + + # Static Mimi Changes: + # 1) Remove not self.learnt logic since it doesn't go in. + def forward(self, x: torch.Tensor, partial: torch.Tensor): + return self.convtr(x, partial) diff --git a/examples/qualcomm/oss_scripts/moshi/model/static_mimi.py b/examples/qualcomm/oss_scripts/moshi/model/static_mimi.py new file mode 100644 index 00000000000..a887fc40878 --- /dev/null +++ b/examples/qualcomm/oss_scripts/moshi/model/static_mimi.py @@ -0,0 +1,679 @@ +# Copyright (c) Kyutai, all rights reserved. +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +"""Retrieves the pretrained models for Moshi and Mimi.""" +from dataclasses import dataclass +from pathlib import Path + +import torch.nn as nn + +try: + from huggingface_hub.errors import EntryNotFoundError +except ImportError: + from huggingface_hub.utils import EntryNotFoundError # noqa: F401 +import typing as tp +from contextlib import ExitStack + +import torch +from einops import rearrange +from executorch.examples.qualcomm.oss_scripts.moshi.model.static_convtr import ( + StaticConvTrUpsample1d, +) +from executorch.examples.qualcomm.oss_scripts.moshi.model.static_seanet_decoder import ( + StaticSEANetDecoder, +) + +from moshi.models.compression import MimiModel +from moshi.models.loaders import ( + _is_safetensors, + _quantizer_kwargs, + _seanet_kwargs, + _transformer_kwargs, +) +from moshi.modules import SEANetEncoder, transformer +from moshi.modules.resample import ConvDownsample1d +from moshi.modules.rope import RotaryEmbedding +from moshi.modules.streaming import State, StreamingModule +from moshi.modules.transformer import ( + create_norm_fn, + KVCacheResult, + LayerScale, + ProjectedTransformer, + RingKVCache, + StreamingTransformer, + StreamingTransformerLayer, +) +from moshi.quantization import BaseQuantizer, SplitResidualVectorQuantizer +from moshi.utils import quantize +from moshi.utils.compile import no_compile +from safetensors.torch import load_model +from torch.nn import functional as F + +SAMPLE_RATE = 24000 +FRAME_RATE = 12.5 + +TEXT_TOKENIZER_NAME = "tokenizer_spm_32k_3.model" +MOSHI_NAME = "model.safetensors" +MOSHI_Q8_NAME = "model.q8.safetensors" +MIMI_NAME = "tokenizer-e351c8d8-checkpoint125.safetensors" +DEFAULT_REPO = "kyutai/moshiko-pytorch-bf16" + + +class StaticRingKVCache(RingKVCache): + # Static Mimi Changes: + # 1) Remove all inplace kv_cache & index updates, perform non inplace updates and return updated output as next execution's input + # 2) Use end_offset to keep track of nth iteration * 2. Different from end_index, this number does not reset even after end_index resets + def complete( + self, + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + end_index: torch.Tensor, + end_offset: torch.Tensor, + ) -> KVCacheResult: + assert k.shape[:-1] == v.shape[:-1], (k.shape, v.shape) + B, H, T, D = k.shape + + end_index = torch.where( + end_index >= self.capacity, end_index - self.capacity, end_index + ) + + assert T > 0 + + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + k_cache = k_cache.permute(0, 2, 1, 3) + v_cache = v_cache.permute(0, 2, 1, 3) + k_0, k_1 = torch.split(k, 1, dim=1) + v_0, v_1 = torch.split(v, 1, dim=1) + + index_0 = torch.tensor([0], dtype=torch.int64) + end_index + index_1 = torch.tensor([1], dtype=torch.int64) + end_index + + k_cache = torch.ops.aten.index_put(k_cache, [None, index_0], k_0) + k_cache = torch.ops.aten.index_put(k_cache, [None, index_1], k_1) + + v_cache = torch.ops.aten.index_put(v_cache, [None, index_0], v_0) + v_cache = torch.ops.aten.index_put(v_cache, [None, index_1], v_1) + k_cache = k_cache.permute(0, 2, 1, 3) + v_cache = v_cache.permute(0, 2, 1, 3) + + indexes = torch.arange( + self.capacity, device=end_offset.device, dtype=torch.long + ) + + # end_index correspond to the actual index where the last value was written. + offset = T - 1 + last_offset = end_offset + offset + delta = indexes - (end_index + offset) + # We know that if `index == end_index`, then we should output `self.end_offset`. + # If `index = end_index - 1` we should output `self.end_offset - 1` + # If `index = end_index - n` we should output `self.end_offset - n` + # Now, for `index == end_index + 1` , we actually have the oldest entry in the cache, + # so we should output `end_index + 1 - self.capacity` + positions = torch.where( + delta <= 0, + last_offset + delta, + last_offset + delta - self.capacity, + ) + end_offset = end_offset.add(T) + end_index = end_index.add(T) + invalid = indexes >= end_offset + positions = torch.where(invalid, torch.full_like(positions, -1), positions) + return ( + KVCacheResult(k_cache, v_cache, positions), + k_cache, + v_cache, + end_index, + end_offset, + ) + + +@dataclass +class _StaticMHAState(State): + kv_cache: StaticRingKVCache + offset: torch.Tensor + offset_cpu: int + + def reset(self): + self.kv_cache.reset() + self.offset.zero_() + self.offset_cpu = 0 + + +class StaticStreamingMultiheadAttention(StreamingModule[_StaticMHAState]): + _fsdp_final = True + + def __init__( + self, + embed_dim: int, + num_heads: int, + causal: bool = False, + context: tp.Optional[int] = None, + rope: tp.Optional[RotaryEmbedding] = None, + weights_per_step: int = 0, + weights_per_step_schedule: list[int] | None = None, + device=None, + dtype=None, + ): + super().__init__() + factory_kwargs = {"device": device, "dtype": dtype} + + self.embed_dim = embed_dim + self.causal = causal + self.context = context + self.rope = rope + self.num_heads = num_heads + self.weights_per_step = weights_per_step + self.weights_per_step_schedule = weights_per_step_schedule + + out_dim = embed_dim + out_dim = 3 * embed_dim + mult = 1 + in_proj = nn.Linear(embed_dim, mult * out_dim, bias=False, **factory_kwargs) + # We try to follow the default PyTorch MHA convention, to easily compare results. + self.in_proj_weight = in_proj.weight + self.in_proj_bias = in_proj.bias + self.out_proj = nn.Linear( + embed_dim, mult * embed_dim, bias=False, **factory_kwargs + ) + + def _init_streaming_state(self, batch_size: int) -> _StaticMHAState: + capacity = self.context + device = self.in_proj_weight.device + dtype = self.in_proj_weight.dtype + dim_per_head = self.embed_dim // self.num_heads + self.kv_cache = StaticRingKVCache( + batch_size, self.num_heads, dim_per_head, capacity, device, dtype + ) + return _StaticMHAState( + self.kv_cache, + offset=torch.zeros(1, device=device, dtype=torch.long), + offset_cpu=0, + ) + + def _complete_kv( + self, k, v, k_cache, v_cache, end_index, end_offset + ) -> KVCacheResult: + state = self._streaming_state + # Check here, since we did not override methods used when streaming_state == None + assert state is not None + return self.kv_cache.complete(k, v, k_cache, v_cache, end_index, end_offset) + + # Static Mimi Changes: + # 1) use end_offset to replace state.offset when assigning it to offset variable + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + end_index: torch.Tensor, + end_offset: torch.Tensor, + ): + T = query.shape[1] + assert self.causal, "Streaming only available for causal" + offset = end_offset + + if self.weights_per_step: + projected = quantize.multi_linear( + self.weights_per_step, + self.weights_per_step_schedule, + self, + query, + offset_cpu=0, + name="in_proj_weight", + ) + else: + projected = quantize.linear(self, query, "in_proj_weight") + q, k, v = rearrange( + projected, "b t (p h d) -> p b h t d", p=3, h=self.num_heads + ) + + q, k = self.rope(q, k, offset, time_before_heads=False) + kv_cache_result, k_cache, v_cache, end_index, end_offset = self._complete_kv( + k, v, k_cache, v_cache, end_index, end_offset + ) + k, v, pos_k = kv_cache_result + + pos_k = pos_k.view(1, -1) + pos_q = offset + torch.arange(T, device=q.device, dtype=torch.long).view(-1, 1) + delta = pos_q - pos_k + attn_bias = (pos_k >= 0) & (delta >= 0) + attn_bias = attn_bias & (delta < self.context) + + x = F.scaled_dot_product_attention(q, k, v, attn_bias, dropout_p=0.0) + + x = rearrange(x, "b h t d -> b t (h d)") + + x = quantize.linear(self.out_proj, x) + + return x, k_cache, v_cache, end_index, end_offset + + +class StaticStreamingTransformerLayer(StreamingTransformerLayer): + def __init__( + self, + d_model: int, + num_heads: int, + dim_feedforward: int | list[int] = 2048, + causal: bool = False, + context: tp.Optional[int] = None, + rope: tp.Optional[RotaryEmbedding] = None, + norm: str = "layer_norm", + layer_scale: tp.Optional[float] = None, + gating: str = "none", + weights_per_step: int = 0, + weights_per_step_schedule: list[int] | None = None, + activation=F.gelu, + skip_self_attn: bool = False, + device=None, + dtype=None, + ): + # Skip parent class init and call grandparent directly + super(StreamingTransformerLayer, self).__init__() + factory_kwargs = {"device": device, "dtype": dtype} + # Redefine self_attn to our streaming multi-head attention + attn_kwargs: tp.Dict[str, tp.Any] = { + "embed_dim": d_model, + "num_heads": num_heads, + } + if not skip_self_attn: + self.self_attn: ( + StaticStreamingMultiheadAttention + ) = StaticStreamingMultiheadAttention( + causal=causal, + context=context, + rope=rope, + weights_per_step=weights_per_step, + weights_per_step_schedule=weights_per_step_schedule, + **attn_kwargs, # type: ignore + **factory_kwargs, # type: ignore + ) # type: ignore + self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) + self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) + # Redefine feedforward layers to expose bias parameter + self.weights_per_step = weights_per_step + self.weights_per_step_schedule = weights_per_step_schedule + self.gating: tp.Optional[nn.Module] = None + self.linear1: tp.Optional[nn.Module] = None + self.linear2: tp.Optional[nn.Module] = None + self.activation = activation + self.skip_self_attn = skip_self_attn + + assert ( + not weights_per_step + ), "weights_per_step without gating not supported for now." + assert not isinstance( + dim_feedforward, list + ), "List dim_feedforward without gating not supported for now." + self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False, **factory_kwargs) + self.linear2 = nn.Linear(dim_feedforward, d_model, bias=False, **factory_kwargs) + + self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs) # type: ignore + self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs) # type: ignore + + def _sa_block( + self, + x: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + end_index: torch.Tensor, + end_offset: torch.Tensor, + ): + x_orig = x + x = self.norm1(x) + update, k_cache, v_cache, end_index, end_offset = self.self_attn( + x, x, x, k_cache, v_cache, end_index, end_offset + ) + x = x_orig.to(update) + self.layer_scale_1(update) + return x, k_cache, v_cache, end_index, end_offset + + def forward( + self, + x: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + end_index: torch.Tensor, + end_offset: torch.Tensor, + ): + with ExitStack() as stack: + if x.device.type != "cuda": + stack.enter_context(no_compile()) + x, k_cache, v_cache, end_index, end_offset = self._sa_block( + x, k_cache, v_cache, end_index, end_offset + ) + x = self._ff_block(x) + return x, k_cache, v_cache, end_index, end_offset + + +class StaticStreamingTransformer(StreamingTransformer): + # Static Mimi Changes: + # 1) After static variables are passed in, unbind them and pass them to corresponding transformer layers. + # After function returned, stack all layers back to 1 tensor and return it. + # 2) Remove other positional embeddings logic such as "sin", "sine_rope" since "rope" is used. + def forward( + self, + x: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + end_index: torch.Tensor, + end_offset: torch.Tensor, + *args, + **kwargs, + ): + B, T, C = x.shape + + dtype_input = x.dtype + k_cache_list = torch.unbind(k_cache, dim=0) + v_cache_list = torch.unbind(v_cache, dim=0) + end_index_list = torch.unbind(end_index, dim=0) + end_offset_list = torch.unbind(end_offset, dim=0) + new_k_cache = [] + new_v_cache = [] + new_end_index = [] + new_end_offset = [] + for i, layer in enumerate(self.layers): + x, k_cache_res, v_cache_res, end_index_res, end_offset_res = layer( + x, + k_cache_list[i], + v_cache_list[i], + end_index_list[i], + end_offset_list[i], + *args, + **kwargs, + ) + new_k_cache.append(k_cache_res) + new_v_cache.append(v_cache_res) + new_end_index.append(end_index_res) + new_end_offset.append(end_offset_res) + new_k_cache = torch.stack(new_k_cache, dim=0) + new_v_cache = torch.stack(new_v_cache, dim=0) + new_end_index = torch.stack(new_end_index, dim=0) + new_end_offset = torch.stack(new_end_offset, dim=0) + return ( + x.to(dtype_input), + new_k_cache, + new_v_cache, + new_end_index, + new_end_offset, + ) + + +class StaticProjectedTransformer(ProjectedTransformer): + """Transformer with optional projections of the input and output to different dimensions when needed. + Supports multiple outputs. + + Args: + input_dimension (int): dimension of the input. + output_dimensions (tuple[int]): dimensions of the outputs. + d_model (int): inner dimension of the Transformer. + conv_layout (bool): If True, expects `[B, C, T]` shaped tensors, otherwise, `[B, T, C]`. + Similarly, the output will have the same layout. + """ + + def __init__( + self, + input_dimension: int, + output_dimensions: tp.Tuple[int, ...], + d_model: int, + *, + conv_layout: bool = False, + **kwargs, + ): + super(ProjectedTransformer, self).__init__() + self.transformer = StaticStreamingTransformer(d_model=d_model, **kwargs) + self.input_dimension = input_dimension + self.output_dimensions = output_dimensions + self.conv_layout = conv_layout + self.input_proj = None + if d_model != input_dimension: + self.input_proj = nn.Linear(input_dimension, d_model, bias=False) + + self.output_projs = nn.ModuleList() + for output_dimension in output_dimensions: + if d_model == output_dimension: + self.output_projs.append(nn.Identity()) + else: + self.output_projs.append( + nn.Linear(d_model, output_dimension, bias=False) + ) + + def forward( + self, + x, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + end_index: torch.Tensor, + end_offset: torch.Tensor, + *args, + **kwargs, + ): + if self.conv_layout: + x = x.transpose(1, 2) + if self.input_proj is not None: + x = self.input_proj(x) + z, k_cache, v_cache, end_index, end_offset = self.transformer( + x, k_cache, v_cache, end_index, end_offset, *args, **kwargs + ) + ys = [] + for output_proj in self.output_projs: + y = output_proj(z) + if self.conv_layout: + y = y.transpose(1, 2) + ys.append(y) + return ys, k_cache, v_cache, end_index, end_offset + + +class StaticMimiModel(MimiModel): + """ + Static Mimi Model does not keep track of any states inside the model. + It moves all the state related variables to I/O since lowered model does not support keep track of the states in the model. + Static variables includes indices, offsets, kv cache, partials, and previous. + """ + + def __init__( + self, + encoder: nn.Module, + decoder: nn.Module, + quantizer: BaseQuantizer, + frame_rate: float, + encoder_frame_rate: float, + sample_rate: int, + channels: int, + causal: bool = False, + encoder_transformer: tp.Optional[nn.Module] = None, + decoder_transformer: tp.Optional[nn.Module] = None, + resample_method: str = "interpolate", + upsample_channel_wise_bug: bool = True, + freeze_encoder: bool = False, + freeze_quantizer: bool = False, + freeze_quantizer_level: int = -1, + torch_compile_encoder_decoder: bool = False, + ): + super(MimiModel, self).__init__() + self.encoder = encoder + self.decoder = decoder + self.encoder_transformer = encoder_transformer + self.decoder_transformer = decoder_transformer + self.quantizer = quantizer + self._frame_rate = frame_rate + self._sample_rate = sample_rate + self._channels = channels + self.encoder_frame_rate = encoder_frame_rate + self.torch_compile_encoder_decoder = torch_compile_encoder_decoder + self.freeze_quantizer = freeze_quantizer + self.freeze_quantizer_level = ( + freeze_quantizer_level + if freeze_quantizer_level > 0 + else self.quantizer.num_codebooks + ) + + # We will need the dimension for the resampling. In general the encoder will be a SeanetEncoder + # which exposes a `dimension` attribute. + dimension = encoder.dimension + assert isinstance( + dimension, int + ), f"Dimension should be int, got {dimension} of type {type(dimension)}." + self.dimension = dimension + + assert resample_method in [ + "interpolate", + "conv", + "avg_pool", + ], f"Invalid resample_method {resample_method}" + self.resample_method = resample_method + if encoder_frame_rate != frame_rate: + assert not ( + causal and resample_method == "interpolate" + ), "Cannot interpolate with causal model." + if resample_method in ["conv", "avg_pool"]: + assert ( + self.encoder_frame_rate > self.frame_rate + ), "Cannot upsample with conv." + downsample_stride = self.encoder_frame_rate / self.frame_rate + assert downsample_stride == int( + downsample_stride + ), f"Only integer strides are supported, got {downsample_stride}" + learnt = resample_method == "conv" + self.downsample = ConvDownsample1d( + int(downsample_stride), + dimension=dimension, + learnt=learnt, + causal=causal, + ) + self.upsample = StaticConvTrUpsample1d( + int(downsample_stride), + dimension=dimension, + learnt=learnt, + causal=causal, + channel_wise=upsample_channel_wise_bug, + ) + + def _static_to_encoder_framerate(self, x: torch.Tensor, partial: torch.Tensor): + # Convert from overall framerate to the encoder frame rate. + x, partial = self.upsample(x, partial) + return x, partial + + def decode( + self, + codes, + k_cache, + v_cache, + end_index, + end_offset, + partial_convtr_0, + partial_convtr_1, + partial_convtr_2, + partial_convtr_3, + partial_convtr_4, + previous_conv_0, + previous_conv_1, + previous_conv_3, + previous_conv_5, + previous_conv_7, + previous_conv_9, + ): + state = self._streaming_state + emb = self.decode_latent(codes) + emb, partial_convtr_0 = self._static_to_encoder_framerate(emb, partial_convtr_0) + assert state is not None + (emb,), k_cache, v_cache, end_index, end_offset = self.decoder_transformer( + emb, k_cache, v_cache, end_index, end_offset + ) + with self._context_for_encoder_decoder: + ( + out, + partial_convtr_1, + partial_convtr_2, + partial_convtr_3, + partial_convtr_4, + previous_conv_0, + previous_conv_1, + previous_conv_3, + previous_conv_5, + previous_conv_7, + previous_conv_9, + ) = self.decoder( + emb, + partial_convtr_1, + partial_convtr_2, + partial_convtr_3, + partial_convtr_4, + previous_conv_0, + previous_conv_1, + previous_conv_3, + previous_conv_5, + previous_conv_7, + previous_conv_9, + ) + return ( + out, + k_cache, + v_cache, + end_index, + end_offset, + partial_convtr_0, + partial_convtr_1, + partial_convtr_2, + partial_convtr_3, + partial_convtr_4, + previous_conv_0, + previous_conv_1, + previous_conv_3, + previous_conv_5, + previous_conv_7, + previous_conv_9, + ) + + +class StaticMimiDecoderModel(StaticMimiModel): + def forward(self, *args): + return super().decode(*args) + + +def get_static_mimi( + filename: str | Path, device: torch.device | str = "cpu", num_codebooks: int = 8 +) -> MimiModel: + """Return a pretrained Mimi model.""" + encoder = SEANetEncoder(**_seanet_kwargs) + decoder = StaticSEANetDecoder(**_seanet_kwargs) + _transformer_kwargs["layer_class"] = StaticStreamingTransformerLayer + encoder_transformer = transformer.ProjectedTransformer( + device=device, **_transformer_kwargs + ) + decoder_transformer = StaticProjectedTransformer( + device=device, **_transformer_kwargs + ) + quantizer = SplitResidualVectorQuantizer( + **_quantizer_kwargs, + ) + model = StaticMimiDecoderModel( + encoder, + decoder, + quantizer, + channels=1, + sample_rate=SAMPLE_RATE, + frame_rate=FRAME_RATE, + encoder_frame_rate=SAMPLE_RATE / encoder.hop_length, + causal=True, + resample_method="conv", + encoder_transformer=encoder_transformer, + decoder_transformer=decoder_transformer, + ).to(device=device) + model.eval() + if _is_safetensors(filename): + load_model(model, filename, strict=True) + else: + pkg = torch.load(filename, "cpu") # noqa: TOR102 + model.load_state_dict(pkg["model"]) + model.set_num_codebooks(num_codebooks) + return model diff --git a/examples/qualcomm/oss_scripts/moshi/model/static_seanet_decoder.py b/examples/qualcomm/oss_scripts/moshi/model/static_seanet_decoder.py new file mode 100644 index 00000000000..15f1963a2b4 --- /dev/null +++ b/examples/qualcomm/oss_scripts/moshi/model/static_seanet_decoder.py @@ -0,0 +1,313 @@ +# Copyright (c) Kyutai, all rights reserved. +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import typing as tp + +import numpy as np +from executorch.examples.qualcomm.oss_scripts.moshi.model.static_conv import ( + StaticStreamingConv1d, +) +from executorch.examples.qualcomm.oss_scripts.moshi.model.static_convtr import ( + StaticStreamingConvTranspose1d, +) +from moshi.modules.seanet import SEANetDecoder, SEANetResnetBlock +from moshi.modules.streaming import StreamingAdd +from moshi.utils.compile import torch_compile_lazy +from torch import nn + + +class StaticSEANetResnetBlock(SEANetResnetBlock): + """Residual block from SEANet model. + + Args: + dim (int): Dimension of the input/output. + kernel_sizes (list): List of kernel sizes for the convolutions. + dilations (list): List of dilations for the convolutions. + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function. + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + compress (int): Reduced dimensionality in residual branches (from Demucs v3). + true_skip (bool): Whether to use true skip connection or a simple + (streamable) convolution as the skip connection. + """ + + # Static Mimi Changes: + # 1) Replace Conv with Static Conv + def __init__( + self, + dim: int, + kernel_sizes: tp.List[int] = [3, 1], # noqa: B006 + dilations: tp.List[int] = [1, 1], # noqa: B006 + activation: str = "ELU", + activation_params: dict = {"alpha": 1.0}, # noqa: B006 + norm: str = "none", + norm_params: tp.Dict[str, tp.Any] = {}, # noqa: B006 + causal: bool = False, + pad_mode: str = "reflect", + compress: int = 2, + true_skip: bool = True, + ): + super(SEANetResnetBlock, self).__init__() + assert len(kernel_sizes) == len( + dilations + ), "Number of kernel sizes should match number of dilations" + act = getattr(nn, activation) + hidden = dim // compress + block = [] + for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)): + in_chs = dim if i == 0 else hidden + out_chs = dim if i == len(kernel_sizes) - 1 else hidden + block += [ + act(**activation_params), + StaticStreamingConv1d( + in_chs, + out_chs, + kernel_size=kernel_size, + dilation=dilation, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ignore_previous=(i == 1), + ), + ] + self.block = nn.Sequential(*block) + self.add = StreamingAdd() + self.shortcut: nn.Module + if true_skip: + self.shortcut = nn.Identity() + else: + self.shortcut = StaticStreamingConv1d( + dim, + dim, + kernel_size=1, + norm=norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ) + + # Static Mimi Changes: + # 1) Pass in previous to Conv. Return the output and updated previous. + def forward(self, x, previous): + block_list = list(self.block.children()) + assert ( + len(block_list) == 4 + ), "Expect block list to have 4 modules, check if model is changed" + assert isinstance(block_list[1], StaticStreamingConv1d) + assert isinstance(block_list[3], StaticStreamingConv1d) + + u = self.shortcut(x) + + x = block_list[0](x) + x, previous = block_list[1](x, previous) + x = block_list[2](x) + v = block_list[3](x) + return self.add(u, v), previous + + +class StaticSEANetDecoder(SEANetDecoder): + """SEANet decoder. + + Args: + channels (int): Audio channels. + dimension (int): Intermediate representation dimension. + n_filters (int): Base width for the model. + n_residual_layers (int): nb of residual layers. + ratios (Sequence[int]): kernel size and stride ratios. + activation (str): Activation function. + activation_params (dict): Parameters to provide to the activation function. + final_activation (str): Final activation function after all convolutions. + final_activation_params (dict): Parameters to provide to the activation function. + norm (str): Normalization method. + norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution. + kernel_size (int): Kernel size for the initial convolution. + last_kernel_size (int): Kernel size for the initial convolution. + residual_kernel_size (int): Kernel size for the residual layers. + dilation_base (int): How much to increase the dilation with each layer. + causal (bool): Whether to use fully causal convolution. + pad_mode (str): Padding mode for the convolutions. + true_skip (bool): Whether to use true skip connection or a simple. + (streamable) convolution as the skip connection in the residual network blocks. + compress (int): Reduced dimensionality in residual branches (from Demucs v3). + disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm. + For the decoder, it corresponds to the N last blocks. + trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup. + If equal to 1.0, it means that all the trimming is done at the right. + """ + + # Static Mimi Changes: + # 1) Replace Conv and Conv Transpose with Static Conv and Static Conv Transpose. Main difference is that Static version does not store any states in the model. Instead, pass state as output and feed it back in during the next execution. + def __init__( + self, + channels: int = 1, + dimension: int = 128, + n_filters: int = 32, + n_residual_layers: int = 3, + ratios: tp.List[int] = [8, 5, 4, 2], # noqa: B006 + activation: str = "ELU", + activation_params: dict = {"alpha": 1.0}, # noqa: B006 + final_activation: tp.Optional[str] = None, + final_activation_params: tp.Optional[dict] = None, + norm: str = "none", + norm_params: tp.Dict[str, tp.Any] = {}, # noqa: B006 + kernel_size: int = 7, + last_kernel_size: int = 7, + residual_kernel_size: int = 3, + dilation_base: int = 2, + causal: bool = False, + pad_mode: str = "reflect", + true_skip: bool = True, + compress: int = 2, + disable_norm_outer_blocks: int = 0, + trim_right_ratio: float = 1.0, + ): + super(SEANetDecoder, self).__init__() + self.dimension = dimension + self.channels = channels + self.n_filters = n_filters + self.ratios = ratios + self.n_residual_layers = n_residual_layers + self.hop_length = int(np.prod(self.ratios)) + self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks + self.disable_norm_outer_blocks = disable_norm_outer_blocks + assert ( + self.disable_norm_outer_blocks >= 0 + and self.disable_norm_outer_blocks <= self.n_blocks + ), ( + "Number of blocks for which to disable norm is invalid." + "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0." + ) + + act = getattr(nn, activation) + mult = int(2 ** len(self.ratios)) + model = [ + StaticStreamingConv1d( + dimension, + mult * n_filters, + kernel_size, + norm=( + "none" if self.disable_norm_outer_blocks == self.n_blocks else norm + ), + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ) + ] + + # Upsample to raw audio scale + for i, ratio in enumerate(self.ratios): + block_norm = ( + "none" + if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) + else norm + ) + # Add upsampling layers + model += [ + act(**activation_params), + StaticStreamingConvTranspose1d( + mult * n_filters, + mult * n_filters // 2, + kernel_size=ratio * 2, + stride=ratio, + norm=block_norm, + norm_kwargs=norm_params, + causal=causal, + trim_right_ratio=trim_right_ratio, + ), + ] + # Add residual layers + for j in range(n_residual_layers): + model += [ + StaticSEANetResnetBlock( + mult * n_filters // 2, + kernel_sizes=[residual_kernel_size, 1], + dilations=[dilation_base**j, 1], + activation=activation, + activation_params=activation_params, + norm=block_norm, + norm_params=norm_params, + causal=causal, + pad_mode=pad_mode, + compress=compress, + true_skip=true_skip, + ) + ] + + mult //= 2 + + # Add final layers + model += [ + act(**activation_params), + StaticStreamingConv1d( + n_filters, + channels, + last_kernel_size, + norm="none" if self.disable_norm_outer_blocks >= 1 else norm, + norm_kwargs=norm_params, + causal=causal, + pad_mode=pad_mode, + ), + ] + self.model = nn.Sequential(*model) + + # Static Mimi Changes + # 1) Pass in all states related variables(e.g., partial, previous) as input and output. + @torch_compile_lazy + def forward( + self, + z, + partial_convtr_1, + partial_convtr_2, + partial_convtr_3, + partial_convtr_4, + previous_conv_0, + previous_conv_1, + previous_conv_3, + previous_conv_5, + previous_conv_7, + previous_conv_9, + ): + model_list = list(self.model.children()) + assert ( + len(model_list) == 15 + ), "Expect to have 15 submodules, check if model is changed." + z, previous_conv_0 = model_list[0](z, previous_conv_0) + z = model_list[1](z) + z, partial_convtr_1 = model_list[2](z, partial_convtr_1) + z, previous_conv_1 = model_list[3](z, previous_conv_1) + z = model_list[4](z) + z, partial_convtr_2 = model_list[5](z, partial_convtr_2) + z, previous_conv_3 = model_list[6](z, previous_conv_3) + z = model_list[7](z) + z, partial_convtr_3 = model_list[8](z, partial_convtr_3) + z, previous_conv_5 = model_list[9](z, previous_conv_5) + z = model_list[10](z) + z, partial_convtr_4 = model_list[11](z, partial_convtr_4) + z, previous_conv_7 = model_list[12](z, previous_conv_7) + z = model_list[13](z) + z, previous_conv_9 = model_list[14](z, previous_conv_9) + + return ( + z, + partial_convtr_1, + partial_convtr_2, + partial_convtr_3, + partial_convtr_4, + previous_conv_0, + previous_conv_1, + previous_conv_3, + previous_conv_5, + previous_conv_7, + previous_conv_9, + ) diff --git a/examples/qualcomm/oss_scripts/moshi/qnn_mimi_decoder_runner.cpp b/examples/qualcomm/oss_scripts/moshi/qnn_mimi_decoder_runner.cpp new file mode 100644 index 00000000000..9cc09717479 --- /dev/null +++ b/examples/qualcomm/oss_scripts/moshi/qnn_mimi_decoder_runner.cpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * @file + * + * This tool can run static mimi decoder with Qualcomm AI Engine + * Direct. + * + */ + +#include +#include +#include + +DEFINE_string( + model_path, + "mimi_decoder_qnn.pte", + "Model serialized in flatbuffer format."); +DEFINE_string( + output_folder_path, + "outputs", + "Executorch inference data output path."); +DEFINE_string( + input_list_path, + "input_list.txt", + "Input list storing file name of encoded results."); + +using executorch::runtime::Error; + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + + example::Runner runner(FLAGS_model_path, FLAGS_output_folder_path); + + ET_CHECK_MSG( + runner.generate(FLAGS_input_list_path) == Error::Ok, + "Runner failed to generate"); + + return 0; +} diff --git a/examples/qualcomm/oss_scripts/moshi/runner/runner.cpp b/examples/qualcomm/oss_scripts/moshi/runner/runner.cpp new file mode 100644 index 00000000000..e7120b1cd77 --- /dev/null +++ b/examples/qualcomm/oss_scripts/moshi/runner/runner.cpp @@ -0,0 +1,279 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// A simple mimi decoder runner that takes encoder's result as input. + +#include +#include +#include + +#include +#include + +#include +#include + +using executorch::aten::Tensor; +using executorch::aten::TensorImpl; +using executorch::extension::Module; +using executorch::extension::llm::time_in_ms; +using executorch::runtime::Error; +using executorch::runtime::EValue; +using executorch::runtime::MethodMeta; +using executorch::runtime::Result; +using executorch::runtime::TensorInfo; + +namespace example { + +Runner::Runner(const std::string& model_path, const std::string& output_path) + : output_path_(output_path) { + module_ = std::make_unique( + model_path, Module::LoadMode::MmapUseMlockIgnoreErrors); + ET_LOG(Info, "creating module: model_path=%s", model_path.c_str()); +} + +Error Runner::parse_input_list(std::string& path) { + // Fill in data for input + std::ifstream input_list(path); + ET_CHECK_MSG(input_list.is_open(), "Input list error opening file"); + std::string encoded_input; + + while (std::getline(input_list, encoded_input)) { + std::ifstream is; + is.open(encoded_input, std::ios::binary); + is.seekg(0, std::ios::end); + size_t filesize = is.tellg(); + is.seekg(0, std::ios::beg); + std::vector encoded_in; + encoded_in.resize(filesize / sizeof(int32_t)); + is.read(reinterpret_cast(encoded_in.data()), filesize); + encoded_input_list_.first.push_back(encoded_in); + } + return Error::Ok; +} + +Error Runner::init_io() { + Result method_meta = module_->method_meta(method_name_); + k_cache_.first.resize( + (method_meta->input_tensor_meta(1)->nbytes() / sizeof(float)), 0); + v_cache_.first.resize( + (method_meta->input_tensor_meta(2)->nbytes() / sizeof(float)), 0); + end_index_.first.resize( + (method_meta->input_tensor_meta(3)->nbytes() / sizeof(float)), 0); + end_offset_.first.resize( + (method_meta->input_tensor_meta(4)->nbytes() / sizeof(float)), 0); + + for (int i = 5; i < 10; i++) { + auto size = (method_meta->input_tensor_meta(i)->nbytes() / sizeof(float)); + std::vector convtr_partial(size, 0); + convtr_partials_.emplace_back(convtr_partial, nullptr); + } + + for (int i = 10; i < method_meta->num_inputs(); i++) { + auto size = (method_meta->input_tensor_meta(i)->nbytes() / sizeof(float)); + std::vector conv_prev(size, 0); + conv_previous_.emplace_back(conv_prev, nullptr); + } + + auto per_decode_output_size = + (method_meta->output_tensor_meta(0)->nbytes() / sizeof(float)); + for (int i = 0; i < encoded_input_list_.first.size(); i++) { + std::vector output(per_decode_output_size, 0); + decoded_output_list_.first.emplace_back(output); + } + return Error::Ok; +} + +Error Runner::prepare_io() { + Result method_meta = module_->method_meta(method_name_); + + // in[0] + Result encoded_input_meta = method_meta->input_tensor_meta(0); + encoded_input_list_.second = std::make_shared( + encoded_input_meta->scalar_type(), + encoded_input_meta->sizes().size(), + const_cast(encoded_input_meta->sizes().data()), + encoded_input_list_.first[0].data(), + const_cast( + encoded_input_meta->dim_order().data())); + input_tensors_.emplace_back(encoded_input_list_.second.get()); + + // out[0] + Result decoded_output_meta = method_meta->output_tensor_meta(0); + decoded_output_list_.second = std::make_shared( + decoded_output_meta->scalar_type(), + decoded_output_meta->sizes().size(), + const_cast(decoded_output_meta->sizes().data()), + decoded_output_list_.first[0].data(), + const_cast( + decoded_output_meta->dim_order().data())); + output_tensors_.emplace_back(decoded_output_list_.second.get()); + + // in[1] and out[1] + Result k_cache_meta = method_meta->input_tensor_meta(1); + k_cache_.second = std::make_shared( + k_cache_meta->scalar_type(), + k_cache_meta->sizes().size(), + const_cast(k_cache_meta->sizes().data()), + k_cache_.first.data(), + const_cast(k_cache_meta->dim_order().data())); + input_tensors_.emplace_back(k_cache_.second.get()); + output_tensors_.emplace_back(k_cache_.second.get()); + + // in[2] and out[2] + Result v_cache_meta = method_meta->input_tensor_meta(2); + v_cache_.second = std::make_shared( + v_cache_meta->scalar_type(), + v_cache_meta->sizes().size(), + const_cast(v_cache_meta->sizes().data()), + v_cache_.first.data(), + const_cast(v_cache_meta->dim_order().data())); + input_tensors_.emplace_back(v_cache_.second.get()); + output_tensors_.emplace_back(v_cache_.second.get()); + + // in[3] and out[3] + Result end_index_meta = method_meta->input_tensor_meta(3); + end_index_.second = std::make_shared( + end_index_meta->scalar_type(), + end_index_meta->sizes().size(), + const_cast(end_index_meta->sizes().data()), + end_index_.first.data(), + const_cast( + end_index_meta->dim_order().data())); + input_tensors_.emplace_back(end_index_.second.get()); + output_tensors_.emplace_back(end_index_.second.get()); + + // in[4] and out[4] + Result end_offset_meta = method_meta->input_tensor_meta(4); + end_offset_.second = std::make_shared( + end_offset_meta->scalar_type(), + end_offset_meta->sizes().size(), + const_cast(end_offset_meta->sizes().data()), + end_offset_.first.data(), + const_cast( + end_offset_meta->dim_order().data())); + input_tensors_.emplace_back(end_offset_.second.get()); + output_tensors_.emplace_back(end_offset_.second.get()); + + // in[5-9] and out [5-9] + for (int i = 0, convtr_partials_start = 5; i < convtr_partials_.size(); + i++, convtr_partials_start++) { + Result convtr_partial_meta = + method_meta->input_tensor_meta(convtr_partials_start); + convtr_partials_[i].second = std::make_shared( + convtr_partial_meta->scalar_type(), + convtr_partial_meta->sizes().size(), + const_cast(convtr_partial_meta->sizes().data()), + convtr_partials_[i].first.data(), + const_cast( + convtr_partial_meta->dim_order().data())); + input_tensors_.emplace_back(convtr_partials_[i].second.get()); + output_tensors_.emplace_back(convtr_partials_[i].second.get()); + } + + // in[10-15] and out [10-15] + for (int i = 0, conv_previous_start = 10; i < conv_previous_.size(); + i++, conv_previous_start++) { + Result conv_previous_meta = + method_meta->input_tensor_meta(conv_previous_start); + conv_previous_[i].second = std::make_shared( + conv_previous_meta->scalar_type(), + conv_previous_meta->sizes().size(), + const_cast(conv_previous_meta->sizes().data()), + conv_previous_[i].first.data(), + const_cast( + conv_previous_meta->dim_order().data())); + input_tensors_.emplace_back(conv_previous_[i].second.get()); + output_tensors_.emplace_back(conv_previous_[i].second.get()); + } + + // Prepare the vector of EValue to run inference + inputs_.reserve(input_tensors_.size()); + for (auto& input_tensor : input_tensors_) { + inputs_.emplace_back(std::move(input_tensor)); + } + + for (int i = 0; i < output_tensors_.size(); i++) { + ET_CHECK_MSG( + module_->set_output(method_name_, output_tensors_[i], i) == Error::Ok, + "failed to set output tensor for module %u'th output", + i); + } + + return Error::Ok; +} + +Error Runner::load(std::string& input_list) { + if (module_->is_loaded()) { + return Error::Ok; + } + method_name_ = *module_->method_names()->begin(); + ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(method_name_)); + ET_CHECK_OK_OR_RETURN_ERROR(parse_input_list(input_list)); + ET_CHECK_OK_OR_RETURN_ERROR(init_io()); + ET_CHECK_OK_OR_RETURN_ERROR(prepare_io()); + + return Error::Ok; +} + +Error Runner::generate(std::string& input_list) { + std::vector inputs; + if (!module_->is_loaded()) { + stats_.model_load_start_ms = time_in_ms(); + ET_CHECK_OK_OR_RETURN_ERROR(load(input_list)); + stats_.model_load_end_ms = time_in_ms(); + } + + ET_LOG(Info, "Start generating"); + stats_.decode_start_ms = time_in_ms(); + for (int i = 0; i < encoded_input_list_.first.size(); i++) { + // encoded_input_list_ stores all executions' inputs. During each execution, + // it only needs to update the pointer to the new input's address. Same for + // decoded_output_list_ where memory space is reserved for all outputs. + // During each execution, point the output to corresponding output memory + // address. After exit for loop, dump all the outputs to a raw file. + encoded_input_list_.second->set_data(encoded_input_list_.first[i].data()); + decoded_output_list_.second->set_data(decoded_output_list_.first[i].data()); + ET_CHECK_MSG( + module_->set_output(method_name_, output_tensors_[0], 0) == Error::Ok, + "failed to set output tensor for module 0'th output"); + + module_->execute(method_name_, inputs_); + } + stats_.decode_end_ms = time_in_ms(); + + ET_LOG( + Info, + "\tModel Load Time:\t\t\t\t%ld (ms)", + stats_.model_load_end_ms - stats_.model_load_start_ms); + + auto decode_duration = stats_.decode_end_ms - stats_.decode_start_ms; + ET_LOG( + Info, + "\tTotal inference time for %zu chunks:\t\t%ld (ms)", + encoded_input_list_.first.size(), + decode_duration); + + ET_LOG( + Info, + "\tAverage inference time per chunk:\t\t%f (ms)", + ((double)decode_duration / encoded_input_list_.first.size())); + + auto output_file_name = output_path_ + "/output_0_0.raw"; + std::ofstream fout(output_file_name.c_str(), std::ios::binary); + for (const auto& decoded_output : decoded_output_list_.first) { + fout.write( + reinterpret_cast(decoded_output.data()), + decoded_output.size() * sizeof(float)); + } + fout.close(); + + return Error::Ok; +} + +} // namespace example diff --git a/examples/qualcomm/oss_scripts/moshi/runner/runner.h b/examples/qualcomm/oss_scripts/moshi/runner/runner.h new file mode 100644 index 00000000000..88f150d4410 --- /dev/null +++ b/examples/qualcomm/oss_scripts/moshi/runner/runner.h @@ -0,0 +1,79 @@ +/* + * Copyright (c) Qualcomm Innovation Center, Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +#include + +namespace example { + +class Runner { + public: + explicit Runner( + const std::string& model_path, + const std::string& output_path); + + struct Stats { + // Total time to load the model and init the inputs + long model_load_start_ms; + long model_load_end_ms; + + // Total time to decode all chunks + long decode_start_ms = 0; + long decode_end_ms = 0; + }; + + executorch::runtime::Error parse_input_list(std::string& input_list); + executorch::runtime::Error init_io(); + executorch::runtime::Error prepare_io(); + executorch::runtime::Error load(std::string& input_list); + executorch::runtime::Error generate(std::string& input_list); + + private: + Stats stats_; + std::unique_ptr module_; + std::string method_name_; + std::string output_path_; + + // Pair that stores IO data with its TensorImpl pointer + std::pair< + std::vector>, + std::shared_ptr> + encoded_input_list_; + std::pair, std::shared_ptr> + k_cache_; + std::pair, std::shared_ptr> + v_cache_; + std::pair, std::shared_ptr> + end_index_; + std::pair, std::shared_ptr> + end_offset_; + std::vector, + std::shared_ptr>> + convtr_partials_; + std::vector, + std::shared_ptr>> + conv_previous_; + std::pair< + std::vector>, + std::shared_ptr> + decoded_output_list_; + + std::vector inputs_; + std::vector input_tensors_; + std::vector output_tensors_; +}; + +} // namespace example