From e2606fe9ab1e4322f4c899d89543e3d673f4be73 Mon Sep 17 00:00:00 2001 From: winskuo-quic Date: Tue, 1 Oct 2024 13:41:53 +0800 Subject: [PATCH] Qualcomm AI Engine Direct - ConvFormer Enablement --- .../_passes/fuse_consecutive_transpose.py | 85 ++++++++--- backends/qualcomm/_passes/layout_transform.py | 1 + backends/qualcomm/builders/__init__.py | 2 + .../builders/op_adaptive_avg_pool2d.py | 125 ++++++++++++++++ backends/qualcomm/builders/op_layer_norm.py | 24 +-- backends/qualcomm/builders/op_rms_norm.py | 2 +- backends/qualcomm/quantizer/annotators.py | 5 + backends/qualcomm/tests/models.py | 22 ++- backends/qualcomm/tests/test_qnn_delegate.py | 61 +++++++- backends/qualcomm/utils/utils.py | 10 +- examples/qualcomm/oss_scripts/conv_former.py | 139 ++++++++++++++++++ 11 files changed, 428 insertions(+), 48 deletions(-) create mode 100644 backends/qualcomm/builders/op_adaptive_avg_pool2d.py create mode 100644 examples/qualcomm/oss_scripts/conv_former.py diff --git a/backends/qualcomm/_passes/fuse_consecutive_transpose.py b/backends/qualcomm/_passes/fuse_consecutive_transpose.py index c81818e00e8..16ce3803076 100644 --- a/backends/qualcomm/_passes/fuse_consecutive_transpose.py +++ b/backends/qualcomm/_passes/fuse_consecutive_transpose.py @@ -15,8 +15,18 @@ class FuseConsecutiveTranspose(ExportPass): """ - This pass fuses consecutive transpose / permute into one to reduce runtime - overhead + This pass fuses consecutive transpose / permute into one or none to reduce runtime + overhead. + To simplify the fuse logic, we ensure each permute node's output has at most 1 permute node + by cloning transpose. + Example: + Before clone transpose: + relu -> permute1 ─> permute2 + |──────> permute3 + + After clone transpose: + relu ─> permute1 ──────> permute2 + |───> permute4(new) ─> permute3 """ def __init__(self): @@ -27,6 +37,30 @@ def __init__(self): self.visited = set() self.nodes = [] + def _clone_transpose( + self, graph_module: torch.fx.GraphModule + ) -> torch.fx.GraphModule: + graph = graph_module.graph + for n in graph_module.graph.nodes: + if n.target in self.op_map: + users = [user for user in list(n.users) if user.target in self.op_map] + if len(users) > 1: + for i in range(1, len(users)): + with graph.inserting_after(n): + clone_permute_node = graph.create_node( + "call_function", + exir_ops.edge.aten.permute_copy.default, + (n.args[0], n.args[1]), + ) + clone_permute_node.meta = n.meta + users[i].replace_input_with(n, clone_permute_node) + + def _is_dispensable(self, axis_order): + for index, value in enumerate(axis_order): + if index != value: + return False + return True + def _traverse(self, node): if node in self.visited or node.target not in self.op_map: return @@ -34,47 +68,50 @@ def _traverse(self, node): self.nodes.append(node) self.visited.add(node) next_users = [n for n in list(node.users) if n.target in self.op_map] + + assert ( + len(next_users) <= 1 + ), "Each permute node should have at most 1 permute output node after _clone_transpose" if not next_users: return - - if len(next_users) == 1: - self._traverse(list(node.users)[0]) else: - raise NotImplementedError( - f"Check the node {node}, wich encounter mutilple permute output case" - ) + self._traverse(list(node.users)[0]) def _fuse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: graph = graph_module.graph for n in graph_module.graph.nodes: self._traverse(n) if len(self.nodes) > 1: - permute_order = [] input_node, output_node = self.nodes[0].args[0], self.nodes[-1] input_shape = input_node.meta["val"].shape axis_order = torch.arange(len(input_shape)).tolist() for node in self.nodes: - permute_order.append(node.args[1]) axis_order = [axis_order[i] for i in node.args[1]] - with graph.inserting_after(input_node): - permute_op = exir_ops.edge.aten.permute_copy.default - permute_node = graph.create_node( - "call_function", permute_op, (input_node, axis_order) - ) - users = output_node.users.copy() - for user in users: - user.replace_input_with(output_node, permute_node) - - # copy metadata - permute_node.meta = output_node.meta - # Without "qnn_permute", we might obtain wrong input shape - if [pn.meta.get(QCOM_INSERTED_PERMUTE) for pn in self.nodes]: - permute_node.meta[QCOM_INSERTED_PERMUTE] = True + # If axis order is just [0,1,2,3], we ignore permute node + if self._is_dispensable(axis_order): + for user in output_node.users.copy(): + user.replace_input_with(output_node, n.args[0]) + else: + with graph.inserting_after(input_node): + permute_op = exir_ops.edge.aten.permute_copy.default + permute_node = graph.create_node( + "call_function", permute_op, (input_node, axis_order) + ) + users = output_node.users.copy() + for user in users: + user.replace_input_with(output_node, permute_node) + + # copy metadata + permute_node.meta = output_node.meta + # Without "qnn_permute", we might obtain wrong input shape + if [pn.meta.get(QCOM_INSERTED_PERMUTE) for pn in self.nodes]: + permute_node.meta[QCOM_INSERTED_PERMUTE] = True # clear current stack self.nodes = [] def call(self, graph_module: torch.fx.GraphModule): + self._clone_transpose(graph_module) self._fuse(graph_module) graph_module.recompile() dead_code_elimination_pass(graph_module) diff --git a/backends/qualcomm/_passes/layout_transform.py b/backends/qualcomm/_passes/layout_transform.py index 098910ed86f..ccc34d3a528 100644 --- a/backends/qualcomm/_passes/layout_transform.py +++ b/backends/qualcomm/_passes/layout_transform.py @@ -30,6 +30,7 @@ class LayoutTransform(ExportPass): """ layout_sensitive_ops = { + exir_ops.edge.aten.adaptive_avg_pool2d.default, exir_ops.edge.aten.avg_pool2d.default, exir_ops.edge.aten.convolution.default, exir_ops.edge.aten.max_pool2d_with_indices.default, diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index 61ed30679e1..7a4d6d764b6 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -7,6 +7,7 @@ from . import ( node_visitor, op_abs, + op_adaptive_avg_pool2d, op_add, op_arange, op_avg_pool2d, @@ -78,6 +79,7 @@ __all__ = [ node_visitor, op_abs, + op_adaptive_avg_pool2d, op_add, op_arange, op_avg_pool2d, diff --git a/backends/qualcomm/builders/op_adaptive_avg_pool2d.py b/backends/qualcomm/builders/op_adaptive_avg_pool2d.py new file mode 100644 index 00000000000..c944e1646e7 --- /dev/null +++ b/backends/qualcomm/builders/op_adaptive_avg_pool2d.py @@ -0,0 +1,125 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import warnings +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import numpy as np + +import torch + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpPoolAvg2d, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class AdaptiveAvgPool2D(NodeVisitor): + target = ["aten.adaptive_avg_pool2d.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + input_height = input_tensor.shape[1] + input_width = input_tensor.shape[2] + + output_height = node.args[1][0] + output_width = node.args[1][1] + + filter_height = input_height // output_height + filter_width = input_width // output_width + filter = [filter_height, filter_width] + filter_shape = [len(filter)] + + stride_height = filter_height + stride_width = filter_width + stride = [stride_height, stride_width] + stride_shape = [len(stride)] + + height = (output_height - 1) * stride_height + filter_height - input_height + width = (output_width - 1) * stride_width + filter_width - input_width + if height % 2 != 0 or width % 2 != 0: + warnings.warn( + "[QNN Delegate Op Builder]: Height or Width is not divisble by 2 with no remainder, fall back op", + stacklevel=1, + ) + return + + padding_height = height / 2 + padding_width = width / 2 + padding = [padding_height, padding_width] + padding_shape = [2, 2] + + out_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + node, + out_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + ) + + adaptive_avg_pool2d_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpPoolAvg2d.op_name, + ) + + adaptive_avg_pool2d_op.AddInputTensors([input_tensor_wrapper]) + adaptive_avg_pool2d_op.AddOutputTensors([output_tensor_wrapper]) + + adaptive_avg_pool2d_op.AddTensorParam( + OpPoolAvg2d.param_filter_size, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(filter_shape), + filter_shape, + np.array( + filter, + dtype=np.uint32, + ), + True, + ) + + adaptive_avg_pool2d_op.AddTensorParam( + OpPoolAvg2d.param_stride, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(stride_shape), + stride_shape, + np.array( + stride, + dtype=np.uint32, + ), + True, + ) + + adaptive_avg_pool2d_op.AddTensorParam( + OpPoolAvg2d.param_pad_amount, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(padding_shape), + padding_shape, + np.array( + [[padding[0], padding[0]], [padding[1], padding[1]]], + dtype=np.uint32, + ), + True, + ) + + return adaptive_avg_pool2d_op diff --git a/backends/qualcomm/builders/op_layer_norm.py b/backends/qualcomm/builders/op_layer_norm.py index 2006c716489..06f822014ed 100644 --- a/backends/qualcomm/builders/op_layer_norm.py +++ b/backends/qualcomm/builders/op_layer_norm.py @@ -63,15 +63,19 @@ def define_node( nodes_to_wrappers, ) + layer_norm_input_tensors = [input_tensor_wrapper, weight_tensor_wrapper] + bias_node = node.args[3] - bias_tensor = get_parameter(bias_node, self.edge_program) - bias_tensor_wrapper = self.define_tensor( - bias_node, - node, - bias_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, - nodes_to_wrappers, - ) + if bias_node is not None: + bias_tensor = get_parameter(bias_node, self.edge_program) + bias_tensor_wrapper = self.define_tensor( + bias_node, + node, + bias_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, + ) + layer_norm_input_tensors.append(bias_tensor_wrapper) epsilon = node.args[4] @@ -89,9 +93,7 @@ def define_node( QNN_OP_PACKAGE_NAME_QTI_AISW, OpLayerNorm.op_name, ) - layer_norm_op.AddInputTensors( - [input_tensor_wrapper, weight_tensor_wrapper, bias_tensor_wrapper] - ) + layer_norm_op.AddInputTensors(layer_norm_input_tensors) layer_norm_op.AddOutputTensors([output_tensor_wrapper]) layer_norm_op.AddScalarParam( OpLayerNorm.param_epsilon, diff --git a/backends/qualcomm/builders/op_rms_norm.py b/backends/qualcomm/builders/op_rms_norm.py index d1daa6c1e54..e5b4778312e 100644 --- a/backends/qualcomm/builders/op_rms_norm.py +++ b/backends/qualcomm/builders/op_rms_norm.py @@ -66,7 +66,7 @@ def define_node( nodes_to_wrappers, ) - # Fake node, nn moudle seems to be inconsistant with document + # Fake node, nn module seems to be inconsistant with document bias_tensor = torch.zeros(weight_tensor.shape) bias_node = torch.fx.Node( node.graph, diff --git a/backends/qualcomm/quantizer/annotators.py b/backends/qualcomm/quantizer/annotators.py index e1792cb1830..8bf2265fb5b 100644 --- a/backends/qualcomm/quantizer/annotators.py +++ b/backends/qualcomm/quantizer/annotators.py @@ -512,6 +512,11 @@ def annotate_sqrt(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) +@register_annotator([torch.ops.aten.square.default]) +def annotate_square(node: Node, quantization_config: QuantizationConfig) -> None: + annotate_single_in_single_out(node, quantization_config) + + @register_annotator([torch.ops.aten.gelu.default]) def annotate_gelu(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index d66aa34e5af..3ad183c2c26 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -16,6 +16,15 @@ def forward(self, x): return torch.abs(x) +class AdaptiveAvgPool2D(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + adaptive_avg_pool = torch.nn.AdaptiveAvgPool2d((1, 1)) + return adaptive_avg_pool(x) + + class Add(torch.nn.Module): def __init__(self): super().__init__() @@ -685,15 +694,24 @@ def forward(self, x): class LayerNorm(torch.nn.Module): - def __init__(self): + def __init__(self, bias=True): super().__init__() - self.layer_norm = torch.nn.LayerNorm([768], eps=1e-6) + self.layer_norm = torch.nn.LayerNorm([768], eps=1e-6, bias=bias) self.linear = torch.nn.Linear(768, 196) def forward(self, x): return self.linear(self.layer_norm(x)) +class LayerNormAdd(torch.nn.Module): + def __init__(self): + super().__init__() + self.layer_norm = torch.nn.LayerNorm([512], eps=1e-6, bias=False) + + def forward(self, x, y): + return self.layer_norm(x) + y + + class LeakyReLUDefault(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 30ed34032f4..4ac73ed39b1 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -97,6 +97,11 @@ def test_qnn_backend_abs(self): sample_input = (torch.randn(1, 2, 3, 4),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_adaptive_avg_pool2d(self): + module = AdaptiveAvgPool2D() # noqa: F405 + sample_input = (torch.randn(1, 512, 7, 7),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_arange(self): modules = [ Arange(start=1, end=11, step=1, dtype=torch.int32), # noqa: F405 @@ -432,9 +437,11 @@ def test_qnn_backend_interpolate_nearest_2d(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_layer_norm(self): - module = LayerNorm() # noqa: F405 + modules = [LayerNorm(), LayerNorm(bias=False)] # noqa: F405 sample_input = (torch.randn(196, 768),) - self.lower_module_and_test_output(module, sample_input) + for i, module in enumerate(modules): + with self.subTest(i=i): + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_leaky_relu(self): test_comb = [ @@ -915,6 +922,12 @@ def test_qnn_backend_abs(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_adaptive_avg_pool2d(self): + module = AdaptiveAvgPool2D() # noqa: F405 + sample_input = (torch.randn(1, 512, 7, 7),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_arange(self): modules = [ Arange(start=1, end=6, step=0.5, dtype=torch.float32), # noqa: F405 @@ -1280,10 +1293,12 @@ def test_qnn_backend_interpolate_nearest_2d(self): self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_layer_norm(self): - module = LayerNorm() # noqa: F405 + modules = [LayerNorm(), LayerNorm(bias=False)] # noqa: F405 sample_input = (torch.randn(196, 768),) - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) + for i, module in enumerate(modules): + with self.subTest(i=i): + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) def test_qnn_backend_leaky_relu(self): test_comb = [ @@ -2675,6 +2690,42 @@ def required_envs(self, conditions=None) -> bool: ] ) + def test_conv_former(self): + if not self.required_envs([self.image_dataset]): + self.skipTest("missing required envs") + + cmds = [ + "python", + f"{self.executorch_root}/examples/qualcomm/oss_scripts/conv_former.py", + "--dataset", + self.image_dataset, + "--artifact", + self.artifact_dir, + "--build_folder", + self.build_folder, + "--device", + self.device, + "--model", + self.model, + "--ip", + self.ip, + "--port", + str(self.port), + ] + if self.host: + cmds.extend(["--host", self.host]) + + p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL) + with Listener((self.ip, self.port)) as listener: + conn = listener.accept() + p.communicate() + msg = json.loads(conn.recv()) + if "Error" in msg: + self.fail(msg["Error"]) + else: + self.assertGreaterEqual(msg["top_1"], 60) + self.assertGreaterEqual(msg["top_5"], 80) + def test_dino_v2(self): if not self.required_envs([self.image_dataset]): self.skipTest("missing required envs") diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index a4acae9585b..e15050fe4c2 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -90,7 +90,7 @@ from executorch.exir.capture import ExecutorchBackendConfig from executorch.exir.lowered_backend_module import LoweredBackendModule from executorch.exir.program._program import _get_updated_graph_signature -from torch._decomp import core_aten_decompositions as torch_core_aten_decompositions +from torch._decomp import core_aten_decompositions, remove_decompositions from torch.export.exported_program import ExportedProgram from torch.fx import passes from torch.fx.passes.operator_support import OperatorSupportBase @@ -283,9 +283,10 @@ def set_spec(module, options): def get_decomp_table() -> Dict[torch._ops.OperatorBase, Callable]: - source_decompositions = torch_core_aten_decompositions() + source_decompositions = core_aten_decompositions() # The below super ops are supported by QNN - remove_decompositions = [ + skip_decompositions = [ + torch.ops.aten.adaptive_avg_pool2d.default, torch.ops.aten.pixel_shuffle.default, torch.ops.aten.pixel_unshuffle.default, torch.ops.aten.hardsigmoid.default, @@ -293,8 +294,7 @@ def get_decomp_table() -> Dict[torch._ops.OperatorBase, Callable]: torch.ops.aten._safe_softmax.default, ] - for key in remove_decompositions: - source_decompositions.pop(key) + remove_decompositions(source_decompositions, skip_decompositions) return source_decompositions diff --git a/examples/qualcomm/oss_scripts/conv_former.py b/examples/qualcomm/oss_scripts/conv_former.py new file mode 100644 index 00000000000..76131d659df --- /dev/null +++ b/examples/qualcomm/oss_scripts/conv_former.py @@ -0,0 +1,139 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +import sys +from multiprocessing.connection import Client + +import numpy as np +import timm +import torch +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.backends.qualcomm.utils.constants import ( + QCOM_PASS_EXPAND_BROADCAST_SHAPE, +) +from executorch.examples.qualcomm.utils import ( + build_executorch_binary, + get_imagenet_dataset, + make_output_dir, + parse_skip_delegation_node, + setup_common_args_and_variables, + SimpleADB, + topk_accuracy, +) + + +def main(args): + skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + if not args.compile_only and args.device is None: + raise RuntimeError( + "device serial is required if not compile only. " + "Please specify a device serial by -s/--device argument." + ) + + data_num = 100 + if args.compile_only: + inputs = [(torch.rand(1, 3, 224, 224),)] + else: + inputs, targets, input_list = get_imagenet_dataset( + dataset_path=f"{args.dataset}", + data_size=data_num, + image_shape=(256, 256), + crop_size=224, + ) + + pte_filename = "conv_former" + model = timm.create_model("convformer_s18.sail_in1k", pretrained=True) + + model = model.eval() + + build_executorch_binary( + model, + inputs[0], + args.model, + f"{args.artifact}/{pte_filename}", + inputs, + quant_dtype=QuantDtype.use_8a8w, + custom_pass_config={QCOM_PASS_EXPAND_BROADCAST_SHAPE}, + ) + + if args.compile_only: + sys.exit(0) + + adb = SimpleADB( + qnn_sdk=os.getenv("QNN_SDK_ROOT"), + build_path=f"{args.build_folder}", + pte_path=f"{args.artifact}/{pte_filename}.pte", + workspace=f"/data/local/tmp/executorch/{pte_filename}", + device_id=args.device, + host_id=args.host, + soc_model=args.model, + shared_buffer=args.shared_buffer, + ) + adb.push(inputs=inputs, input_list=input_list) + adb.execute() + + # collect output data + output_data_folder = f"{args.artifact}/outputs" + make_output_dir(output_data_folder) + + adb.pull(output_path=args.artifact) + + # top-k analysis + predictions = [] + for i in range(data_num): + predictions.append( + np.fromfile( + os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32 + ) + ) + + k_val = [1, 5] + topk = [topk_accuracy(predictions, targets, k).item() for k in k_val] + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({f"top_{k}": topk[i] for i, k in enumerate(k_val)})) + else: + for i, k in enumerate(k_val): + print(f"top_{k}->{topk[i]}%") + + +if __name__ == "__main__": + parser = setup_common_args_and_variables() + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. Default ./conv_former", + default="./conv_former", + type=str, + ) + + parser.add_argument( + "-d", + "--dataset", + help=( + "path to the validation folder of ImageNet dataset. " + "e.g. --dataset imagenet-mini/val " + "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" + ), + type=str, + required=True, + ) + + args = parser.parse_args() + try: + main(args) + except Exception as e: + if args.ip and args.port != -1: + with Client((args.ip, args.port)) as conn: + conn.send(json.dumps({"Error": str(e)})) + else: + raise Exception(e)