From e98e7fe42811b656bd272e92044df67ccb2d3880 Mon Sep 17 00:00:00 2001 From: Linbin Yu Date: Fri, 24 Jun 2022 21:51:20 +0000 Subject: [PATCH] [4] move pt_operator_library to shared BUCK file (#80170) Summary: Move pt_operator_library to pt_ops.bzl and make it shared with OSS BUCK build This will replace D36912042. I will update all load statements in future diffs. Test Plan: sandcaslte, OSS CI Differential Revision: D37390060 Pull Request resolved: https://github.com/pytorch/pytorch/pull/80170 Approved by: https://github.com/JacobSzwejbka --- .github/workflows/_buck-build-test.yml | 18 +-- BUCK.oss | 18 ++- buckbuild.bzl | 10 +- pt_defs.oss.bzl | 177 ------------------------- pt_ops.bzl | 111 ++++++++++++++++ 5 files changed, 126 insertions(+), 208 deletions(-) delete mode 100644 pt_defs.oss.bzl diff --git a/.github/workflows/_buck-build-test.yml b/.github/workflows/_buck-build-test.yml index 1d54f55c2424..b005224650a1 100644 --- a/.github/workflows/_buck-build-test.yml +++ b/.github/workflows/_buck-build-test.yml @@ -62,29 +62,17 @@ jobs: command: | sh scripts/buck_setup.sh - - name: Build glog - run: | - buck build third_party:glog - - name: Build C10 run: | buck build c10:c10 - - name: Build cpuinfo - run: | - buck build third_party:cpuinfo - - - name: Build pthreadpool - run: | - buck build third_party:pthreadpool - - name: Build XNNPACK run: | buck build third_party:XNNPACK - name: Build QNNPACK run: | - buck build aten/src/ATen/native/quantized/cpu/qnnpack/... --keep-going + buck build aten/src/ATen/native/quantized/cpu/qnnpack:pytorch_qnnpack - name: Build aten_cpu run: | @@ -94,9 +82,9 @@ jobs: run: | buck build :torch_mobile_core - - name: Build torch_mobile_all_ops + - name: Build pt_ops_full run: | - buck build :torch_mobile_all_ops + buck build :pt_ops_full - name: Build mobile benchmark run: | diff --git a/BUCK.oss b/BUCK.oss index 62868fbb08a5..13d5518801a2 100644 --- a/BUCK.oss +++ b/BUCK.oss @@ -1,15 +1,17 @@ load("//tools/build_defs:glob_defs.bzl", "subdir_glob") load( - ":pt_defs.oss.bzl", + ":pt_ops.bzl", "pt_operator_library", - "get_pt_ops_deps", ) load(":buckbuild.bzl", "define_buck_targets", + "get_pt_operator_registry_dict", ) +# define shared buck targets define_buck_targets() +# define OSS only targets cxx_library( name = "pthreadpool", srcs = ['caffe2/utils/threadpool/pthreadpool.cc', 'caffe2/utils/threadpool/pthreadpool_impl.cc', 'caffe2/utils/threadpool/pthreadpool-cpp.cc', 'caffe2/utils/threadpool/thread_pool_guard.cpp', 'caffe2/utils/threadpool/ThreadPool.cc'], @@ -76,21 +78,17 @@ cxx_library( pt_operator_library( name = "torch_mobile_ops_full_dev", - check_decl = False, include_all_operators = True, ) cxx_library( - name = "torch_mobile_all_ops", - visibility = ["PUBLIC"], - deps = get_pt_ops_deps( + name = "pt_ops_full", + **get_pt_operator_registry_dict( name = "pt_ops_full", - train = False, deps = [ ":torch_mobile_ops_full_dev", ], - enable_flatbuffer = False, - ), + ) ) cxx_binary( @@ -118,7 +116,7 @@ cxx_binary( ], deps = [ ":torch_mobile_core", - ":torch_mobile_all_ops", + ":pt_ops_full", "//c10:c10", ], ) diff --git a/buckbuild.bzl b/buckbuild.bzl index 55e22851d48d..a2e1a6d5e391 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -25,6 +25,10 @@ load( "jit_core_sources", "libtorch_profiler_sources", ) +load( + ":pt_ops.bzl", + "USED_PT_BACKENDS", +) load( ":pt_template_srcs.bzl", "METAL_MASKRCNN_SOURCE_LIST", @@ -235,12 +239,6 @@ def get_pt_preprocessor_flags(): PT_PREPROCESSOR_FLAGS.append("-DENABLE_PYTORCH_NON_PRODUCTION_BUILDS") return PT_PREPROCESSOR_FLAGS -USED_PT_BACKENDS = [ - "CPU", - "QuantizedCPU", - "SparseCPU", # brings ~20 kb size regression -] - # This needs to be kept in sync with https://github.com/pytorch/pytorch/blob/release/1.9/torchgen/gen.py#L892 PT_BACKEND_HEADERS = [ "CPU", diff --git a/pt_defs.oss.bzl b/pt_defs.oss.bzl deleted file mode 100644 index 8710b1bd7c30..000000000000 --- a/pt_defs.oss.bzl +++ /dev/null @@ -1,177 +0,0 @@ -load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library") -load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule") -load( - ":buckbuild.bzl", - "get_pt_operator_registry_dict", -) - -PT_BASE_OPS = [ - "aten::_coalesced_", - "aten::_copy_from", - "aten::_empty_affine_quantized", - "aten::_empty_per_channel_affine_quantized", - "aten::_indices", - "aten::_nnz", - "aten::_values", - "aten::add", - "aten::add_", - "aten::arange", - "aten::as_strided", - "aten::as_strided_", - "aten::cat", - "aten::clone", - "aten::coalesce", - "aten::contiguous", - "aten::copy_", - "aten::copy_sparse_to_sparse_", - "aten::dense_dim", - "aten::dequantize", - "aten::div", - "aten::div_", - "aten::empty", - "aten::empty_like", - "aten::empty_strided", - "aten::empty.memory_format", - "aten::eq", - "aten::equal", - "aten::expand", - "aten::fill_", - "aten::is_coalesced", - "aten::is_complex", - "aten::is_floating_point", - "aten::is_leaf", - "aten::is_nonzero", - "aten::item", - "aten::max", - "aten::min", - "aten::mul", - "aten::mul_", - "aten::narrow", - "aten::ne", - "aten::permute", - "aten::q_per_channel_axis", - "aten::q_per_channel_scales", - "aten::q_per_channel_zero_points", - "aten::q_scale", - "aten::q_zero_point", - "aten::qscheme", - "aten::quantize_per_tensor", - "aten::reshape", - "aten::_reshape_alias", - "aten::resize_", - "aten::resize_as_", - "aten::scalar_tensor", - "aten::select", - "aten::set_", - "aten::size", - "aten::slice", - "aten::sparse_dim", - "aten::sparse_resize_and_clear_", - "aten::squeeze", - "aten::squeeze_", - "aten::stride", - "aten::sub", - "aten::sub_", - "aten::sum", - "aten::t", - "aten::to", - "aten::_to_copy", - "aten::unsqueeze", - "aten::view", - "aten::zero_", - "aten::zeros", - "aten::zeros_like", -] - -######### selective build ######### - -def pt_operator_registry( - name, - deps = [], - train = False, - labels = [], - env = [], - template_select = True, - enforce_traced_op_list = False, - pt_allow_forced_schema_registration = True, - enable_flatbuffer = False, - **kwargs): - args = get_pt_operator_registry_dict( - name, - deps, - train, - labels, - env, - template_select, - enforce_traced_op_list, - pt_allow_forced_schema_registration, - enable_flatbuffer = True, - **kwargs - ) - - fb_xplat_cxx_library( - name = name, - **args - ) - -def get_pt_ops_deps(name, deps, train = False, enforce_traced_op_list = False, enable_flatbuffer = False, **kwargs): - pt_operator_registry( - name, - deps, - train = train, - enforce_traced_op_list = enforce_traced_op_list, - enable_flatbuffer = enable_flatbuffer, - **kwargs - ) - return deps + [":" + name] - -def pt_operator_library( - name, - ops = [], - exported_deps = [], - check_decl = True, - train = False, - model = None, - include_all_operators = False, - **kwargs): - model_name = name - - ops = [op.strip() for op in ops] - - # If ops are specified, then we are in static selective build mode, so we append - # base ops to this list to avoid additional special case logic in subsequent code. - if len(ops) > 0: - ops.extend(PT_BASE_OPS) - - visibility = kwargs.pop("visibility", ["PUBLIC"]) - - fb_xplat_genrule( - name = name, - out = "model_operators.yaml", - cmd = ( - "$(exe :gen_operators_yaml) " + - "{optionally_root_ops} " + - "{optionally_training_root_ops} " + - "--rule_name {rule_name} " + - "--output_path \"${{OUT}}\" " + - "--model_name {model_name} " + - "--dep_graph_yaml_path pytorch_op_deps.yaml " + - "--models_yaml_path all_mobile_model_configs.yaml " + - #"{optionally_model_versions} " + - #"{optionally_model_assets} " + - #"{optionally_model_traced_backends} " + - "{optionally_include_all_operators}" - ).format( - rule_name = name, - model_name = model_name, - optionally_root_ops = "--root_ops " + (",".join(ops)) if len(ops) > 0 else "", - optionally_training_root_ops = "--training_root_ops " + (",".join(ops)) if len(ops) > 0 and train else "", - #optionally_model_versions = "--model_versions " + (",".join(model_versions)) if model_versions != None else "", - #optionally_model_assets = "--model_assets " + (",".join(model_assets)) if model_assets != None else "", - #optionally_model_traced_backends = "--model_traced_backends " + (",".join(model_traced_backends)) if model_traced_backends != None else "", - optionally_include_all_operators = "--include_all_operators " if include_all_operators else "", - ), - labels = ["pt_operator_library"], # for pt_operator_query_codegen query - visibility = visibility, - **kwargs - ) diff --git a/pt_ops.bzl b/pt_ops.bzl index 0d06fd69a679..2dd4ce3e2ab2 100644 --- a/pt_ops.bzl +++ b/pt_ops.bzl @@ -1,3 +1,114 @@ +load("//tools/build_defs:expect.bzl", "expect") +load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule") +load("//tools/build_defs:type_defs.bzl", "is_list", "is_string") + +# @lint-ignore BUCKRESTRICTEDSYNTAX +IS_OSS = read_config("pt", "is_oss", "0") == "1" # True for OSS BUCK build, and False for internal BUCK build + +USED_PT_BACKENDS = [ + "CPU", + "QuantizedCPU", + "SparseCPU", # brings ~20 kb size regression +] + +def pt_operator_library( + name, + ops = [], + exported_deps = [], + check_decl = True, + train = False, + model = None, + include_all_operators = False, + **kwargs): + (model_name, model_versions, model_assets, model_traced_backends) = validate_and_extract_model_information( + name, + model, + ) + + ops = [op.strip() for op in ops] + + # If ops are specified, then we are in static selective build mode, so we append + # base ops to this list to avoid additional special case logic in subsequent code. + if len(ops) > 0: + ops.extend(PT_BASE_OPS) + + labels = kwargs.pop("labels", []) + visibility = kwargs.pop("visibility", ["PUBLIC"]) + + fb_xplat_genrule( + name = name, + out = "model_operators.yaml", + cmd = ( + "$(exe {root}:gen_operators_yaml) " + + "{optionally_root_ops} " + + "{optionally_training_root_ops} " + + "--rule_name {rule_name} " + + "--output_path \"${{OUT}}\" " + + "--model_name {model_name} " + + "--dep_graph_yaml_path {dep_graph_yaml} " + + "--models_yaml_path {models_yaml} " + + "{optionally_model_versions} " + + "{optionally_model_assets} " + + "{optionally_model_traced_backends} " + + "{optionally_include_all_operators}" + ).format( + root = "//" if IS_OSS else "//xplat/caffe2", + rule_name = name, + model_name = model_name, + dep_graph_yaml = "none" if IS_OSS else "$(location //xplat/caffe2:pytorch_op_deps)/fb/pytorch_op_deps.yaml ", + models_yaml = "none" if IS_OSS else "$(location //xplat/pytorch_models:all_mobile_model_configs)/build/all_mobile_model_configs.yaml ", + optionally_root_ops = "--root_ops " + (",".join(ops)) if len(ops) > 0 else "", + optionally_training_root_ops = "--training_root_ops " + (",".join(ops)) if len(ops) > 0 and train else "", + optionally_model_versions = "--model_versions " + (",".join(model_versions)) if model_versions != None else "", + optionally_model_assets = "--model_assets " + (",".join(model_assets)) if model_assets != None else "", + optionally_model_traced_backends = "--model_traced_backends " + (",".join(model_traced_backends)) if model_traced_backends != None else "", + optionally_include_all_operators = "--include_all_operators " if include_all_operators else "", + ), + labels = labels + [ + "pt_operator_library", + "supermodule:android/default/pytorch", + "supermodule:ios/default/public.pytorch", + ] + (["pt_train_operator_library"] if train else []), + visibility = visibility, + **kwargs + ) + +def validate_and_extract_model_information(name, model): + model_name = name + model_versions = None + model_assets = None + model_traced_backends = None + + if model != None: + model_name = model.get("name") + expect(model_name != None, "Expected Model Name to be present") + model_versions = model.get("versions") + expect(is_list(model_versions), "Expected model versions to be a list of string") + for ver in model_versions or []: + expect(is_string(ver), "Expected version '{}' to be string".format(str(ver))) + model_assets = model.get("assets") + expect( + model_assets == None or is_list(model_assets), + "Expected model assets to be a list of string if specified", + ) + for asset_name in model_assets or []: + expect(is_string(asset_name), "Expected asset_name '{}' to be string".format(str(asset_name))) + model_traced_backends = model.get("traced_backends") + expect( + model_traced_backends == None or is_list(model_traced_backends), + "Expected model traced backends to be a list of string if specified", + ) + + if model_traced_backends != None: + for backend in model_traced_backends: + expect(is_string(backend), "Expected backend name '{}' to be string".format(str(backend))) + expect( + backend in USED_PT_BACKENDS, + "Expected backend name ({}) to be in set: {}".format(backend, ",".join(USED_PT_BACKENDS)), + ) + + return (model_name, model_versions, model_assets, model_traced_backends) + # This file keeps a list of PyTorch operators used by any targets in # @fbsource//xplat/... # The purpose of the list is to avoid generating large number of unused