diff --git a/examples/arm/executor_runner/arm_perf_monitor.cpp b/examples/arm/executor_runner/arm_perf_monitor.cpp index 58a47105743..35fd114f777 100644 --- a/examples/arm/executor_runner/arm_perf_monitor.cpp +++ b/examples/arm/executor_runner/arm_perf_monitor.cpp @@ -19,7 +19,7 @@ namespace { #if defined(ETHOSU55) || defined(ETHOSU65) const uint32_t ethosu_pmuCountersUsed = 4; #elif defined(ETHOSU85) -const uint32_t ethosu_pmuCountersUsed = 5; +const uint32_t ethosu_pmuCountersUsed = 7; #else #error No NPU target defined #endif @@ -65,11 +65,14 @@ void ethosu_inference_begin(struct ethosu_driver* drv, void*) { ETHOSU_PMU_Set_EVTYPER(drv, 2, ETHOSU_PMU_EXT_RD_DATA_BEAT_RECEIVED); ETHOSU_PMU_Set_EVTYPER(drv, 3, ETHOSU_PMU_EXT_WR_DATA_BEAT_WRITTEN); ETHOSU_PMU_Set_EVTYPER(drv, 4, ETHOSU_PMU_NPU_IDLE); - // Enable the 5 counters + ETHOSU_PMU_Set_EVTYPER(drv, 5, ETHOSU_PMU_MAC_ACTIVE); + ETHOSU_PMU_Set_EVTYPER(drv, 6, ETHOSU_PMU_WD_ACTIVE); + // Enable the 7 counters ETHOSU_PMU_CNTR_Enable( drv, ETHOSU_PMU_CNT1_Msk | ETHOSU_PMU_CNT2_Msk | ETHOSU_PMU_CNT3_Msk | - ETHOSU_PMU_CNT4_Msk | ETHOSU_PMU_CNT5_Msk); + ETHOSU_PMU_CNT4_Msk | ETHOSU_PMU_CNT5_Msk | ETHOSU_PMU_CNT6_Msk | + ETHOSU_PMU_CNT7_Msk); #else #error No NPU target defined #endif @@ -214,7 +217,7 @@ void StopMeasurements(int num_inferences) { #elif defined(ETHOSU85) ET_LOG( Info, - "Ethos-U PMU Events:[ETHOSU_PMU_SRAM_RD_DATA_BEAT_RECEIVED, ETHOSU_PMU_SRAM_WR_DATA_BEAT_WRITTEN, ETHOSU_PMU_EXT_RD_DATA_BEAT_RECEIVED, ETHOSU_PMU_EXT_WR_DATA_BEAT_WRITTEN, ETHOSU_PMU_NPU_IDLE]"); + "Ethos-U PMU Events:[ETHOSU_PMU_SRAM_RD_DATA_BEAT_RECEIVED, ETHOSU_PMU_SRAM_WR_DATA_BEAT_WRITTEN, ETHOSU_PMU_EXT_RD_DATA_BEAT_RECEIVED, ETHOSU_PMU_EXT_WR_DATA_BEAT_WRITTEN, ETHOSU_PMU_NPU_IDLE, ETHOSU_PMU_MAC_ACTIVE, ETHOSU_PMU_WD_ACTIVE]"); #else #error No NPU target defined #endif diff --git a/examples/arm/pruning_minimal_example.ipynb b/examples/arm/pruning_minimal_example.ipynb new file mode 100644 index 00000000000..78bb3f06b5b --- /dev/null +++ b/examples/arm/pruning_minimal_example.ipynb @@ -0,0 +1,566 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c0156802", + "metadata": {}, + "source": [ + "# Copyright 2025 Arm Limited and/or its affiliates.\n", + "#\n", + "# This source code is licensed under the BSD-style license found in the\n", + "# LICENSE file in the root directory of this source tree." + ] + }, + { + "cell_type": "markdown", + "id": "26b849fd", + "metadata": {}, + "source": [ + "# Introduction\n", + "Model conditioning techniques like pruning modify the weights of a Machine Learning model and in some cases allow significant speed-up of the inference execution, reduction of the memory footprint and reduction in the overall power consumption of the system. Assuming you can optimise your workload without loss in accuracy and you target an Arm® Ethos™ NPU or a GPU with a Neural Engine, you should consider pruning the neural network before compiling it in the to_edge_transform_and_lower stage." + ] + }, + { + "cell_type": "markdown", + "id": "9a7d6d97", + "metadata": {}, + "source": [ + "# Why apply model conditioning?\n", + "The Ethos-U hardware has a dedicated weight decoder to process the model weights. At the same time, the compiler arranges the weights into blocks and the blocks are then fed to the hardware weight decoder. As part of the block arrangement process, the compiler compresses sequences of zero weights and clusters of weights. To avoid any doubt, the compression by the compiler is lossless - to the same input tensor, irrespective of whether compression was applied or not, the output tensor from execution on the NPU will be the same. If the model you provide in the to_edge_transform_and_lower stage is optimised to have sequences of zero weights and/or clusters of the same weights, the compiler will be able to compress these weights very efficiently. The good compression would result in lower number of memory accesses by the NPU at runtime, which would mean that the MAC engines are not waiting on memory accesses resulting in better overall performance. In other words, if you have a memory bound model, you should consider pruning and clustering your neural network before lowering it in the to_edge_transform_and_lower stage.\n", + "\n", + "The Ethos-U85 hardware also has hardware support for 2:4 sparse weights - if you have 2:4 sparse weights, the MAC array will skip multiplications where the result will be 0. The 2:4 sparsity allow power savings for all configurations and provides a speed-up on compute-bound neural networks.\n", + "\n", + "Before we begin, make sure you are running the Jupyter notebook from the correct python virtual environment variable." + ] + }, + { + "cell_type": "markdown", + "id": "d6532247", + "metadata": {}, + "source": [ + "# Prerequisites\n", + "Let's import python the packages you will need to run through the jupyter notebook." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a8a191d7", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torchvision import datasets, transforms\n", + "from torch import nn\n", + "import torch.nn.utils.prune as prune\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import DataLoader, Subset\n", + "import random\n", + "\n", + "from executorch.backends.arm.ethosu import EthosUPartitioner\n", + "from executorch.exir import (\n", + " EdgeCompileConfig,\n", + " ExecutorchBackendConfig,\n", + " to_edge_transform_and_lower,\n", + ")\n", + "from executorch.backends.arm.ethosu import EthosUCompileSpec\n", + "from executorch.backends.arm.quantizer import (\n", + " EthosUQuantizer,\n", + " get_symmetric_quantization_config,\n", + ")\n", + "from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e\n", + "from executorch.extension.export_util.utils import save_pte_program" + ] + }, + { + "cell_type": "markdown", + "id": "6af794bc", + "metadata": {}, + "source": [ + "# Model conditioning with PyTorch and deployment with ExecuTorch \n", + "We'll define a simple model with 3 back-to-back Linear layers. We will execute the model on the Ethos-U85 NPU, then we will prune the model and execute the pruned variant on the Ethos-U85 and compare the performance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3e37c2ce", + "metadata": {}, + "outputs": [], + "source": [ + "LR = 1e-3\n", + "NUM_EPOCHS = 1\n", + "BATCH_SIZE = 128\n", + "\n", + "# Data\n", + "transform = transforms.Compose([transforms.ToTensor()])\n", + "train_ds = datasets.MNIST(\"./data\", train=True, download=True, transform=transform)\n", + "test_ds = datasets.MNIST(\"./data\", train=False, transform=transform)\n", + "\n", + "train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)\n", + "test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)\n", + "\n", + "class Simple_NN(nn.Module): \n", + " def __init__(self):\n", + " super().__init__()\n", + " self.flatten = nn.Flatten()\n", + " self.fc1 = nn.Linear(28 * 28, 512)\n", + " self.fc2 = nn.Linear(512, 256)\n", + " self.fc3 = nn.Linear(256, 10)\n", + "\n", + " def forward(self, x):\n", + " x = self.flatten(x)\n", + " x = F.relu(self.fc1(x))\n", + " x = F.relu(self.fc2(x))\n", + " x = self.fc3(x)\n", + " return x\n", + " \n", + " def prunable_parameters(self):\n", + " return (\n", + " (self.fc1, \"weight\"),\n", + " (self.fc2, \"weight\"),\n", + " (self.fc3, \"weight\"),\n", + " )\n", + "\n", + " def prune(self, pruning_method: prune.BasePruningMethod, amount: float = 0.1):\n", + " # reference https://pytorch.org/tutorials/intermediate/pruning_tutorial.html\n", + "\n", + " # produces a mask that is multiplied with the parameter\n", + " prune.global_unstructured(\n", + " self.prunable_parameters(),\n", + " pruning_method=pruning_method,\n", + " amount=amount,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "6db1e58d", + "metadata": {}, + "source": [ + "We define a simple model with 3 back-to-back linear layers. Linear is highly memory bound operation because every weight is read once only from the external memory. It is impossible to buffer the weights in memory(you usually have more weights in the external memory than space in the SARM) and reuse them for the computation. In comparison, in a convolution you usually have small filter sizes(e.g. 3x3 filter) which means you can buffer all the convolution weights in memory and reuse them for the computation. If your model or module within the model is composed entirely of Linear layers, the workload will be memory bound and pruning is likely to provide good speed-up.\n", + "\n", + "Next, let's define a simple function to train the network and a function to evaluate the accuracy of the model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "477312ae", + "metadata": {}, + "outputs": [], + "source": [ + "# Training loop\n", + "def train(model):\n", + " # The model is simple enough that we can train it on CPU\n", + " device = \"cpu\"\n", + " for epoch in range(NUM_EPOCHS):\n", + " # ---- Training ----\n", + " model.train()\n", + " opt = torch.optim.Adam(model.parameters(), lr=LR)\n", + " criterion = torch.nn.CrossEntropyLoss()\n", + " for step, (inp, out_real) in enumerate(train_loader):\n", + " inp, out_real = inp.to(device), out_real.to(device)\n", + " opt.zero_grad()\n", + " out_pred = model(inp)\n", + " loss = criterion(out_pred, out_real)\n", + " #print(f\"Loss: {loss.item():.4f}\")\n", + " loss.backward()\n", + " opt.step()\n", + "\n", + "def evaluate(model):\n", + " # ---- Evaluation ----\n", + " correct, total = 0, 0\n", + " with torch.no_grad():\n", + " for inp, out_real in test_loader:\n", + " out_pred = model(inp)\n", + " preds = out_pred.argmax(1)\n", + " correct += (preds == out_real).sum().item()\n", + " total += out_real.size(0)\n", + "\n", + " acc = 100 * correct / total\n", + " print(f\"Top 1 accuracy = {acc:.2f}%\")" + ] + }, + { + "cell_type": "markdown", + "id": "a4750eaf", + "metadata": {}, + "source": [ + "Let's instantiate the model and train it. In order to get reproducible results, we will fix the seed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc68a7d9", + "metadata": {}, + "outputs": [], + "source": [ + "SEED = 123\n", + "torch.manual_seed(SEED)\n", + "model = Simple_NN()\n", + "train(model)\n", + "print(\"Evaluate FP32 model accuracy\")\n", + "evaluate(model)" + ] + }, + { + "cell_type": "markdown", + "id": "9837d9ba", + "metadata": {}, + "source": [ + "We obtain 96% top1 accuracy for the FP32 model.\n", + "\n", + "Next, we would like to apply post-training quantization with ExecuTorch and evaluate the accuracy of the quantized model. It is important to calibrate the quantized model on a few real samples from the MNIST dataset to get good quantization parameters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "855c542f", + "metadata": {}, + "outputs": [], + "source": [ + "# MNIST images are 28x28 in greyscale, hence the shape is 1x1x28x28\n", + "example_inputs = (torch.randn(1,1,28,28),)\n", + "exported_program = torch.export.export(model, example_inputs)\n", + "graph_module = exported_program.module(check_guards=False)\n", + "\n", + "# Create a compilation spec describing the target for configuring the quantizer\n", + "compile_spec = EthosUCompileSpec(\n", + " target=\"ethos-u85-128\",\n", + " system_config=\"Ethos_U85_SYS_Flash_High\",\n", + " memory_mode=\"Shared_Sram\",\n", + " extra_flags=[\"--output-format=raw\", \"--debug-force-regor --verbose-weights\"]\n", + " )\n", + "\n", + "# Create and configure quantizer to use a symmetric quantization config globally on all nodes\n", + "quantizer = EthosUQuantizer(compile_spec)\n", + "operator_config = get_symmetric_quantization_config()\n", + "quantizer.set_global(operator_config)\n", + "\n", + "# Post training quantization, need a few example images to obtain good quantization parameters\n", + "subset_indices = random.sample(range(len(train_ds)), 50)\n", + "calibration_set = Subset(train_ds, subset_indices)\n", + "calibration_loader = DataLoader(calibration_set, shuffle=False)\n", + "\n", + "quantized_graph_module = prepare_pt2e(graph_module, quantizer)\n", + "for batch_images,label in calibration_loader:\n", + " quantized_graph_module(*batch_images) # Calibrate the graph module with the example input\n", + "quantized_graph_module = convert_pt2e(quantized_graph_module)" + ] + }, + { + "cell_type": "markdown", + "id": "996faefd", + "metadata": {}, + "source": [ + "Next, let us evaluate the accuracy of the quantized model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "63da2b30", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Accuracy of the quantized model\")\n", + "evaluate(quantized_graph_module)" + ] + }, + { + "cell_type": "markdown", + "id": "2ff3462c", + "metadata": {}, + "source": [ + "We maintain the 96% top1 accuracy for the quantized model. Next, let's compile the model for the Ethos-U backend. We will define a function `generate_pte` that calls `to_edge_transform_and_lower` and saves the pte file on device." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aa8259f4", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_pte(quantized_exported_program,compile_spec,name):\n", + " # Create partitioner from compile spec\n", + " partitioner = EthosUPartitioner(compile_spec)\n", + "\n", + " # Lower the exported program to the Ethos-U backend\n", + " edge_program_manager = to_edge_transform_and_lower(\n", + " quantized_exported_program,\n", + " partitioner=[partitioner],\n", + " compile_config=EdgeCompileConfig(\n", + " _check_ir_validity=False,\n", + " ),\n", + " )\n", + "\n", + " # Convert edge program to executorch\n", + " executorch_program_manager = edge_program_manager.to_executorch(\n", + " config=ExecutorchBackendConfig(extract_delegate_segments=False)\n", + " )\n", + "\n", + " # Save pte file\n", + " save_pte_program(executorch_program_manager, f\"{name}.pte\")\n", + "\n", + "# Create a new exported program using the quantized_graph_module\n", + "quantized_exported_program = torch.export.export(quantized_graph_module, example_inputs)\n", + "generate_pte(quantized_exported_program,compile_spec,\"original_model\")" + ] + }, + { + "cell_type": "markdown", + "id": "2b6cae04", + "metadata": {}, + "source": [ + "Note that as part of the compilation process in `to_edge_transform_and_lower`, we get Weight Compression information:\n", + "```\n", + "Original Weights Size 522.50 KiB\n", + "NPU Encoded Weights Size 507.44 KiB\n", + "```\n", + "In other words, the original Weights are 522KB and after compilation and encoding by the compiler, we get 507KB of weights that will be read by the NPU at runtime. Remember this is for the case when we've not applied pruning or clustering. This will generate original_model.pte file that we will deploy on device later on. \n", + "\n", + "Next, let's move on to prune the model and evaluate its accuracy. We have a lot of weights in the original network, so we will apply 95% pruning rate." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "493eed60", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Prune the model\")\n", + "model.prune(pruning_method=prune.L1Unstructured, amount=0.95)\n", + "print(\"Evaluate pruned model accuracy\")\n", + "evaluate(model)" + ] + }, + { + "cell_type": "markdown", + "id": "82460ba6", + "metadata": {}, + "source": [ + "We obtain 37% top1 accuracy for the pruned model. That can seem surprising at first sight, but remember that when we prune, we randomly set 95% of the weights to 0. It is normal to lose accuracy when applying pruning. We need to retrain the model in order to recover the accuracy we've lost from the pruning. We can do that easily by calling the train function one more time. Once we are done with the retraining, it is important to remove the parameters we've pruned." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c816ad25", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"Train the pruned model to recover the lost information\")\n", + "train(model)\n", + "# Remove the pruned parameters when we've retrained the model and recovered the lost accuracy\n", + "for a,b in model.prunable_parameters():\n", + " prune.remove(a, b)\n", + "\n", + "print(\"Evaluate pruned model accuracy\")\n", + "evaluate(model)" + ] + }, + { + "cell_type": "markdown", + "id": "fbb70d47", + "metadata": {}, + "source": [ + "We obtain 96% top1 accuracy for the pruned workload so we have recovered the accuracy we've lost with the pruning. Let's quantize the pruned model, evaluate the accuracy of the int8 network and obtain a pte file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3cdb0f59", + "metadata": {}, + "outputs": [], + "source": [ + "pruned_exported_program = torch.export.export(model, example_inputs)\n", + "pruned_graph_module = pruned_exported_program.module(check_guards=False)\n", + "quantized_pruned_graph_module = prepare_pt2e(pruned_graph_module, quantizer)\n", + "for batch_images,label in calibration_loader:\n", + " quantized_pruned_graph_module(*batch_images) # Calibrate the graph module with the example input\n", + "quantized_pruned_graph_module = convert_pt2e(quantized_pruned_graph_module)\n", + "print(\"Accuracy of the pruned quantized model\")\n", + "evaluate(quantized_pruned_graph_module)\n", + "\n", + "quantized_ep_pruned = torch.export.export(quantized_pruned_graph_module, example_inputs)\n", + "generate_pte(quantized_ep_pruned,compile_spec,\"pruned_model\")" + ] + }, + { + "cell_type": "markdown", + "id": "4263714e", + "metadata": {}, + "source": [ + "We obtain 96% top1 accuracy of the quantized pruned model. What is interesting is that this time, the NPU encoded weights size shrank considerably:\n", + "```\n", + "Original Weights Size 522.50 KiB\n", + "NPU Encoded Weights Size 46.12 KiB\n", + "```\n", + "In other words, we are now solving the MNIST classification problem with just 46KB of encoded weights. This is a significant reduction from the 507KB we had in the original model.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "562fdb16", + "metadata": {}, + "source": [ + "# NPU performance\n", + "In the sections above, we generated two pte files - one pte for the original model and another pte for the pruned model. These models perform very similarly in terms of accuracy. Let's run both of these models on the NPU and analyse the performance at runtime.\n", + "\n", + "# Performance of the original model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4bdd91dc", + "metadata": {}, + "outputs": [], + "source": [ + "%%bash\n", + "# Ensure the arm-none-eabi-gcc toolchain and FVP:s are available on $PATH\n", + "source ethos-u-scratch/setup_path.sh\n", + "\n", + "# Build executorch libraries cross-compiled for arm baremetal to executorch/cmake-out-arm\n", + "cmake --preset arm-baremetal \\\n", + "-DCMAKE_BUILD_TYPE=Release \\\n", + "-B../../cmake-out-arm ../..\n", + "cmake --build ../../cmake-out-arm --target install -j$(nproc) " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "756ab779", + "metadata": {}, + "outputs": [], + "source": [ + "%%bash \n", + "source ethos-u-scratch/setup_path.sh\n", + "# Build example executor runner application to examples/arm/ethos_u_minimal_example\n", + "cmake -DCMAKE_TOOLCHAIN_FILE=$(pwd)/ethos-u-setup/arm-none-eabi-gcc.cmake \\\n", + " -DCMAKE_BUILD_TYPE=Release \\\n", + " -DET_PTE_FILE_PATH=original_model.pte \\\n", + " -DTARGET_CPU=cortex-m55 \\\n", + " -DETHOSU_TARGET_NPU_CONFIG=ethos-u85-128 \\\n", + " -DMEMORY_MODE=Shared_Sram \\\n", + " -DSYSTEM_CONFIG=Ethos_U85_SYS_DRAM_Mid \\\n", + " -Bethos_u_original_model \\\n", + " executor_runner\n", + "cmake --build ethos_u_original_model -j$(nproc) -- arm_executor_runner" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0a525a09", + "metadata": {}, + "outputs": [], + "source": [ + "%%bash \n", + "source ethos-u-scratch/setup_path.sh\n", + "# Run the pruned model\n", + "../../backends/arm/scripts/run_fvp.sh --elf=ethos_u_original_model/arm_executor_runner --target=ethos-u85-128" + ] + }, + { + "cell_type": "markdown", + "id": "23ebdc46", + "metadata": {}, + "source": [ + "We obtain a total of 99k NPU Active cycles. The MAC engines of the NPU are active during 8k cycles and the Weight Decoder is active during 74k NPU cycles. It's worth noting that the data flow in the Ethos-U is pipelined. In other words, the MAC array and the Weight Decoder are working at the same time. Having a total of 99k NPU cycles and only 8k Active MAC cycles and 74k of Weight Decoder active cycles means that the NPU is spending most of the time decoding weights and the MAC array is underutilized. Pruning is designed to alleviate that bottleneck. Let's analyse the performance of the pruned workload.\n", + "\n", + "# Performance of the pruned model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e7c09926", + "metadata": {}, + "outputs": [], + "source": [ + "%%bash \n", + "source ethos-u-scratch/setup_path.sh\n", + "\n", + "# Build example executor runner application to examples/arm/ethos_u_minimal_example\n", + "cmake -DCMAKE_TOOLCHAIN_FILE=$(pwd)/ethos-u-setup/arm-none-eabi-gcc.cmake \\\n", + " -DCMAKE_BUILD_TYPE=Release \\\n", + " -DET_PTE_FILE_PATH=pruned_model.pte \\\n", + " -DTARGET_CPU=cortex-m55 \\\n", + " -DETHOSU_TARGET_NPU_CONFIG=ethos-u85-128 \\\n", + " -DMEMORY_MODE=Shared_Sram \\\n", + " -DSYSTEM_CONFIG=Ethos_U85_SYS_DRAM_Mid \\\n", + " -Bethos_u_pruned_model \\\n", + " executor_runner\n", + "cmake --build ethos_u_pruned_model -j$(nproc) -- arm_executor_runner" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "891947f7", + "metadata": {}, + "outputs": [], + "source": [ + "%%bash \n", + "source ethos-u-scratch/setup_path.sh\n", + "# Run the pruned model\n", + "../../backends/arm/scripts/run_fvp.sh --elf=ethos_u_pruned_model/arm_executor_runner --target=ethos-u85-128" + ] + }, + { + "cell_type": "markdown", + "id": "e55ae929", + "metadata": {}, + "source": [ + "On the pruned model, the inference completes in 22k NPU cycles. The NPU still performs 8k MACs, but this time the number of cycles when the weight decoder is active has dropped to to 17k cycles. \n", + "It's also worth noting that the size of the pte file has been reduced significantly - from 518 KB of the original model to 57KB of the pruned workload. " + ] + }, + { + "cell_type": "markdown", + "id": "d934fe41", + "metadata": {}, + "source": [ + "# Conclusion\n", + "We defined a simple model to solve the MNIST dataset. The model is using Linear layers and is heavily memory-bound on the external memory. We pruned the model and obtain similar int8 accuracy between the original workload and the pruned counterpart. Let us put the results from the runtime in a table and draw a few conclusions: \n", + "\n", + "| Model |NPU_ACTIVE cycles | NPU Encoded Weight Size | Weight Decoder Active Cycles | External memory beats read | Size of the pte file |\n", + "| ----------------------------------------|----------------- | ------------------------- | -----------------------------|---------------------------------|-----------------------|\n", + "| Original model | 97k | 506 KB | 74k | 32k | 517 KB |\n", + "| Pruned model | 22k | 46 KB | 8k | 3k | 57 KB |\n", + "\n", + "For the pruned network, we obtain a significant uplift - over 3x improvement in the inference speed and a drastic reduction in the number of cycles when the Weight Decoder is active. The NPU will consume lower power and the size of the pruned model that we save on-device is significantly smaller compared to the original network." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv_py3.10", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}