diff --git a/tools/cpp/cc_flags_supplier.bzl b/tools/cpp/cc_flags_supplier.bzl index 9fe4dc5ef20d25..1be5a82c40038b 100644 --- a/tools/cpp/cc_flags_supplier.bzl +++ b/tools/cpp/cc_flags_supplier.bzl @@ -15,7 +15,7 @@ load("@bazel_tools//tools/build_defs/cc:action_names.bzl", "CC_FLAGS_MAKE_VARIABLE_ACTION_NAME") load("@bazel_tools//tools/cpp:cc_flags_supplier_lib.bzl", "build_cc_flags") -load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain") +load("@bazel_tools//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain", "use_cpp_toolchain") def _cc_flags_supplier_impl(ctx): cc_toolchain = find_cpp_toolchain(ctx) @@ -30,7 +30,7 @@ cc_flags_supplier = rule( attrs = { "_cc_toolchain": attr.label(default = Label("@bazel_tools//tools/cpp:current_cc_toolchain")), }, - toolchains = ["@bazel_tools//tools/cpp:toolchain_type"], + toolchains = use_cpp_toolchain(), incompatible_use_toolchain_transition = True, fragments = ["cpp"], ) diff --git a/tools/cpp/compiler_flag.bzl b/tools/cpp/compiler_flag.bzl index c06f98378c7a5c..10b806869b8d18 100644 --- a/tools/cpp/compiler_flag.bzl +++ b/tools/cpp/compiler_flag.bzl @@ -14,7 +14,7 @@ """Rule that allows select() to differentiate between compilers.""" -load("//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain") +load("//tools/cpp:toolchain_utils.bzl", "find_cpp_toolchain", "use_cpp_toolchain") def _compiler_flag_impl(ctx): toolchain = find_cpp_toolchain(ctx) @@ -25,6 +25,6 @@ compiler_flag = rule( attrs = { "_cc_toolchain": attr.label(default = Label("//tools/cpp:current_cc_toolchain")), }, - toolchains = ["@bazel_tools//tools/cpp:toolchain_type"], + toolchains = use_cpp_toolchain(), incompatible_use_toolchain_transition = True, ) diff --git a/tools/cpp/toolchain_utils.bzl b/tools/cpp/toolchain_utils.bzl index 6b124b91a85415..ee2e3c217a34f0 100644 --- a/tools/cpp/toolchain_utils.bzl +++ b/tools/cpp/toolchain_utils.bzl @@ -14,18 +14,19 @@ # limitations under the License. """ -Finds the c++ toolchain. +Utilities to help work with c++ toolchains. -Returns the toolchain if enabled, and falls back to a toolchain constructed from -the CppConfiguration. """ +CPP_TOOLCHAIN_TYPE = "@bazel_tools//tools/cpp:toolchain_type" + def find_cpp_toolchain(ctx): """ Finds the c++ toolchain. If the c++ toolchain is in use, returns it. Otherwise, returns a c++ - toolchain derived from legacy toolchain selection. + toolchain derived from legacy toolchain selection, constructed from + the CppConfiguration. Args: ctx: The rule context for which to find a toolchain. @@ -36,9 +37,9 @@ def find_cpp_toolchain(ctx): # Check the incompatible flag for toolchain resolution. if hasattr(cc_common, "is_cc_toolchain_resolution_enabled_do_not_use") and cc_common.is_cc_toolchain_resolution_enabled_do_not_use(ctx = ctx): - if not "@bazel_tools//tools/cpp:toolchain_type" in ctx.toolchains: - fail("In order to use find_cpp_toolchain, you must include the '@bazel_tools//tools/cpp:toolchain_type' in the toolchains argument to your rule.") - toolchain_info = ctx.toolchains["@bazel_tools//tools/cpp:toolchain_type"] + if not CPP_TOOLCHAIN_TYPE in ctx.toolchains: + fail("In order to use find_cpp_toolchain, you must include the '%s' in the toolchains argument to your rule." % CPP_TOOLCHAIN_TYPE) + toolchain_info = ctx.toolchains[CPP_TOOLCHAIN_TYPE] if hasattr(toolchain_info, "cc_provider_in_toolchain") and hasattr(toolchain_info, "cc"): return toolchain_info.cc return toolchain_info @@ -49,3 +50,23 @@ def find_cpp_toolchain(ctx): # We didn't find anything. fail("In order to use find_cpp_toolchain, you must define the '_cc_toolchain' attribute on your rule or aspect.") + +def use_cpp_toolchain(mandatory = True): + """ + Helper to depend on the c++ toolchain. + + Usage: + ``` + my_rule = rule( + toolchains = [other toolchain types] + use_cpp_toolchain(), + ) + ``` + + Args: + mandatory: Whether or not it should be an error if the toolchain cannot be resolved. + Currently ignored, this will be enabled when optional toolchain types are added. + + Returns: + A list that can be used as the value for `rule.toolchains`. + """ + return [CPP_TOOLCHAIN_TYPE]