From 9aceb177e394b52082e1061367ff12a888d95e6e Mon Sep 17 00:00:00 2001 From: haowhsu-quic Date: Fri, 4 Oct 2024 11:40:49 +0800 Subject: [PATCH] Qualcomm AI Engine Direct - wav2letter e2e example Summary - enable wav2letter e2e example --- backends/qualcomm/tests/test_qnn_delegate.py | 38 +++ .../qualcomm/scripts/install_requirement.sh | 2 + examples/qualcomm/scripts/wav2letter.py | 226 ++++++++++++++++++ 3 files changed, 266 insertions(+) create mode 100644 examples/qualcomm/scripts/install_requirement.sh create mode 100644 examples/qualcomm/scripts/wav2letter.py diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 8abed68c630..74af8feb383 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -2681,6 +2681,44 @@ def test_ptq_mobilebert(self): for k, v in cpu.items(): self.assertLessEqual(abs(v[0] - htp[k][0]), 5) + def test_wav2letter(self): + if not self.required_envs([self.pretrained_weight]): + self.skipTest("missing required envs") + + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/scripts/wav2letter.py", + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--device", + self.device, + "--model", + self.model, + "--pretrained_weight", + self.pretrained_weight, + "--ip", + self.ip, + "--port", + str(self.port), + ] + if self.host: + cmds.extend(["--host", self.host]) + if self.shared_buffer: + cmds.extend(["--shared_buffer"]) + + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertLessEqual(msg["wer"], 0.5) + self.assertLessEqual(msg["cer"], 0.25) + def test_export_example(self): if not self.required_envs([self.model_name]): self.skipTest("missing required envs") diff --git a/examples/qualcomm/scripts/install_requirement.sh b/examples/qualcomm/scripts/install_requirement.sh new file mode 100644 index 00000000000..c961467a8a5 --- /dev/null +++ b/examples/qualcomm/scripts/install_requirement.sh @@ -0,0 +1,2 @@ +pip install soundfile +pip install torchmetrics diff --git a/examples/qualcomm/scripts/wav2letter.py b/examples/qualcomm/scripts/wav2letter.py new file mode 100644 index 00000000000..e377c6d7e90 --- /dev/null +++ b/examples/qualcomm/scripts/wav2letter.py @@ -0,0 +1,226 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +import sys +from multiprocessing.connection import Client + +import numpy as np + +import torch +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.examples.models.wav2letter import Wav2LetterModel +from executorch.examples.qualcomm.utils import ( + build_executorch_binary, + make_output_dir, + parse_skip_delegation_node, + setup_common_args_and_variables, + SimpleADB, +) + + +class Conv2D(torch.nn.Module): + def __init__(self, stride, padding, weight, bias=None): + super().__init__() + use_bias = bias is not None + self.conv = torch.nn.Conv2d( + in_channels=weight.shape[1], + out_channels=weight.shape[0], + kernel_size=[weight.shape[2], 1], + stride=[*stride, 1], + padding=[*padding, 0], + bias=use_bias, + ) + self.conv.weight = torch.nn.Parameter(weight.unsqueeze(-1)) + if use_bias: + self.conv.bias = torch.nn.Parameter(bias) + + def forward(self, x): + return self.conv(x) + + +def get_dataset(data_size, artifact_dir): + from torch.utils.data import DataLoader + from torchaudio.datasets import LIBRISPEECH + + def collate_fun(batch): + waves, labels = [], [] + + for wave, _, text, *_ in batch: + waves.append(wave.squeeze(0)) + labels.append(text) + # need padding here for static ouput shape + waves = torch.nn.utils.rnn.pad_sequence(waves, batch_first=True) + return waves, labels + + dataset = LIBRISPEECH(artifact_dir, url="test-clean", download=True) + data_loader = DataLoader( + dataset=dataset, + batch_size=data_size, + shuffle=True, + collate_fn=lambda x: collate_fun(x), + ) + # prepare input data + inputs, targets, input_list = [], [], "" + for wave, label in data_loader: + for index in range(data_size): + # reshape input tensor to NCHW + inputs.append((wave[index].reshape(1, 1, -1, 1),)) + targets.append(label[index]) + input_list += f"input_{index}_0.raw\n" + # here we only take first batch, i.e. 'data_size' tensors + break + + return inputs, targets, input_list + + +def eval_metric(pred, target_str): + from torchmetrics.text import CharErrorRate, WordErrorRate + + def parse(ids): + vocab = " abcdefghijklmnopqrstuvwxyz'*" + return ["".join([vocab[c] for c in id]).replace("*", "").upper() for id in ids] + + pred_str = parse( + [ + torch.unique_consecutive(pred[i, :, :].argmax(0)) + for i in range(pred.shape[0]) + ] + ) + wer, cer = WordErrorRate(), CharErrorRate() + return wer(pred_str, target_str), cer(pred_str, target_str) + + +def main(args): + skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + + # ensure the working directory exist + os.makedirs(args.artifact, exist_ok=True) + + if not args.compile_only and args.device is None: + raise RuntimeError( + "device serial is required if not compile only. " + "Please specify a device serial by -s/--device argument." + ) + + instance = Wav2LetterModel() + # target labels " abcdefghijklmnopqrstuvwxyz'*" + instance.vocab_size = 29 + model = instance.get_eager_model().eval() + model.load_state_dict(torch.load(args.pretrained_weight, weights_only=True)) + + # convert conv1d to conv2d in nn.Module level will only introduce 2 permute + # nodes around input & output, which is more quantization friendly. + for i in range(len(model.acoustic_model)): + for j in range(len(model.acoustic_model[i])): + module = model.acoustic_model[i][j] + if isinstance(module, torch.nn.Conv1d): + model.acoustic_model[i][j] = Conv2D( + stride=module.stride, + padding=module.padding, + weight=module.weight, + bias=module.bias, + ) + + # retrieve dataset, will take some time to download + data_num = 100 + inputs, targets, input_list = get_dataset( + data_size=data_num, artifact_dir=args.artifact + ) + pte_filename = "w2l_qnn" + build_executorch_binary( + model, + inputs[0], + args.model, + f"{args.artifact}/{pte_filename}", + inputs, + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + quant_dtype=QuantDtype.use_8a8w, + shared_buffer=args.shared_buffer, + ) + + if args.compile_only: + sys.exit(0) + + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=f"{args.build_folder}", + pte_path=f"{args.artifact}/{pte_filename}.pte", + workspace=f"/data/local/tmp/executorch/{pte_filename}", + device_id=args.device, + host_id=args.host, + soc_model=args.model, + shared_buffer=args.shared_buffer, + ) + adb.push(inputs=inputs, input_list=input_list) + adb.execute() + + # collect output data + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + adb.pull(output_path=args.artifact) + + predictions = [] + for i in range(data_num): + predictions.append( + np.fromfile( + os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32 + ) + ) + + # evaluate metrics + wer, cer = 0, 0 + for i, pred in enumerate(predictions): + pred = torch.from_numpy(pred).reshape(1, instance.vocab_size, -1) + wer_eval, cer_eval = eval_metric(pred, targets[i]) + wer += wer_eval + cer += cer_eval + + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send( + json.dumps({"wer": wer.item() / data_num, "cer": cer.item() / data_num}) + ) + else: + print(f"wer: {wer / data_num}\ncer: {cer / data_num}") + + +if __name__ == "__main__": + parser = setup_common_args_and_variables() + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. " + "Default ./wav2letter", + default="./wav2letter", + type=str, + ) + + parser.add_argument( + "-p", + "--pretrained_weight", + help=( + "Location of pretrained weight, please download via " + "https://github.com/nipponjo/wav2letter-ctc-pytorch/tree/main?tab=readme-ov-file#wav2letter-ctc-pytorch" + " for torchaudio.models.Wav2Letter version" + ), + default=None, + type=str, + required=True, + ) + + args = parser.parse_args() + try: + main(args) + except Exception as e: + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"Error": str(e)})) + else: + raise Exception(e)