diff --git a/backends/mediatek/CMakeLists.txt b/backends/mediatek/CMakeLists.txt index 4b233d94f04..744b1193d5a 100644 --- a/backends/mediatek/CMakeLists.txt +++ b/backends/mediatek/CMakeLists.txt @@ -25,9 +25,13 @@ include_directories(BEFORE ${CMAKE_CURRENT_SOURCE_DIR}/runtime/include) # targets add_library(neuron_backend SHARED) -target_link_libraries( - neuron_backend PRIVATE executorch_no_prim_ops portable_ops_lib android log - ${NEURON_BUFFER_ALLOCATOR_LIB} +target_link_libraries(neuron_backend + PRIVATE + executorch_no_prim_ops + portable_ops_lib + android + log + ${NEURON_BUFFER_ALLOCATOR_LIB} ) target_sources( neuron_backend diff --git a/examples/mediatek/CMakeLists.txt b/examples/mediatek/CMakeLists.txt index 2abee59759f..1d411f07ca7 100644 --- a/examples/mediatek/CMakeLists.txt +++ b/examples/mediatek/CMakeLists.txt @@ -75,6 +75,44 @@ if(${ANDROID}) ) target_compile_options(mtk_executor_runner PUBLIC ${_common_compile_options}) + set(_mtk_oss_executor_runner__srcs ${_executor_runner__srcs}) + list( + TRANSFORM + _mtk_oss_executor_runner__srcs + PREPEND + "${EXECUTORCH_SOURCE_DIR}/" + ) + list( + FILTER + _mtk_oss_executor_runner__srcs + EXCLUDE REGEX + ".*executor_runner.cpp$" + ) + list( + PREPEND + _mtk_oss_executor_runner__srcs + ${CMAKE_CURRENT_LIST_DIR}/executor_runner/mtk_oss_executor_runner.cpp + ) + + add_executable(mtk_oss_executor_runner ${_mtk_oss_executor_runner__srcs}) + + target_include_directories(mtk_oss_executor_runner + PUBLIC + ${_common_include_directories} + ${EXECUTORCH_ROOT}/cmake-android-out/third-party/gflags/include + ) + + target_link_libraries(mtk_oss_executor_runner + ${_executor_runner_libs} + executorch + neuron_backend + gflags + ) + target_compile_options(mtk_oss_executor_runner + PUBLIC + ${_common_compile_options} + ) + set(_mtk_llama_executor_runner__srcs ${_mtk_executor_runner__srcs}) list(FILTER _mtk_llama_executor_runner__srcs EXCLUDE REGEX ".*executor_runner.cpp$" diff --git a/examples/mediatek/README.md b/examples/mediatek/README.md index faca42fb50c..9727f2587fd 100644 --- a/examples/mediatek/README.md +++ b/examples/mediatek/README.md @@ -9,6 +9,8 @@ examples/mediatek ├── preformatter_templates # Model specific prompt preformatter templates ├── prompts # Calibration Prompts ├── tokenizers_ # Model tokenizer scripts + ├── oss_utils # Utils for oss models +├── eval_utils # Utils for eval oss models ├── model_export_scripts # Model specifc export scripts ├── models # Model definitions ├── llm_models # LLM model definitions @@ -44,6 +46,7 @@ pip3 install mtk_converter-8.8.0.dev20240723+public.d1467db9-cp310-cp310-manylin ``` ## AoT Flow +### llama ##### Note: Verify that localhost connection is available before running AoT Flow 1. Exporting Models to `.pte` - In the `examples/mediatek directory`, run: @@ -72,6 +75,14 @@ source shell_scripts/export_llama.sh +``` +- Argument Options: + - `model_name`: deeplabv3/edsr/inceptionv3/inceptionv4/mobilenetv2/mobilenetv3/resnet18/resnet50 + # Runtime ## Supported Chips @@ -100,6 +111,13 @@ adb push .pte Make sure to replace `` with the actual name of your model file. And, replace the `` with the desired detination on the device. +##### Note: For oss models, please push additional files to your Android device +```bash +adb push mtk_oss_executor_runner +adb push input_list.txt +for i in input*bin; do adb push "$i" ; done; +``` + ### Executing the Model Execute the model on your Android device by running: @@ -111,3 +129,21 @@ adb shell "/data/local/tmp/mtk_executor_runner --model_path /data/local/tmp/` with the name of your model file and `` with the desired number of iterations to run the model. ##### Note: For llama models, please use `mtk_llama_executor_runner`. Refer to `examples/mediatek/executor_runner/run_llama3_sample.sh` for reference. +##### Note: For oss models, please use `mtk_oss_executor_runner`. +```bash +adb shell "/data/local/tmp/mtk_oss_executor_runner --model_path /data/local/tmp/.pte --input_list /data/local/tmp/input_list.txt --output_folder /data/local/tmp/output_" +adb pull "/data/local/tmp/output_ ./" +``` + +### Check oss result on PC +```bash +python3 eval_utils/eval_oss_result.py --eval_type --target_f --output_f +``` +For example: +``` +python3 eval_utils/eval_oss_result.py --eval_type piq --target_f edsr --output_f output_edsr +``` +- Argument Options: + - `eval_type`: topk/piq/segmentation + - `target_f`: folder contain golden data files. file name is `golden__0.bin` + - `output_f`: folder contain model output data files. file name is `output__0.bin` diff --git a/examples/mediatek/aot_utils/oss_utils/utils.py b/examples/mediatek/aot_utils/oss_utils/utils.py new file mode 100755 index 00000000000..f447b2ac68f --- /dev/null +++ b/examples/mediatek/aot_utils/oss_utils/utils.py @@ -0,0 +1,73 @@ +# Copyright (c) MediaTek Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +from typing import Optional + +import torch +from executorch import exir +from executorch.backends.mediatek import ( + NeuropilotPartitioner, + NeuropilotQuantizer, + Precision, +) +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e + + +def build_executorch_binary( + model, + inputs, + file_name, + dataset, + quant_dtype: Optional[Precision] = None, +): + if quant_dtype is not None: + quantizer = NeuropilotQuantizer() + quantizer.setup_precision(quant_dtype) + if quant_dtype not in Precision: + raise AssertionError(f"No support for Precision {quant_dtype}.") + + captured_model = torch._export.capture_pre_autograd_graph(model, inputs) + annotated_model = prepare_pt2e(captured_model, quantizer) + print("Quantizing the model...") + # calibration + for data in dataset: + annotated_model(*data) + quantized_model = convert_pt2e(annotated_model, fold_quantize=False) + aten_dialect = torch.export.export(quantized_model, inputs) + else: + aten_dialect = torch.export.export(model, inputs) + + from executorch.exir.program._program import to_edge_transform_and_lower + + edge_compile_config = exir.EdgeCompileConfig(_check_ir_validity=False) + # skipped op names are used for deeplabV3 model + neuro_partitioner = NeuropilotPartitioner( + [], + op_names_to_skip={ + "aten_convolution_default_106", + "aten_convolution_default_107", + }, + ) + edge_prog = to_edge_transform_and_lower( + aten_dialect, + compile_config=edge_compile_config, + partitioner=[neuro_partitioner], + ) + + exec_prog = edge_prog.to_executorch( + config=exir.ExecutorchBackendConfig(extract_constant_segment=False) + ) + with open(f"{file_name}.pte", "wb") as file: + file.write(exec_prog.buffer) + + +def make_output_dir(path: str): + if os.path.exists(path): + for f in os.listdir(path): + os.remove(os.path.join(path, f)) + os.removedirs(path) + os.makedirs(path) diff --git a/examples/mediatek/eval_utils/eval_oss_result.py b/examples/mediatek/eval_utils/eval_oss_result.py new file mode 100755 index 00000000000..3e599330b66 --- /dev/null +++ b/examples/mediatek/eval_utils/eval_oss_result.py @@ -0,0 +1,198 @@ +# Copyright (c) MediaTek Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import json +import os + +import numpy as np +import piq +import torch + + +def check_data(target_f, predict_f): + target_files = os.listdir(target_f) + predict_files = os.listdir(predict_f) + if len(target_files) != len(predict_files): + raise RuntimeError( + "Data number in target folder and prediction folder must be same" + ) + + predict_set = set(predict_files) + for f in target_files: + # target file naming rule is golden_sampleId_outId.bin + # predict file naming rule is output_sampleId_outId.bin + pred_name = f.replace("golden", "output") + try: + predict_set.remove(pred_name) + except KeyError: + raise RuntimeError(f"Cannot find {pred_name} in {predict_f}") + + if predict_set: + target_name = next(predict_set).replace("output", "golden") + raise RuntimeError(f"Cannot find {target_name} in {target_f}") + + +def eval_topk(target_f, predict_f): + def solve(prob, target, k): + _, indices = torch.topk(prob, k=k, sorted=True) + golden = torch.reshape(target, [-1, 1]) + correct = golden == indices + if torch.any(correct): + return 1 + else: + return 0 + + target_files = os.listdir(target_f) + + cnt10 = 0 + cnt50 = 0 + for target_name in target_files: + pred_name = target_name.replace("golden", "output") + + pred_npy = np.fromfile(os.path.join(predict_f, pred_name), dtype=np.float32) + target_npy = np.fromfile(os.path.join(target_f, target_name), dtype=np.int64)[0] + cnt10 += solve(torch.from_numpy(pred_npy), torch.from_numpy(target_npy), 10) + cnt50 += solve(torch.from_numpy(pred_npy), torch.from_numpy(target_npy), 50) + + print("Top10 acc:", cnt10 * 100.0 / len(target_files)) + print("Top50 acc:", cnt50 * 100.0 / len(target_files)) + + +def eval_piq(target_f, predict_f): + target_files = os.listdir(target_f) + + psnr_list = [] + ssim_list = [] + for target_name in target_files: + pred_name = target_name.replace("golden", "output") + hr = np.fromfile(os.path.join(target_f, target_name), dtype=np.float32) + hr = hr.reshape((1, 448, 448, 3)) + hr = np.moveaxis(hr, 3, 1) + hr = torch.from_numpy(hr) + + sr = np.fromfile(os.path.join(predict_f, pred_name), dtype=np.float32) + sr = sr.reshape((1, 448, 448, 3)) + sr = np.moveaxis(sr, 3, 1) + sr = torch.from_numpy(sr).clamp(0, 1) + + psnr_list.append(piq.psnr(hr, sr)) + ssim_list.append(piq.ssim(hr, sr)) + + avg_psnr = sum(psnr_list).item() / len(psnr_list) + avg_ssim = sum(ssim_list).item() / len(ssim_list) + + print(f"Avg of PSNR is: {avg_psnr}") + print(f"Avg of SSIM is: {avg_ssim}") + + +def eval_segmentation(target_f, predict_f): + classes = [ + "Backround", + "Aeroplane", + "Bicycle", + "Bird", + "Boat", + "Bottle", + "Bus", + "Car", + "Cat", + "Chair", + "Cow", + "DiningTable", + "Dog", + "Horse", + "MotorBike", + "Person", + "PottedPlant", + "Sheep", + "Sofa", + "Train", + "TvMonitor", + ] + + target_files = os.listdir(target_f) + + def make_confusion(goldens, predictions, num_classes): + def histogram(golden, predict): + mask = golden < num_classes + hist = np.bincount( + num_classes * golden[mask].astype(int) + predict[mask], + minlength=num_classes**2, + ).reshape(num_classes, num_classes) + return hist + + confusion = np.zeros((num_classes, num_classes)) + for g, p in zip(goldens, predictions): + confusion += histogram(g.flatten(), p.flatten()) + + return confusion + + pred_list = [] + target_list = [] + for target_name in target_files: + pred_name = target_name.replace("golden", "output") + target_npy = np.fromfile(os.path.join(target_f, target_name), dtype=np.uint8) + target_npy = target_npy.reshape((224, 224)) + target_list.append(target_npy) + + pred_npy = np.fromfile(os.path.join(predict_f, pred_name), dtype=np.float32) + pred_npy = pred_npy.reshape((224, 224, len(classes))) + pred_npy = pred_npy.argmax(2).astype(np.uint8) + pred_list.append(pred_npy) + + eps = 1e-6 + confusion = make_confusion(target_list, pred_list, len(classes)) + + pa = np.diag(confusion).sum() / (confusion.sum() + eps) + mpa = np.mean(np.diag(confusion) / (confusion.sum(axis=1) + eps)) + iou = np.diag(confusion) / ( + confusion.sum(axis=1) + confusion.sum(axis=0) - np.diag(confusion) + eps + ) + miou = np.mean(iou) + cls_iou = dict(zip(classes, iou)) + + print(f"PA : {pa}") + print(f"MPA : {mpa}") + print(f"MIoU : {miou}") + print(f"CIoU : \n{json.dumps(cls_iou, indent=2)}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--target_f", + help="folder of target data", + type=str, + required=True, + ) + + parser.add_argument( + "--out_f", + help="folder of model prediction data", + type=str, + required=True, + ) + + parser.add_argument( + "--eval_type", + help="Choose eval type from: topk, piq, segmentation", + type=str, + choices=["topk", "piq", "segmentation"], + required=True, + ) + + args = parser.parse_args() + + check_data(args.target_f, args.out_f) + + if args.eval_type == "topk": + eval_topk(args.target_f, args.out_f) + elif args.eval_type == "piq": + eval_piq(args.target_f, args.out_f) + elif args.eval_type == "segmentation": + eval_segmentation(args.target_f, args.out_f) diff --git a/examples/mediatek/executor_runner/mtk_oss_executor_runner.cpp b/examples/mediatek/executor_runner/mtk_oss_executor_runner.cpp new file mode 100755 index 00000000000..3a1ad1d863b --- /dev/null +++ b/examples/mediatek/executor_runner/mtk_oss_executor_runner.cpp @@ -0,0 +1,302 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * Copyright (c) 2024 MediaTek Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * @file + * + * This tool can run ExecuTorch model files that only use operators that + * are covered by the portable kernels, with possible delegate to the + * test_backend_compiler_lib. + * + * It sets all input tensor data to ones, and assumes that the outputs are + * all fp32 tensors. + */ + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +static uint8_t method_allocator_pool[8 * 1024U * 1024U]; // 8 MB + +// Model Path +DEFINE_string( + model_path, + "model.pte", + "Model serialized in flatbuffer format. Default to 'model.pte'"); +DEFINE_string( + input_list, + "input_list.txt", + "Model input list. Default to 'input_list.txt'"); +DEFINE_string( + output_folder, + "outputs", + "Model output folder. Default to 'outputs'"); + +using namespace torch::executor; +using torch::executor::MemoryAllocator; +using torch::executor::util::BufferCleanup; +using torch::executor::util::FileDataLoader; +using namespace std::filesystem; + +int main(int argc, char** argv) { + runtime_init(); + + gflags::ParseCommandLineFlags(&argc, &argv, true); + if (argc != 1) { + std::string msg = "Extra commandline args:"; + for (int i = 1 /* skip argv[0] (program name) */; i < argc; i++) { + msg += std::string(" ") + argv[i]; + } + ET_LOG(Error, "%s", msg.c_str()); + return 1; + } + + // Create output folder + create_directories(FLAGS_output_folder); + + // Create a loader to get the data of the program file. There are other + // DataLoaders that use mmap() or point to data that's already in memory, and + // users can create their own DataLoaders to load from arbitrary sources. + const char* model_path = FLAGS_model_path.c_str(); + Result loader = FileDataLoader::from(model_path); + ET_CHECK_MSG( + loader.ok(), + "FileDataLoader::from() failed: 0x%" PRIx32, + (uint32_t)loader.error()); + + // Parse the program file. This is immutable, and can also be reused between + // multiple execution invocations across multiple threads. + Result program = Program::load(&loader.get()); + if (!program.ok()) { + ET_LOG(Error, "Failed to parse model file %s", model_path); + return 1; + } + ET_LOG(Info, "Model file %s is loaded.", model_path); + + // Use the first method in the program. + const char* method_name = nullptr; + { + const auto method_name_result = program->get_method_name(0); + ET_CHECK_MSG(method_name_result.ok(), "Program has no methods"); + method_name = *method_name_result; + } + ET_LOG(Info, "Using method %s", method_name); + + // MethodMeta describes the memory requirements of the method. + Result method_meta_result = program->method_meta(method_name); + ET_CHECK_MSG( + method_meta_result.ok(), + "Failed to get method_meta for %s: 0x%" PRIx32, + method_name, + (uint32_t)method_meta_result.error()); + + // + // The runtime does not use malloc/new; it allocates all memory using the + // MemoryManger provided by the client. Clients are responsible for allocating + // the memory ahead of time, or providing MemoryAllocator subclasses that can + // do it dynamically. + // + + // The method allocator is used to allocate all dynamic C++ metadata/objects + // used to represent the loaded method. This allocator is only used during + // loading a method of the program, which will return an error if there was + // not enough memory. + // + // The amount of memory required depends on the loaded method and the runtime + // code itself. The amount of memory here is usually determined by running the + // method and seeing how much memory is actually used, though it's possible to + // subclass MemoryAllocator so that it calls malloc() under the hood (see + // MallocMemoryAllocator). + // + // In this example we use a statically allocated memory pool. + MemoryAllocator method_allocator{ + MemoryAllocator(sizeof(method_allocator_pool), method_allocator_pool)}; + + // The memory-planned buffers will back the mutable tensors used by the + // method. The sizes of these buffers were determined ahead of time during the + // memory-planning pasees. + // + // Each buffer typically corresponds to a different hardware memory bank. Most + // mobile environments will only have a single buffer. Some embedded + // environments may have more than one for, e.g., slow/large DRAM and + // fast/small SRAM, or for memory associated with particular cores. + std::vector> planned_buffers; // Owns the memory + std::vector> planned_spans; // Passed to the allocator + size_t num_memory_planned_buffers = + method_meta_result->num_memory_planned_buffers(); + for (size_t id = 0; id < num_memory_planned_buffers; ++id) { + // .get() will always succeed because id < num_memory_planned_buffers. + size_t buffer_size = static_cast( + method_meta_result->memory_planned_buffer_size(id).get()); + ET_LOG(Info, "Setting up planned buffer %zu, size %zu.", id, buffer_size); + planned_buffers.push_back(std::make_unique(buffer_size)); + planned_spans.push_back({planned_buffers.back().get(), buffer_size}); + } + HierarchicalAllocator planned_memory( + {planned_spans.data(), planned_spans.size()}); + + // Assemble all of the allocators into the MemoryManager that the Executor + // will use. + MemoryManager memory_manager(&method_allocator, &planned_memory); + + // + // Load the method from the program, using the provided allocators. Running + // the method can mutate the memory-planned buffers, so the method should only + // be used by a single thread at at time, but it can be reused. + // + Result method = program->load_method(method_name, &memory_manager); + ET_CHECK_MSG( + method.ok(), + "Loading of method %s failed with status 0x%" PRIx32, + method_name, + (uint32_t)method.error()); + ET_LOG(Info, "Method loaded."); + + std::ifstream input_list(FLAGS_input_list); + ET_CHECK_MSG( + input_list.is_open(), + "Error: cannot open input file %s", + FLAGS_input_list.c_str()); + + auto split = [](std::string s, std::string delimiter) { + size_t pos_start = 0, pos_end, delim_len = delimiter.length(); + std::string token; + std::vector res; + + while ((pos_end = s.find(delimiter, pos_start)) != std::string::npos) { + token = s.substr(pos_start, pos_end - pos_start); + pos_start = pos_end + delim_len; + res.push_back(token); + } + res.push_back(s.substr(pos_start)); + return res; + }; + + MethodMeta method_meta = method->method_meta(); + size_t num_inputs = method_meta.num_inputs(); + std::string file_path; + int inference_index = 0; + while (std::getline(input_list, file_path)) { + auto input_files = split(file_path, " "); + if (input_files.size() == 0) { + break; + } + ET_CHECK_MSG( + input_files.size() == num_inputs, + "Model expect %zu inputs but get %zu from input files", + num_inputs, + input_files.size()); + + // Prepare the inputs. + size_t num_allocated = 0; + ET_LOG(Info, "Number of inputs: %zu", num_inputs); + void** inputs = (void**)malloc(num_inputs * sizeof(void*)); + + for (size_t i = 0; i < num_inputs; i++) { + auto tag = method_meta.input_tag(i); + if (tag.get() != Tag::Tensor) { + ET_LOG(Debug, "Skipping malloc non-tensor input %zu", i); + continue; + } + Result tensor_meta = method_meta.input_tensor_meta(i); + const auto nbytes = tensor_meta->nbytes(); + // This input is a tensor. Allocate a buffer for it. + void* data_ptr = malloc(nbytes); + + // Read data from file + std::ifstream fin(input_files[i], std::ios::binary); + fin.seekg(0, fin.end); + size_t file_size = fin.tellg(); + + ET_CHECK_MSG( + file_size == nbytes, + "Input %zu size mismatch. file bytes: %zu, tensor bytes: %zu", + i, + file_size, + nbytes); + + fin.seekg(0, fin.beg); + fin.read(static_cast(data_ptr), file_size); + fin.close(); + inputs[num_allocated++] = data_ptr; + + // Set backend input + auto scalar_type = tensor_meta->scalar_type(); + auto sizes_raw = tensor_meta->sizes(); + auto dim = sizes_raw.size(); + auto dim_order_raw = tensor_meta->dim_order(); + std::vector sizes(sizes_raw.begin(), sizes_raw.end()); + std::vector dim_order(dim_order_raw.begin(), dim_order_raw.end()); + + TensorImpl impl = TensorImpl( + scalar_type, dim, sizes.data(), data_ptr, dim_order.data()); + + Tensor tensor(&impl); + Error ret = method->set_input(tensor, i); + if (ret != Error::Ok) { + ET_LOG(Error, "Failed to set input %zu: 0x%" PRIx32, i, (uint32_t)ret); + // The BufferCleanup will free the inputs when it goes out of scope. + BufferCleanup cleanup({inputs, num_allocated}); + return 1; + } + } + BufferCleanup({inputs, num_allocated}); + ET_LOG(Info, "Inputs prepared."); + + // Run the model. + auto before_exec = std::chrono::high_resolution_clock::now(); + Error status = Error::Ok; + status = method->execute(); + auto after_exec = std::chrono::high_resolution_clock::now(); + double elapsed_time = std::chrono::duration_cast( + after_exec - before_exec) + .count() / + 1000.0; + + ET_LOG(Info, "Inference took %f ms", elapsed_time); + ET_CHECK_MSG( + status == Error::Ok, + "Execution of method %s failed with status 0x%" PRIx32, + method_name, + (uint32_t)status); + ET_LOG(Info, "Model executed successfully."); + + // Get output data + size_t output_size = method->outputs_size(); + ET_LOG(Info, "Number of outputs: %zu", output_size); + std::vector outputs(output_size); + status = method->get_outputs(outputs.data(), output_size); + ET_CHECK(status == Error::Ok); + for (size_t i = 0; i < output_size; i++) { + auto output_tensor = outputs[i].toTensor(); + auto output_file_name = FLAGS_output_folder + "/output_" + + std::to_string(inference_index) + "_" + std::to_string(i) + ".bin"; + std::ofstream fout(output_file_name.c_str(), std::ios::binary); + fout.write(output_tensor.const_data_ptr(), output_tensor.nbytes()); + fout.close(); + } + + inference_index++; + } + + return 0; +} diff --git a/examples/mediatek/model_export_scripts/deeplab_v3.py b/examples/mediatek/model_export_scripts/deeplab_v3.py new file mode 100755 index 00000000000..da6766c0f54 --- /dev/null +++ b/examples/mediatek/model_export_scripts/deeplab_v3.py @@ -0,0 +1,124 @@ +# Copyright (c) MediaTek Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +import random + +import numpy as np + +import torch +from executorch.backends.mediatek import Precision +from executorch.examples.mediatek.aot_utils.oss_utils.utils import ( + build_executorch_binary, +) +from executorch.examples.models.deeplab_v3 import DeepLabV3ResNet101Model + + +class NhwcWrappedModel(torch.nn.Module): + def __init__(self): + super(NhwcWrappedModel, self).__init__() + self.deeplabv3 = DeepLabV3ResNet101Model().get_eager_model() + + def forward(self, input1): + nchw_input1 = input1.permute(0, 3, 1, 2) + nchw_output = self.deeplabv3(nchw_input1) + return nchw_output.permute(0, 2, 3, 1) + + +def get_dataset(data_size, dataset_dir, download): + from torchvision import datasets, transforms + + input_size = (224, 224) + preprocess = transforms.Compose( + [ + transforms.Resize(input_size), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + dataset = list( + datasets.VOCSegmentation( + root=os.path.join(dataset_dir, "voc_image"), + year="2009", + image_set="val", + transform=preprocess, + download=download, + ) + ) + + # prepare input data + random.shuffle(dataset) + inputs, targets, input_list = [], [], "" + for index, data in enumerate(dataset): + if index >= data_size: + break + image, target = data + inputs.append((image.unsqueeze(0).permute(0, 2, 3, 1),)) + targets.append(np.array(target.resize(input_size))) + input_list += f"input_{index}_0.bin\n" + + return inputs, targets, input_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. Default ./deeplab_v3", + default="./deeplab_v3", + type=str, + ) + + parser.add_argument( + "-d", + "--download", + help="If specified, download VOCSegmentation dataset by torchvision API", + action="store_true", + default=False, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + data_num = 100 + inputs, targets, input_list = get_dataset( + data_size=data_num, dataset_dir=args.artifact, download=args.download + ) + + # save data to inference on device + input_list_file = f"{args.artifact}/input_list.txt" + with open(input_list_file, "w") as f: + f.write(input_list) + f.flush() + for idx, data in enumerate(inputs): + for i, d in enumerate(data): + file_name = f"{args.artifact}/input_{idx}_{i}.bin" + d.detach().numpy().tofile(file_name) + if idx == 0: + print("inp shape: ", d.detach().numpy().shape) + print("inp type: ", d.detach().numpy().dtype) + for idx, data in enumerate(targets): + file_name = f"{args.artifact}/golden_{idx}_0.bin" + data.tofile(file_name) + if idx == 0: + print("golden shape: ", data.shape) + print("golden type: ", data.dtype) + + # build pte + pte_filename = "deeplabV3Resnet101_mtk" + instance = NhwcWrappedModel() + build_executorch_binary( + instance.eval(), + (torch.randn(1, 224, 224, 3),), + f"{args.artifact}/{pte_filename}", + inputs, + quant_dtype=Precision.A8W8, + ) diff --git a/examples/mediatek/model_export_scripts/edsr.py b/examples/mediatek/model_export_scripts/edsr.py new file mode 100755 index 00000000000..4192d67e569 --- /dev/null +++ b/examples/mediatek/model_export_scripts/edsr.py @@ -0,0 +1,170 @@ +# Copyright (c) MediaTek Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +import numpy as np + +import torch +from executorch.backends.mediatek import Precision +from executorch.examples.mediatek.aot_utils.oss_utils.utils import ( + build_executorch_binary, +) +from executorch.examples.models.edsr import EdsrModel + +from PIL import Image +from torch.utils.data import Dataset +from torchsr.datasets import B100 +from torchvision.transforms.functional import to_tensor + + +class NhwcWrappedModel(torch.nn.Module): + def __init__(self): + super(NhwcWrappedModel, self).__init__() + self.edsr = EdsrModel().get_eager_model() + + def forward(self, input1): + nchw_input1 = input1.permute(0, 3, 1, 2) + nchw_output = self.edsr(nchw_input1) + return nchw_output.permute(0, 2, 3, 1) + + +class SrDataset(Dataset): + def __init__(self, hr_dir: str, lr_dir: str): + self.input_size = np.asanyarray([224, 224]) + self.hr = [] + self.lr = [] + + for file in sorted(os.listdir(hr_dir)): + self.hr.append(self._resize_img(os.path.join(hr_dir, file), 2)) + + for file in sorted(os.listdir(lr_dir)): + self.lr.append(self._resize_img(os.path.join(lr_dir, file), 1)) + + if len(self.hr) != len(self.lr): + raise AssertionError( + "The number of high resolution pics is not equal to low " + "resolution pics" + ) + + def __getitem__(self, idx: int): + return self.hr[idx], self.lr[idx] + + def __len__(self): + return len(self.lr) + + def _resize_img(self, file: str, scale: int): + with Image.open(file) as img: + return ( + to_tensor(img.resize(tuple(self.input_size * scale))) + .unsqueeze(0) + .permute(0, 2, 3, 1) + ) + + def get_input_list(self): + input_list = "" + for i in range(len(self.lr)): + input_list += f"input_{i}_0.bin\n" + return input_list + + +def get_b100( + dataset_dir: str, +): + hr_dir = f"{dataset_dir}/sr_bm_dataset/SRBenchmarks/benchmark/B100/HR" + lr_dir = f"{dataset_dir}/sr_bm_dataset/SRBenchmarks/benchmark/B100/LR_bicubic/X2" + + if not os.path.exists(hr_dir) or not os.path.exists(lr_dir): + B100(root=f"{dataset_dir}/sr_bm_dataset", scale=2, download=True) + + return SrDataset(hr_dir, lr_dir) + + +def get_dataset(hr_dir: str, lr_dir: str, default_dataset: str, dataset_dir: str): + if not (lr_dir and hr_dir) and not default_dataset: + raise RuntimeError( + "Nither custom dataset is provided nor using default dataset." + ) + + if (lr_dir and hr_dir) and default_dataset: + raise RuntimeError("Either use custom dataset, or use default dataset.") + + if default_dataset: + return get_b100(dataset_dir) + + return SrDataset(hr_dir, lr_dir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. Default ./edsr", + default="./edsr", + type=str, + ) + + parser.add_argument( + "-r", + "--hr_ref_dir", + help="Path to the high resolution images", + default="", + type=str, + ) + + parser.add_argument( + "-l", + "--lr_dir", + help="Path to the low resolution image inputs", + default="", + type=str, + ) + + parser.add_argument( + "-d", + "--default_dataset", + help="If specified, download and use B100 dataset by torchSR API", + action="store_true", + default=False, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + dataset = get_dataset( + args.hr_ref_dir, args.lr_dir, args.default_dataset, args.artifact + ) + + inputs, targets, input_list = dataset.lr, dataset.hr, dataset.get_input_list() + + # save data to inference on device + input_list_file = f"{args.artifact}/input_list.txt" + with open(input_list_file, "w") as f: + f.write(input_list) + f.flush() + for idx, data in enumerate(inputs): + for i, d in enumerate(data): + file_name = f"{args.artifact}/input_{idx}_{i}.bin" + d.detach().numpy().tofile(file_name) + for idx, data in enumerate(targets): + file_name = f"{args.artifact}/golden_{idx}_0.bin" + data.detach().numpy().tofile(file_name) + + # build pte + pte_filename = "edsr_mtk" + instance = NhwcWrappedModel() + build_executorch_binary( + instance.eval(), + (inputs[0],), + f"{args.artifact}/{pte_filename}", + [(input,) for input in inputs], + quant_dtype=Precision.A8W8, + ) diff --git a/examples/mediatek/model_export_scripts/inception_v3.py b/examples/mediatek/model_export_scripts/inception_v3.py new file mode 100755 index 00000000000..c28bd85b402 --- /dev/null +++ b/examples/mediatek/model_export_scripts/inception_v3.py @@ -0,0 +1,120 @@ +# Copyright (c) MediaTek Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +import torch +from executorch.backends.mediatek import Precision +from executorch.examples.mediatek.aot_utils.oss_utils.utils import ( + build_executorch_binary, +) +from executorch.examples.models.inception_v3 import InceptionV3Model + + +class NhwcWrappedModel(torch.nn.Module): + def __init__(self): + super(NhwcWrappedModel, self).__init__() + self.inception = InceptionV3Model().get_eager_model() + + def forward(self, input1): + nchw_input1 = input1.permute(0, 3, 1, 2) + output = self.inception(nchw_input1) + return output + + +def get_dataset(dataset_path, data_size): + from torchvision import datasets, transforms + + def get_data_loader(): + preprocess = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess) + return torch.utils.data.DataLoader( + imagenet_data, + shuffle=True, + ) + + # prepare input data + inputs, targets, input_list = [], [], "" + data_loader = get_data_loader() + for index, data in enumerate(data_loader): + if index >= data_size: + break + feature, target = data + feature = feature.permute(0, 2, 3, 1) # NHWC + inputs.append((feature,)) + targets.append(target) + input_list += f"input_{index}_0.bin\n" + + return inputs, targets, input_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-d", + "--dataset", + help=( + "path to the validation folder of ImageNet dataset. " + "e.g. --dataset imagenet-mini/val " + "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" + ), + type=str, + required=True, + ) + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. " + "Default ./inceptionV3", + default="./inceptionV3", + type=str, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + data_num = 100 + inputs, targets, input_list = get_dataset( + dataset_path=f"{args.dataset}", + data_size=data_num, + ) + + # save data to inference on device + input_list_file = f"{args.artifact}/input_list.txt" + with open(input_list_file, "w") as f: + f.write(input_list) + f.flush() + for idx, data in enumerate(inputs): + for i, d in enumerate(data): + file_name = f"{args.artifact}/input_{idx}_{i}.bin" + d.detach().numpy().tofile(file_name) + for idx, data in enumerate(targets): + file_name = f"{args.artifact}/golden_{idx}_0.bin" + data.detach().numpy().tofile(file_name) + + pte_filename = "inceptionV3_mtk" + instance = NhwcWrappedModel() + build_executorch_binary( + instance.eval(), + (torch.randn(1, 224, 224, 3),), + f"{args.artifact}/{pte_filename}", + inputs, + quant_dtype=Precision.A8W8, + ) diff --git a/examples/mediatek/model_export_scripts/inception_v4.py b/examples/mediatek/model_export_scripts/inception_v4.py new file mode 100755 index 00000000000..ccb2ce16f22 --- /dev/null +++ b/examples/mediatek/model_export_scripts/inception_v4.py @@ -0,0 +1,120 @@ +# Copyright (c) MediaTek Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +import torch +from executorch.backends.mediatek import Precision +from executorch.examples.mediatek.aot_utils.oss_utils.utils import ( + build_executorch_binary, +) +from executorch.examples.models.inception_v4 import InceptionV4Model + + +class NhwcWrappedModel(torch.nn.Module): + def __init__(self): + super(NhwcWrappedModel, self).__init__() + self.inception = InceptionV4Model().get_eager_model() + + def forward(self, input1): + nchw_input1 = input1.permute(0, 3, 1, 2) + output = self.inception(nchw_input1) + return output + + +def get_dataset(dataset_path, data_size): + from torchvision import datasets, transforms + + def get_data_loader(): + preprocess = transforms.Compose( + [ + transforms.Resize((299, 299)), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess) + return torch.utils.data.DataLoader( + imagenet_data, + shuffle=True, + ) + + # prepare input data + inputs, targets, input_list = [], [], "" + data_loader = get_data_loader() + for index, data in enumerate(data_loader): + if index >= data_size: + break + feature, target = data + feature = feature.permute(0, 2, 3, 1) # NHWC + inputs.append((feature,)) + targets.append(target) + input_list += f"input_{index}_0.bin\n" + + return inputs, targets, input_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-d", + "--dataset", + help=( + "path to the validation folder of ImageNet dataset. " + "e.g. --dataset imagenet-mini/val " + "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" + ), + type=str, + required=True, + ) + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. " + "Default ./inceptionV4", + default="./inceptionV4", + type=str, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + data_num = 100 + inputs, targets, input_list = get_dataset( + dataset_path=f"{args.dataset}", + data_size=data_num, + ) + + # save data to inference on device + input_list_file = f"{args.artifact}/input_list.txt" + with open(input_list_file, "w") as f: + f.write(input_list) + f.flush() + for idx, data in enumerate(inputs): + for i, d in enumerate(data): + file_name = f"{args.artifact}/input_{idx}_{i}.bin" + d.detach().numpy().tofile(file_name) + for idx, data in enumerate(targets): + file_name = f"{args.artifact}/golden_{idx}_0.bin" + data.detach().numpy().tofile(file_name) + + # build pte + pte_filename = "inceptionV4_mtk" + instance = NhwcWrappedModel() + build_executorch_binary( + instance.eval(), + (torch.randn(1, 299, 299, 3),), + f"{args.artifact}/{pte_filename}", + inputs, + quant_dtype=Precision.A8W8, + ) diff --git a/examples/mediatek/model_export_scripts/mobilenet_v2.py b/examples/mediatek/model_export_scripts/mobilenet_v2.py new file mode 100755 index 00000000000..97f2ed884eb --- /dev/null +++ b/examples/mediatek/model_export_scripts/mobilenet_v2.py @@ -0,0 +1,121 @@ +# Copyright (c) MediaTek Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +import torch +from executorch.backends.mediatek import Precision +from executorch.examples.mediatek.aot_utils.oss_utils.utils import ( + build_executorch_binary, +) +from executorch.examples.models.mobilenet_v2 import MV2Model + + +class NhwcWrappedModel(torch.nn.Module): + def __init__(self): + super(NhwcWrappedModel, self).__init__() + self.mobilenet = MV2Model().get_eager_model() + + def forward(self, input1): + nchw_input1 = input1.permute(0, 3, 1, 2) + output = self.mobilenet(nchw_input1) + return output + + +def get_dataset(dataset_path, data_size): + from torchvision import datasets, transforms + + def get_data_loader(): + preprocess = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess) + return torch.utils.data.DataLoader( + imagenet_data, + shuffle=True, + ) + + # prepare input data + inputs, targets, input_list = [], [], "" + data_loader = get_data_loader() + for index, data in enumerate(data_loader): + if index >= data_size: + break + feature, target = data + feature = feature.permute(0, 2, 3, 1) # NHWC + inputs.append((feature,)) + targets.append(target) + input_list += f"input_{index}_0.bin\n" + + return inputs, targets, input_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-d", + "--dataset", + help=( + "path to the validation folder of ImageNet dataset. " + "e.g. --dataset imagenet-mini/val " + "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" + ), + type=str, + required=True, + ) + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. " + "Default ./mobilenetV2", + default="./mobilenetV2", + type=str, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + data_num = 100 + inputs, targets, input_list = get_dataset( + dataset_path=f"{args.dataset}", + data_size=data_num, + ) + + # save data to inference on device + input_list_file = f"{args.artifact}/input_list.txt" + with open(input_list_file, "w") as f: + f.write(input_list) + f.flush() + for idx, data in enumerate(inputs): + for i, d in enumerate(data): + file_name = f"{args.artifact}/input_{idx}_{i}.bin" + d.detach().numpy().tofile(file_name) + for idx, data in enumerate(targets): + file_name = f"{args.artifact}/golden_{idx}_0.bin" + data.detach().numpy().tofile(file_name) + + # build pte + pte_filename = "mobilenetV2_mtk" + instance = NhwcWrappedModel() + build_executorch_binary( + instance.eval(), + (torch.randn(1, 224, 224, 3),), + f"{args.artifact}/{pte_filename}", + inputs, + quant_dtype=Precision.A8W8, + ) diff --git a/examples/mediatek/model_export_scripts/mobilenet_v3.py b/examples/mediatek/model_export_scripts/mobilenet_v3.py new file mode 100755 index 00000000000..fed2497ca26 --- /dev/null +++ b/examples/mediatek/model_export_scripts/mobilenet_v3.py @@ -0,0 +1,121 @@ +# Copyright (c) MediaTek Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +import torch +from executorch.backends.mediatek import Precision +from executorch.examples.mediatek.aot_utils.oss_utils.utils import ( + build_executorch_binary, +) +from executorch.examples.models.mobilenet_v3 import MV3Model + + +class NhwcWrappedModel(torch.nn.Module): + def __init__(self): + super(NhwcWrappedModel, self).__init__() + self.mobilenet = MV3Model().get_eager_model() + + def forward(self, input1): + nchw_input1 = input1.permute(0, 3, 1, 2) + output = self.mobilenet(nchw_input1) + return output + + +def get_dataset(dataset_path, data_size): + from torchvision import datasets, transforms + + def get_data_loader(): + preprocess = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess) + return torch.utils.data.DataLoader( + imagenet_data, + shuffle=True, + ) + + # prepare input data + inputs, targets, input_list = [], [], "" + data_loader = get_data_loader() + for index, data in enumerate(data_loader): + if index >= data_size: + break + feature, target = data + feature = feature.permute(0, 2, 3, 1) # NHWC + inputs.append((feature,)) + targets.append(target) + input_list += f"input_{index}_0.bin\n" + + return inputs, targets, input_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-d", + "--dataset", + help=( + "path to the validation folder of ImageNet dataset. " + "e.g. --dataset imagenet-mini/val " + "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" + ), + type=str, + required=True, + ) + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. " + "Default ./mobilenetV3", + default="./mobilenetV3", + type=str, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + data_num = 100 + inputs, targets, input_list = get_dataset( + dataset_path=f"{args.dataset}", + data_size=data_num, + ) + + # save data to inference on device + input_list_file = f"{args.artifact}/input_list.txt" + with open(input_list_file, "w") as f: + f.write(input_list) + f.flush() + for idx, data in enumerate(inputs): + for i, d in enumerate(data): + file_name = f"{args.artifact}/input_{idx}_{i}.bin" + d.detach().numpy().tofile(file_name) + for idx, data in enumerate(targets): + file_name = f"{args.artifact}/golden_{idx}_0.bin" + data.detach().numpy().tofile(file_name) + + # build pte + pte_filename = "mobilenetV3_mtk" + instance = NhwcWrappedModel() + build_executorch_binary( + instance.eval(), + (torch.randn(1, 224, 224, 3),), + f"{args.artifact}/{pte_filename}", + inputs, + quant_dtype=Precision.A8W8, + ) diff --git a/examples/mediatek/model_export_scripts/resnet18.py b/examples/mediatek/model_export_scripts/resnet18.py new file mode 100755 index 00000000000..2f3af57e7f3 --- /dev/null +++ b/examples/mediatek/model_export_scripts/resnet18.py @@ -0,0 +1,122 @@ +# Copyright (c) MediaTek Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +import torch +from executorch.backends.mediatek import Precision +from executorch.examples.mediatek.aot_utils.oss_utils.utils import ( + build_executorch_binary, +) +from executorch.examples.models.resnet import ResNet18Model + + +class NhwcWrappedModel(torch.nn.Module): + def __init__(self): + super(NhwcWrappedModel, self).__init__() + self.resnet = ResNet18Model().get_eager_model() + + def forward(self, input1): + nchw_input1 = input1.permute(0, 3, 1, 2) + output = self.resnet(nchw_input1) + return output + + +def get_dataset(dataset_path, data_size): + from torchvision import datasets, transforms + + def get_data_loader(): + preprocess = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess) + return torch.utils.data.DataLoader( + imagenet_data, + shuffle=True, + ) + + # prepare input data + inputs, targets, input_list = [], [], "" + data_loader = get_data_loader() + for index, data in enumerate(data_loader): + if index >= data_size: + break + feature, target = data + feature = feature.permute(0, 2, 3, 1) # NHWC + inputs.append((feature,)) + targets.append(target) + input_list += f"input_{index}_0.bin\n" + + return inputs, targets, input_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-d", + "--dataset", + help=( + "path to the validation folder of ImageNet dataset. " + "e.g. --dataset imagenet-mini/val " + "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" + ), + type=str, + required=True, + ) + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. " + "Default ./resnet18", + default="./resnet18", + type=str, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + data_num = 100 + inputs, targets, input_list = get_dataset( + dataset_path=f"{args.dataset}", + data_size=data_num, + ) + + # save data to inference on device + input_list_file = f"{args.artifact}/input_list.txt" + with open(input_list_file, "w") as f: + f.write(input_list) + f.flush() + for idx, data in enumerate(inputs): + for i, d in enumerate(data): + file_name = f"{args.artifact}/input_{idx}_{i}.bin" + d.detach().numpy().tofile(file_name) + for idx, data in enumerate(targets): + file_name = f"{args.artifact}/golden_{idx}_0.bin" + aaa = data.detach().numpy() + data.detach().numpy().tofile(file_name) + + # build pte + pte_filename = "resnet18_mtk" + instance = NhwcWrappedModel() + build_executorch_binary( + instance.eval(), + (torch.randn(1, 224, 224, 3),), + f"{args.artifact}/{pte_filename}", + inputs, + quant_dtype=Precision.A8W8, + ) diff --git a/examples/mediatek/model_export_scripts/resnet50.py b/examples/mediatek/model_export_scripts/resnet50.py new file mode 100755 index 00000000000..ce23842447b --- /dev/null +++ b/examples/mediatek/model_export_scripts/resnet50.py @@ -0,0 +1,121 @@ +# Copyright (c) MediaTek Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +import torch +from executorch.backends.mediatek import Precision +from executorch.examples.mediatek.aot_utils.oss_utils.utils import ( + build_executorch_binary, +) +from executorch.examples.models.resnet import ResNet50Model + + +class NhwcWrappedModel(torch.nn.Module): + def __init__(self): + super(NhwcWrappedModel, self).__init__() + self.resnet = ResNet50Model().get_eager_model() + + def forward(self, input1): + nchw_input1 = input1.permute(0, 3, 1, 2) + output = self.resnet(nchw_input1) + return output + + +def get_dataset(dataset_path, data_size): + from torchvision import datasets, transforms + + def get_data_loader(): + preprocess = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess) + return torch.utils.data.DataLoader( + imagenet_data, + shuffle=True, + ) + + # prepare input data + inputs, targets, input_list = [], [], "" + data_loader = get_data_loader() + for index, data in enumerate(data_loader): + if index >= data_size: + break + feature, target = data + feature = feature.permute(0, 2, 3, 1) # NHWC + inputs.append((feature,)) + targets.append(target) + input_list += f"input_{index}_0.bin\n" + + return inputs, targets, input_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-d", + "--dataset", + help=( + "path to the validation folder of ImageNet dataset. " + "e.g. --dataset imagenet-mini/val " + "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" + ), + type=str, + required=True, + ) + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. " + "Default ./resnet50", + default="./resnet50", + type=str, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + data_num = 100 + inputs, targets, input_list = get_dataset( + dataset_path=f"{args.dataset}", + data_size=data_num, + ) + + # save data to inference on device + input_list_file = f"{args.artifact}/input_list.txt" + with open(input_list_file, "w") as f: + f.write(input_list) + f.flush() + for idx, data in enumerate(inputs): + for i, d in enumerate(data): + file_name = f"{args.artifact}/input_{idx}_{i}.bin" + d.detach().numpy().tofile(file_name) + for idx, data in enumerate(targets): + file_name = f"{args.artifact}/golden_{idx}_0.bin" + data.detach().numpy().tofile(file_name) + + # compile to pte + pte_filename = "resnet50_mtk" + instance = NhwcWrappedModel() + build_executorch_binary( + instance.eval(), + (torch.randn(1, 224, 224, 3),), + f"{args.artifact}/{pte_filename}", + inputs, + quant_dtype=Precision.A8W8, + ) diff --git a/examples/mediatek/requirements.txt b/examples/mediatek/requirements.txt index 038700059ba..7c3de886e27 100644 --- a/examples/mediatek/requirements.txt +++ b/examples/mediatek/requirements.txt @@ -4,3 +4,5 @@ safetensors sentencepiece tokenizers transformers +piq +pillow diff --git a/examples/mediatek/shell_scripts/export_oss.sh b/examples/mediatek/shell_scripts/export_oss.sh new file mode 100755 index 00000000000..3da5dc41f94 --- /dev/null +++ b/examples/mediatek/shell_scripts/export_oss.sh @@ -0,0 +1,29 @@ +model=$1 + +echo "Export model: $model" + +if [ $model = "deeplabv3" ] +then + python3 model_export_scripts/deeplab_v3.py -d +elif [ $model = "edsr" ] +then + python3 model_export_scripts/edsr.py -d +elif [ $model = "inceptionv3" ] +then + python3 model_export_scripts/inception_v3.py -d PATH_TO_DATASET +elif [ $model = "inceptionv4" ] +then + python3 model_export_scripts/inception_v4.py -d PATH_TO_DATASET +elif [ $model = "mobilenetv2" ] +then + python3 model_export_scripts/mobilenet_v2.py -d PATH_TO_DATASET +elif [ $model = "mobilenetv3" ] +then + python3 model_export_scripts/mobilenet_v3.py -d PATH_TO_DATASET +elif [ $model = "resnet18" ] +then + python3 model_export_scripts/resnet18.py -d PATH_TO_DATASET +elif [ $model = "resnet50" ] +then + python3 model_export_scripts/resnet50.py -d PATH_TO_DATASET +fi