diff --git a/src/pipdeptree/_models/package.py b/src/pipdeptree/_models/package.py index cdacee3..65edc4b 100644 --- a/src/pipdeptree/_models/package.py +++ b/src/pipdeptree/_models/package.py @@ -16,6 +16,8 @@ class Package(ABC): """Abstract class for wrappers around objects that pip returns.""" + UNKNOWN_LICENSE_STR = "(Unknown license)" + def __init__(self, obj: DistInfoDistribution) -> None: self._obj: DistInfoDistribution = obj @@ -27,6 +29,26 @@ def key(self) -> str: def project_name(self) -> str: return self._obj.project_name # type: ignore[no-any-return] + def licenses(self) -> str: + try: + dist_metadata = metadata(self.key) + except PackageNotFoundError: + return self.UNKNOWN_LICENSE_STR + + license_strs: list[str] = [] + classifiers = dist_metadata.get_all("Classifier", []) + + for classifier in classifiers: + line = str(classifier) + if line.startswith("License"): + license_str = line.split(":: ")[-1] + license_strs.append(license_str) + + if len(license_strs) == 0: + return self.UNKNOWN_LICENSE_STR + + return f'({", ".join(license_strs)})' + @abstractmethod def render_as_root(self, *, frozen: bool) -> str: raise NotImplementedError @@ -94,8 +116,6 @@ class DistPackage(Package): """ - UNKNOWN_LICENSE_STR = "(Unknown license)" - def __init__(self, obj: DistInfoDistribution, req: ReqPackage | None = None) -> None: super().__init__(obj) self.req = req @@ -103,26 +123,6 @@ def __init__(self, obj: DistInfoDistribution, req: ReqPackage | None = None) -> def requires(self) -> list[Requirement]: return self._obj.requires() # type: ignore[no-untyped-call,no-any-return] - def licenses(self) -> str: - try: - dist_metadata = metadata(self.key) - except PackageNotFoundError: - return self.UNKNOWN_LICENSE_STR - - license_strs: list[str] = [] - classifiers = dist_metadata.get_all("Classifier", []) - - for classifier in classifiers: - line = str(classifier) - if line.startswith("License"): - license_str = line.split(":: ")[-1] - license_strs.append(license_str) - - if len(license_strs) == 0: - return self.UNKNOWN_LICENSE_STR - - return f'({", ".join(license_strs)})' - @property def version(self) -> str: return self._obj.version # type: ignore[no-any-return] diff --git a/src/pipdeptree/_render/text.py b/src/pipdeptree/_render/text.py index 15e20d1..efa5a02 100644 --- a/src/pipdeptree/_render/text.py +++ b/src/pipdeptree/_render/text.py @@ -3,10 +3,8 @@ from itertools import chain from typing import TYPE_CHECKING, Any -from pipdeptree._models import DistPackage - if TYPE_CHECKING: - from pipdeptree._models import PackageDAG, ReqPackage + from pipdeptree._models import DistPackage, PackageDAG, ReqPackage def render_text( # noqa: PLR0913 @@ -89,8 +87,7 @@ def aux( # noqa: PLR0913, PLR0917 prefix += " " if use_bullets else "" next_prefix = prefix node_str = prefix + bullet + node_str - - if include_license and isinstance(node, DistPackage): + elif include_license: node_str += " " + node.licenses() result = [node_str] @@ -142,7 +139,7 @@ def aux( if parent: prefix = " " * indent + ("- " if use_bullets else "") node_str = prefix + node_str - if include_license and isinstance(node, DistPackage): + elif include_license: node_str += " " + node.licenses() result = [node_str] children = [ diff --git a/tests/_models/test_package.py b/tests/_models/test_package.py index dafa703..d51969f 100644 --- a/tests/_models/test_package.py +++ b/tests/_models/test_package.py @@ -7,6 +7,7 @@ import pytest from pipdeptree._models import DistPackage, ReqPackage +from pipdeptree._models.package import Package if TYPE_CHECKING: from pytest_mock import MockerFixture @@ -68,7 +69,7 @@ def test_dist_package_as_dict() -> None: [ pytest.param( Mock(get_all=lambda *args, **kwargs: []), # noqa: ARG005 - DistPackage.UNKNOWN_LICENSE_STR, + Package.UNKNOWN_LICENSE_STR, id="no-license", ), pytest.param( @@ -106,7 +107,7 @@ def test_dist_package_licenses_importlib_cant_find_package(monkeypatch: pytest.M dist = DistPackage(Mock(project_name="a")) licenses_str = dist.licenses() - assert licenses_str == DistPackage.UNKNOWN_LICENSE_STR + assert licenses_str == Package.UNKNOWN_LICENSE_STR def test_req_package_render_as_root() -> None: diff --git a/tests/render/test_text.py b/tests/render/test_text.py index 0c8eaeb..0db71a6 100644 --- a/tests/render/test_text.py +++ b/tests/render/test_text.py @@ -5,7 +5,7 @@ import pytest from pipdeptree._models import PackageDAG -from pipdeptree._models.package import DistPackage +from pipdeptree._models.package import Package from pipdeptree._render.text import render_text if TYPE_CHECKING: @@ -506,7 +506,57 @@ def test_render_text_with_license_info( ("c", "1.0.0"): [], } dag = PackageDAG.from_pkgs(list(mock_pkgs(graph))) - monkeypatch.setattr(DistPackage, "licenses", lambda _: "(TEST)") + monkeypatch.setattr(Package, "licenses", lambda _: "(TEST)") + + render_text(dag, max_depth=float("inf"), encoding=encoding, include_license=True) + captured = capsys.readouterr() + assert "\n".join(expected_output).strip() == captured.out.strip() + + +@pytest.mark.parametrize( + ("encoding", "expected_output"), + [ + ( + "utf-8", + [ + "a==3.4.0 (TEST)", + "b==2.3.1 (TEST)", + "└── a==3.4.0 [requires: b==2.3.1]", + "c==1.0.0 (TEST)", + "├── a==3.4.0 [requires: c==1.0.0]", + "└── b==2.3.1 [requires: c==1.0.0]", + " └── a==3.4.0 [requires: b==2.3.1]", + ], + ), + ( + "ascii", + [ + "a==3.4.0 (TEST)", + "b==2.3.1 (TEST)", + " - a==3.4.0 [requires: b==2.3.1]", + "c==1.0.0 (TEST)", + " - a==3.4.0 [requires: c==1.0.0]", + " - b==2.3.1 [requires: c==1.0.0]", + " - a==3.4.0 [requires: b==2.3.1]", + ], + ), + ], +) +def test_render_text_with_license_info_and_reversed_tree( + encoding: str, + expected_output: str, + mock_pkgs: Callable[[MockGraph], Iterator[Mock]], + capsys: pytest.CaptureFixture[str], + monkeypatch: pytest.MonkeyPatch, +) -> None: + graph: dict[tuple[str, str], list[tuple[str, list[tuple[str, str]]]]] = { + ("a", "3.4.0"): [("b", [("==", "2.3.1")]), ("c", [("==", "1.0.0")])], + ("b", "2.3.1"): [("c", [("==", "1.0.0")])], + ("c", "1.0.0"): [], + } + dag = PackageDAG.from_pkgs(list(mock_pkgs(graph))) + dag = dag.reverse() + monkeypatch.setattr(Package, "licenses", lambda _: "(TEST)") render_text(dag, max_depth=float("inf"), encoding=encoding, include_license=True) captured = capsys.readouterr()