Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

go: analyze imports paths by module to enable multiple go_mod targets (Cherry pick of #16386) #16799

Merged
merged 2 commits into from Sep 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
180 changes: 85 additions & 95 deletions src/python/pants/backend/codegen/protobuf/go/rules.py
Expand Up @@ -14,10 +14,18 @@
AllProtobufTargets,
ProtobufGrpcToggleField,
ProtobufSourceField,
ProtobufSourcesGeneratorTarget,
ProtobufSourceTarget,
)
from pants.backend.go import target_type_rules
from pants.backend.go.target_type_rules import ImportPathToPackages
from pants.backend.go.target_types import GoPackageSourcesField
from pants.backend.go.dependency_inference import (
GoImportPathsMappingAddressSet,
GoModuleImportPathsMapping,
GoModuleImportPathsMappings,
GoModuleImportPathsMappingsHook,
)
from pants.backend.go.target_type_rules import GoImportPathMappingRequest
from pants.backend.go.target_types import GoOwningGoModAddressField, GoPackageSourcesField
from pants.backend.go.util_rules import (
assembly,
build_pkg,
Expand All @@ -36,10 +44,8 @@
BuildGoPackageTargetRequest,
GoCodegenBuildRequest,
)
from pants.backend.go.util_rules.first_party_pkg import (
FallibleFirstPartyPkgAnalysis,
FirstPartyPkgAnalysisRequest,
)
from pants.backend.go.util_rules.first_party_pkg import FallibleFirstPartyPkgAnalysis
from pants.backend.go.util_rules.go_mod import OwningGoMod, OwningGoModRequest
from pants.backend.go.util_rules.pkg_analyzer import PackageAnalyzerSetup
from pants.backend.go.util_rules.sdk import GoSdkProcess
from pants.backend.python.util_rules import pex
Expand All @@ -65,13 +71,10 @@
from pants.engine.process import FallibleProcessResult, Process, ProcessResult
from pants.engine.rules import collect_rules, rule
from pants.engine.target import (
FieldSet,
GeneratedSources,
GenerateSourcesRequest,
HydratedSources,
HydrateSourcesRequest,
InferDependenciesRequest,
InferredDependencies,
SourcesPaths,
SourcesPathsRequest,
TransitiveTargets,
Expand Down Expand Up @@ -118,17 +121,15 @@ def parse_go_package_option(content_raw: bytes) -> str | None:
return None


@dataclass(frozen=True)
class GoProtobufImportPathMapping:
"""Maps import paths of Go Protobuf packages to the addresses."""

mapping: FrozenDict[str, tuple[Address, ...]]
class ProtobufGoModuleImportPathsMappingsHook(GoModuleImportPathsMappingsHook):
pass


@rule(desc="Map import paths for all Go Protobuf targets.", level=LogLevel.DEBUG)
async def map_import_paths_of_all_go_protobuf_targets(
targets: AllProtobufTargets,
) -> GoProtobufImportPathMapping:
_request: ProtobufGoModuleImportPathsMappingsHook,
all_protobuf_targets: AllProtobufTargets,
) -> GoModuleImportPathsMappings:
sources = await MultiGet(
Get(
HydratedSources,
Expand All @@ -138,28 +139,57 @@ async def map_import_paths_of_all_go_protobuf_targets(
enable_codegen=True,
),
)
for tgt in targets
for tgt in all_protobuf_targets
)

all_contents = await MultiGet(
Get(DigestContents, Digest, source.snapshot.digest) for source in sources
)

go_protobuf_targets: dict[str, set[Address]] = defaultdict(set)
for tgt, contents in zip(targets, all_contents):
go_protobuf_mapping_metadata = []
owning_go_mod_gets = []
for tgt, contents in zip(all_protobuf_targets, all_contents):
if not contents:
continue
if len(contents) > 1:
raise AssertionError(
f"Protobuf target `{tgt.address}` mapped to more than one source file."
)

import_path = parse_go_package_option(contents[0].content)
if not import_path:
continue
go_protobuf_targets[import_path].add(tgt.address)

return GoProtobufImportPathMapping(
FrozenDict({ip: tuple(addrs) for ip, addrs in go_protobuf_targets.items()})
owning_go_mod_gets.append(Get(OwningGoMod, OwningGoModRequest(tgt.address)))
go_protobuf_mapping_metadata.append((import_path, tgt.address))

owning_go_mod_targets = await MultiGet(owning_go_mod_gets)

import_paths_by_module: dict[Address, dict[str, set[Address]]] = defaultdict(
lambda: defaultdict(set)
)

for owning_go_mod, (import_path, address) in zip(
owning_go_mod_targets, go_protobuf_mapping_metadata
):
import_paths_by_module[owning_go_mod.address][import_path].add(address)

return GoModuleImportPathsMappings(
FrozenDict(
{
go_mod_addr: GoModuleImportPathsMapping(
mapping=FrozenDict(
{
import_path: GoImportPathsMappingAddressSet(
addresses=tuple(sorted(addresses)), infer_all=True
)
for import_path, addresses in import_path_mapping.items()
}
),
)
for go_mod_addr, import_path_mapping in import_paths_by_module.items()
}
)
)


Expand All @@ -182,8 +212,6 @@ async def setup_full_package_build_request(
request: _SetupGoProtobufPackageBuildRequest,
protoc: Protoc,
go_protoc_plugin: _SetupGoProtocPlugin,
package_mapping: ImportPathToPackages,
go_protobuf_mapping: GoProtobufImportPathMapping,
analyzer: PackageAnalyzerSetup,
) -> FallibleBuildGoPackageRequest:
output_dir = "_generated_files"
Expand All @@ -196,6 +224,11 @@ async def setup_full_package_build_request(
Get(Digest, CreateDigest([Directory(output_dir)])),
)

go_mod_addr = await Get(OwningGoMod, OwningGoModRequest(transitive_targets.roots[0].address))
package_mapping = await Get(
GoModuleImportPathsMapping, GoImportPathMappingRequest(go_mod_addr.address)
)

all_sources = await Get(
SourceFiles,
SourceFilesRequest(
Expand Down Expand Up @@ -317,25 +350,23 @@ async def setup_full_package_build_request(
candidate_addresses = package_mapping.mapping.get(dep_import_path)
if candidate_addresses:
# TODO: Use explicit dependencies to disambiguate? This should never happen with Go backend though.
if len(candidate_addresses) > 1:
return FallibleBuildGoPackageRequest(
request=None,
import_path=request.import_path,
exit_code=result.exit_code,
stderr=textwrap.dedent(
f"""
Multiple addresses match import of `{dep_import_path}`.

addresses: {', '.join(str(a) for a in candidate_addresses)}
"""
).strip(),
)
dep_build_request_addrs.extend(candidate_addresses)

# Infer dependencies on other generated Go sources.
go_protobuf_candidate_addresses = go_protobuf_mapping.mapping.get(dep_import_path)
if go_protobuf_candidate_addresses:
dep_build_request_addrs.extend(go_protobuf_candidate_addresses)
if candidate_addresses.infer_all:
dep_build_request_addrs.extend(candidate_addresses.addresses)
else:
if len(candidate_addresses.addresses) > 1:
return FallibleBuildGoPackageRequest(
request=None,
import_path=request.import_path,
exit_code=result.exit_code,
stderr=textwrap.dedent(
f"""
Multiple addresses match import of `{dep_import_path}`.

addresses: {', '.join(str(a) for a in candidate_addresses.addresses)}
"""
).strip(),
)
dep_build_request_addrs.extend(candidate_addresses.addresses)

dep_build_requests = await MultiGet(
Get(BuildGoPackageRequest, BuildGoPackageTargetRequest(addr))
Expand All @@ -359,7 +390,6 @@ async def setup_full_package_build_request(
@rule
async def setup_build_go_package_request_for_protobuf(
request: GoCodegenBuildProtobufRequest,
protobuf_package_mapping: GoProtobufImportPathMapping,
) -> FallibleBuildGoPackageRequest:
# Hydrate the protobuf source to parse for the Go import path.
sources = await Get(HydratedSources, HydrateSourcesRequest(request.target[ProtobufSourceField]))
Expand All @@ -374,10 +404,15 @@ async def setup_build_go_package_request_for_protobuf(
stderr=f"No import path was set in Protobuf file via `option go_package` directive for {request.target.address}.",
)

go_mod_addr = await Get(OwningGoMod, OwningGoModRequest(request.target.address))
package_mapping = await Get(
GoModuleImportPathsMapping, GoImportPathMappingRequest(go_mod_addr.address)
)

# Request the full build of the package. This indirection is necessary so that requests for two or more
# Protobuf files in the same Go package result in a single cacheable rule invocation.
protobuf_target_addrs_for_import_path = protobuf_package_mapping.mapping.get(import_path)
if not protobuf_target_addrs_for_import_path:
protobuf_target_addrs_set_for_import_path = package_mapping.mapping.get(import_path)
if not protobuf_target_addrs_set_for_import_path:
return FallibleBuildGoPackageRequest(
request=None,
import_path=import_path,
Expand All @@ -393,7 +428,7 @@ async def setup_build_go_package_request_for_protobuf(
return await Get(
FallibleBuildGoPackageRequest,
_SetupGoProtobufPackageBuildRequest(
addresses=protobuf_target_addrs_for_import_path,
addresses=protobuf_target_addrs_set_for_import_path.addresses,
import_path=import_path,
),
)
Expand Down Expand Up @@ -589,59 +624,14 @@ async def setup_go_protoc_plugin(platform: Platform) -> _SetupGoProtocPlugin:
return _SetupGoProtocPlugin(plugin_digest)


@dataclass(frozen=True)
class GoProtobufDependenciesInferenceFieldSet(FieldSet):
required_fields = (GoPackageSourcesField,)

sources: GoPackageSourcesField


class InferGoProtobufDependenciesRequest(InferDependenciesRequest):
infer_from = GoProtobufDependenciesInferenceFieldSet


@rule(
desc="Infer dependencies on Protobuf sources for first-party Go packages", level=LogLevel.DEBUG
)
async def infer_go_dependencies(
request: InferGoProtobufDependenciesRequest,
go_protobuf_mapping: GoProtobufImportPathMapping,
) -> InferredDependencies:
address = request.field_set.address
maybe_pkg_analysis = await Get(
FallibleFirstPartyPkgAnalysis, FirstPartyPkgAnalysisRequest(address)
)
if maybe_pkg_analysis.analysis is None:
_logger.error(
softwrap(
f"""
Failed to analyze {maybe_pkg_analysis.import_path} for dependency inference:

{maybe_pkg_analysis.stderr}
"""
)
)
return InferredDependencies([])
pkg_analysis = maybe_pkg_analysis.analysis

inferred_dependencies: list[Address] = []
for import_path in (
*pkg_analysis.imports,
*pkg_analysis.test_imports,
*pkg_analysis.xtest_imports,
):
candidate_addresses = go_protobuf_mapping.mapping.get(import_path, ())
inferred_dependencies.extend(candidate_addresses)

return InferredDependencies(inferred_dependencies)


def rules():
return (
*collect_rules(),
UnionRule(GenerateSourcesRequest, GenerateGoFromProtobufRequest),
UnionRule(GoCodegenBuildRequest, GoCodegenBuildProtobufRequest),
UnionRule(InferDependenciesRequest, InferGoProtobufDependenciesRequest),
UnionRule(GoModuleImportPathsMappingsHook, ProtobufGoModuleImportPathsMappingsHook),
ProtobufSourcesGeneratorTarget.register_plugin_field(GoOwningGoModAddressField),
ProtobufSourceTarget.register_plugin_field(GoOwningGoModAddressField),
# Rules needed for this to pass src/python/pants/init/load_backends_integration_test.py:
*assembly.rules(),
*build_pkg.rules(),
Expand Down
2 changes: 2 additions & 0 deletions src/python/pants/backend/experimental/go/register.py
Expand Up @@ -14,6 +14,7 @@
)
from pants.backend.go.util_rules import (
assembly,
binary,
build_pkg,
build_pkg_target,
coverage,
Expand All @@ -38,6 +39,7 @@ def target_types():
def rules():
return [
*assembly.rules(),
*binary.rules(),
*build_pkg.rules(),
*build_pkg_target.rules(),
*check.rules(),
Expand Down
48 changes: 48 additions & 0 deletions src/python/pants/backend/go/dependency_inference.py
@@ -0,0 +1,48 @@
# Copyright 2022 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).
from __future__ import annotations

from dataclasses import dataclass

from pants.build_graph.address import Address
from pants.engine.unions import union
from pants.util.frozendict import FrozenDict


@union
@dataclass(frozen=True)
class GoModuleImportPathsMappingsHook:
"""An entry point for a specific implementation of mapping Go import paths to owning targets.

All implementations will be merged together. The core Go dependency inference rules will request
the `GoModuleImportPathsMappings` type using implementations of this union.
"""


@dataclass(frozen=True)
class GoImportPathsMappingAddressSet:
addresses: tuple[Address, ...]
infer_all: bool


@dataclass(frozen=True)
class GoModuleImportPathsMapping:
"""Maps import paths (as strings) to one or more addresses of targets providing those import
path(s) for a single Go module."""

mapping: FrozenDict[str, GoImportPathsMappingAddressSet]


@dataclass(frozen=True)
class GoModuleImportPathsMappings:
"""Import path mappings for all Go modules in the repository.

This type is requested from plugins which provide implementations for the GoCodegenBuildRequest
union and then merged.
"""

modules: FrozenDict[Address, GoModuleImportPathsMapping]


class AllGoModuleImportPathsMappings(GoModuleImportPathsMappings):
pass