Skip to content

Commit

Permalink
Improved type annotations in layers
Browse files Browse the repository at this point in the history
  • Loading branch information
seddonym committed Sep 2, 2022
1 parent 16af142 commit 7d0955a
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 11 deletions.
34 changes: 23 additions & 11 deletions src/importlinter/contracts/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,9 @@ def _build_layer_chain_data(
return layer_chain_data

@classmethod
def _get_indirect_collapsed_chains(cls, graph, importer_package, imported_package):
def _get_indirect_collapsed_chains(
cls, graph: ImportGraph, importer_package: Module, imported_package: Module
) -> List:
"""
Squashes the two packages.
Gets a list of paths between them, called segments.
Expand Down Expand Up @@ -334,7 +336,7 @@ def _get_indirect_collapsed_chains(cls, graph, importer_package, imported_packag
)

@classmethod
def _find_segments(cls, graph, importer: Module, imported: Module):
def _find_segments(cls, graph: ImportGraph, importer: Module, imported: Module):
"""
Return list of headless and tailless detailed chains.
"""
Expand All @@ -345,18 +347,26 @@ def _find_segments(cls, graph, importer: Module, imported: Module):
if len(chain) == 2:
raise ValueError("Direct chain found - these should have been removed.")
detailed_chain = []
for importer, imported in [(chain[i], chain[i + 1]) for i in range(len(chain) - 1)]:
import_details = graph.get_import_details(importer=importer, imported=imported)
for importer_in_chain, imported_in_chain in [
(chain[i], chain[i + 1]) for i in range(len(chain) - 1)
]:
import_details = graph.get_import_details(
importer=importer_in_chain, imported=imported_in_chain
)
line_numbers = tuple(set(j["line_number"] for j in import_details))
detailed_chain.append(
{"importer": importer, "imported": imported, "line_numbers": line_numbers}
{
"importer": importer_in_chain,
"imported": imported_in_chain,
"line_numbers": line_numbers,
}
)
segments.append(detailed_chain)
return segments

@classmethod
def _pop_shortest_chains(cls, graph, importer, imported):
chain = True
def _pop_shortest_chains(cls, graph: ImportGraph, importer: str, imported: str):
chain: Union[Optional[Tuple[str, ...]], bool] = True
while chain:
chain = graph.find_shortest_chain(importer, imported)
if chain:
Expand All @@ -366,7 +376,9 @@ def _pop_shortest_chains(cls, graph, importer, imported):
yield chain

@classmethod
def _segments_to_collapsed_chains(cls, graph, segments, importer: Module, imported: Module):
def _segments_to_collapsed_chains(
cls, graph: ImportGraph, segments, importer: Module, imported: Module
):
collapsed_chains = []
for segment in segments:
head_imports = []
Expand Down Expand Up @@ -411,19 +423,19 @@ def _segments_to_collapsed_chains(cls, graph, segments, importer: Module, import

return collapsed_chains

def _remove_other_layers(self, graph, container, layers_to_preserve):
def _remove_other_layers(self, graph: ImportGraph, container, layers_to_preserve):
for index, layer in enumerate(self.layers): # type: ignore
candidate_layer = self._module_from_layer(layer, container)
if candidate_layer.name in graph.modules and candidate_layer not in layers_to_preserve:
self._remove_layer(graph, layer_package=candidate_layer)

def _remove_layer(self, graph, layer_package):
def _remove_layer(self, graph: ImportGraph, layer_package):
for module in graph.find_descendants(layer_package.name):
graph.remove_module(module)
graph.remove_module(layer_package.name)

@classmethod
def _pop_direct_imports(cls, higher_layer_package, lower_layer_package, graph):
def _pop_direct_imports(cls, higher_layer_package, lower_layer_package, graph: ImportGraph):
import_details_list = []
lower_layer_modules = {lower_layer_package.name} | graph.find_descendants(
lower_layer_package.name
Expand Down
12 changes: 12 additions & 0 deletions src/importlinter/domain/ports/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,15 @@ def add_import(

def remove_import(self, *, importer: str, imported: str) -> None:
raise NotImplementedError

def find_modules_directly_imported_by(self, module: str) -> Set[str]:
raise NotImplementedError

def find_modules_that_directly_import(self, module: str) -> Set[str]:
raise NotImplementedError

def squash_module(self, module: str) -> None:
raise NotImplementedError

def remove_module(self, module: str) -> None:
raise NotImplementedError

0 comments on commit 7d0955a

Please sign in to comment.