diff --git a/pip_audit/_service/interface.py b/pip_audit/_service/interface.py index 92f05507..3978ebb9 100644 --- a/pip_audit/_service/interface.py +++ b/pip_audit/_service/interface.py @@ -114,6 +114,12 @@ def merge_aliases(self, other: VulnerabilityResult) -> VulnerabilityResult: self.id, self.description, self.fix_versions, self.aliases | other.aliases - {self.id} ) + def has_any_id(self, ids: Set[str]) -> bool: + """ + Returns whether ids intersects with {id} | aliases. + """ + return bool(ids & (self.aliases | {self.id})) + class VulnerabilityService(ABC): """ diff --git a/test/service/test_interface.py b/test/service/test_interface.py index b10725e4..828eb401 100644 --- a/test/service/test_interface.py +++ b/test/service/test_interface.py @@ -62,3 +62,14 @@ def test_vulnerability_result_update_aliases(): merged = result1.merge_aliases(result2) assert merged.id == "FOO" assert merged.aliases == {"BAR", "BAZ", "ZAP", "QUUX"} + + +def test_vulnerability_result_has_any_id(): + result = VulnerabilityResult( + id="FOO", description="bar", fix_versions=[Version("1.0.0")], aliases={"BAR", "BAZ", "QUUX"} + ) + + assert result.has_any_id({"FOO"}) + assert result.has_any_id({"ham", "eggs", "BAZ"}) + assert not result.has_any_id({"zilch"}) + assert not result.has_any_id(set())