Skip to content

Commit

Permalink
paths goal: allow finding paths between addresses that expand to mult…
Browse files Browse the repository at this point in the history
…iple targets
  • Loading branch information
AlexTereshenkov committed Jul 19, 2023
1 parent ed38de4 commit afeb676
Showing 1 changed file with 84 additions and 33 deletions.
117 changes: 84 additions & 33 deletions src/python/pants/backend/project_info/paths.py
Expand Up @@ -5,18 +5,20 @@

import json
from collections import deque
from dataclasses import dataclass
from typing import Iterable

from pants.base.specs import Specs
from pants.base.specs_parser import SpecsParser
from pants.engine.addresses import Address
from pants.engine.console import Console
from pants.engine.goal import Goal, GoalSubsystem, Outputting
from pants.engine.rules import Get, MultiGet, collect_rules, goal_rule
from pants.engine.rules import Get, MultiGet, collect_rules, goal_rule, rule
from pants.engine.target import (
AlwaysTraverseDeps,
Dependencies,
DependenciesRequest,
Target,
Targets,
TransitiveTargets,
TransitiveTargetsRequest,
Expand Down Expand Up @@ -83,6 +85,75 @@ def find_paths_breadth_first(
visited_edges.add(current_edge)


@dataclass
class SpecsPaths:
paths: list[list[str]]


@dataclass
class SpecsPathsCollection:
spec_paths: list[SpecsPaths]


@dataclass(frozen=True)
class RootDestinationPair:
root: Target
destination: Target


@dataclass(frozen=True)
class RootDestinationsPair:
root: Target
destinations: Targets


@rule(desc="Get paths between root and destination.")
async def get_paths_between_root_and_destination(pair: RootDestinationPair) -> SpecsPaths:
transitive_targets = await Get(
TransitiveTargets,
TransitiveTargetsRequest(
[pair.root.address], should_traverse_deps_predicate=AlwaysTraverseDeps()
),
)

adjacent_targets_per_target = await MultiGet(
Get(
Targets,
DependenciesRequest(
tgt.get(Dependencies), should_traverse_deps_predicate=AlwaysTraverseDeps()
),
)
for tgt in transitive_targets.closure
)

transitive_targets_closure_addresses = (t.address for t in transitive_targets.closure)
adjacency_lists = dict(zip(transitive_targets_closure_addresses, adjacent_targets_per_target))

spec_paths = []
for path in find_paths_breadth_first(
adjacency_lists, pair.root.address, pair.destination.address
):
spec_path = [address.spec for address in path]
spec_paths.append(spec_path)

return SpecsPaths(paths=spec_paths)


@rule("Get paths between root and multiple destinations.")
async def get_paths_between_root_and_destinations(
pair: RootDestinationsPair,
) -> SpecsPathsCollection:
spec_paths = await MultiGet(
Get(
SpecsPaths,
RootDestinationPair,
RootDestinationPair(destination=destination, root=pair.root),
)
for destination in pair.destinations
)
return SpecsPathsCollection(spec_paths=list(spec_paths))


@goal_rule
async def paths(console: Console, paths_subsystem: PathsSubsystem) -> PathsGoal:
path_from = paths_subsystem.from_
Expand Down Expand Up @@ -116,38 +187,18 @@ async def paths(console: Console, paths_subsystem: PathsSubsystem) -> PathsGoal:
)

all_spec_paths = []
for root in from_tgts:
for destination in to_tgts:
spec_paths = []
transitive_targets = await Get( # noqa: PNT30: ignore
TransitiveTargets,
TransitiveTargetsRequest(
[root.address], should_traverse_deps_predicate=AlwaysTraverseDeps()
),
)

adjacent_targets_per_target = await MultiGet( # noqa: PNT30: ignore
Get(
Targets,
DependenciesRequest(
tgt.get(Dependencies), should_traverse_deps_predicate=AlwaysTraverseDeps()
),
)
for tgt in transitive_targets.closure
)

transitive_targets_closure_addresses = (t.address for t in transitive_targets.closure)
adjacency_lists = dict(
zip(transitive_targets_closure_addresses, adjacent_targets_per_target)
)

for path in find_paths_breadth_first(
adjacency_lists, root.address, destination.address
):
spec_path = [address.spec for address in path]
spec_paths.append(spec_path)

all_spec_paths.extend(spec_paths)
spec_paths = await MultiGet(
Get(
SpecsPathsCollection,
RootDestinationsPair,
RootDestinationsPair(root=root, destinations=to_tgts),
)
for root in from_tgts
)

for spec_path in spec_paths:
for path in (p.paths for p in spec_path.spec_paths):
all_spec_paths.extend(path)

with paths_subsystem.output(console) as write_stdout:
write_stdout(json.dumps(all_spec_paths, indent=2) + "\n")
Expand Down

0 comments on commit afeb676

Please sign in to comment.