From ca36ea43bca6f0f52b832dee97c64e970788b122 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?B=C3=A9n=C3=A9dikt=20Tran?= <10796600+picnixz@users.noreply.github.com> Date: Sat, 22 Nov 2025 16:34:47 +0100 Subject: [PATCH 1/4] fix IP address/network comparison methods --- Lib/ipaddress.py | 199 ++++++++---------- Lib/test/test_ipaddress.py | 121 ++++++++--- ...-11-22-14-57-08.gh-issue-141647.PowRcQ.rst | 5 + 3 files changed, 192 insertions(+), 133 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2025-11-22-14-57-08.gh-issue-141647.PowRcQ.rst diff --git a/Lib/ipaddress.py b/Lib/ipaddress.py index f1062a8cd052a5..2152fe886c9a5c 100644 --- a/Lib/ipaddress.py +++ b/Lib/ipaddress.py @@ -150,6 +150,12 @@ def v6_int_to_packed(address): raise ValueError("Address negative or too large for IPv6") +def _check_ip_version(a, b): + if a.version != b.version: + # does this need to raise a ValueError? + raise TypeError(f"{a} and {b} are not of the same version") + + def _split_optional_netmask(address): """Helper to split the netmask and raise AddressValueError if needed""" addr = str(address).split('/') @@ -213,7 +219,7 @@ def summarize_address_range(first, last): Raise: TypeError: - If the first and last objects are not IP addresses. + If the first or last objects are not IP addresses. If the first and last objects are not the same version. ValueError: If the last object is not greater than the first. @@ -223,9 +229,7 @@ def summarize_address_range(first, last): if (not (isinstance(first, _BaseAddress) and isinstance(last, _BaseAddress))): raise TypeError('first and last must be IP addresses, not networks') - if first.version != last.version: - raise TypeError("%s and %s are not of the same version" % ( - first, last)) + _check_ip_version(first, last) if first > last: raise ValueError('last IP address must be greater than first') @@ -316,40 +320,39 @@ def collapse_addresses(addresses): TypeError: If passed a list of mixed version objects. """ - addrs = [] ips = [] nets = [] - # split IP addresses and networks + # split IP addresses/interfaces and networks for ip in addresses: if isinstance(ip, _BaseAddress): - if ips and ips[-1].version != ip.version: - raise TypeError("%s and %s are not of the same version" % ( - ip, ips[-1])) - ips.append(ip) - elif ip._prefixlen == ip.max_prefixlen: - if ips and ips[-1].version != ip.version: - raise TypeError("%s and %s are not of the same version" % ( - ip, ips[-1])) - try: - ips.append(ip.ip) - except AttributeError: - ips.append(ip.network_address) + if ips: + _check_ip_version(ips[-1], ip) + if hasattr(ip, "ip") and isinstance(ip.ip, _BaseAddress): + ips.append(ip.ip) # interface IP address + else: + ips.append(ip) + elif isinstance(ip, _BaseNetwork): + if ip.prefixlen == ip.max_prefixlen: + if ips: + _check_ip_version(ips[-1], ip) + ips.append(ip.network_address) # network address + else: + if nets: + _check_ip_version(nets[-1], ip) + nets.append(ip) else: - if nets and nets[-1].version != ip.version: - raise TypeError("%s and %s are not of the same version" % ( - ip, nets[-1])) - nets.append(ip) + raise TypeError(f"{ip} is not an IP object") # sort and dedup ips = sorted(set(ips)) - # find consecutive address ranges in the sorted sequence and summarize them + nets_from_range = [] if ips: for first, last in _find_address_range(ips): - addrs.extend(summarize_address_range(first, last)) + nets_from_range.extend(summarize_address_range(first, last)) - return _collapse_addresses_internal(addrs + nets) + return _collapse_addresses_internal(nets_from_range + nets) def get_mixed_type_key(obj): @@ -567,21 +570,15 @@ def __int__(self): return self._ip def __eq__(self, other): - try: - return (self._ip == other._ip - and self.version == other.version) - except AttributeError: + if not isinstance(other, _BaseAddress): return NotImplemented + return self._ip == other._ip and self.version == other.version def __lt__(self, other): if not isinstance(other, _BaseAddress): return NotImplemented - if self.version != other.version: - raise TypeError('%s and %s are not of the same version' % ( - self, other)) - if self._ip != other._ip: - return self._ip < other._ip - return False + _check_ip_version(self, other) + return self._ip < other._ip # Shorthand for Integer addition and subtraction. This is not # meant to ever support addition/subtraction of addresses. @@ -708,9 +705,7 @@ def __getitem__(self, n): def __lt__(self, other): if not isinstance(other, _BaseNetwork): return NotImplemented - if self.version != other.version: - raise TypeError('%s and %s are not of the same version' % ( - self, other)) + _check_ip_version(self, other) if self.network_address != other.network_address: return self.network_address < other.network_address if self.netmask != other.netmask: @@ -718,30 +713,31 @@ def __lt__(self, other): return False def __eq__(self, other): - try: - return (self.version == other.version and - self.network_address == other.network_address and - int(self.netmask) == int(other.netmask)) - except AttributeError: + if not isinstance(other, _BaseNetwork): return NotImplemented + return (self.version == other.version + and self.network_address == other.network_address + and int(self.netmask._ip) == int(other.netmask)) def __hash__(self): return hash((int(self.network_address), int(self.netmask))) def __contains__(self, other): - # always false if one is v4 and the other is v6. - if self.version != other.version: - return False - # dealing with another network. if isinstance(other, _BaseNetwork): + # should __contains__ actually implement subnet_of() + # and supernet_of() instead? return False - # dealing with another address - else: - # address - return other._ip & self.netmask._ip == self.network_address._ip + if isinstance(other, _BaseAddress): + return ( + self.version == other.version + and (other._ip & self.netmask._ip) == self.network_address._ip + ) + return NotImplemented def overlaps(self, other): """Tell if self is partly contained in other.""" + if not isinstance(other, _BaseNetwork): + raise TypeError("%s is not a network object" % other) return self.network_address in other or ( self.broadcast_address in other or ( other.network_address in self or ( @@ -821,13 +817,9 @@ def address_exclude(self, other): ValueError: If other is not completely contained by self. """ - if not self.version == other.version: - raise TypeError("%s and %s are not of the same version" % ( - self, other)) - if not isinstance(other, _BaseNetwork): raise TypeError("%s is not a network object" % other) - + _check_ip_version(self, other) if not other.subnet_of(self): raise ValueError('%s not contained in %s' % (other, self)) if other == self: @@ -870,7 +862,7 @@ def compare_networks(self, other): 'HostA._ip < HostB._ip' Args: - other: An IP object. + other: An IP network object. Returns: If the IP versions of self and other are the same, returns: @@ -892,10 +884,9 @@ def compare_networks(self, other): TypeError if the IP versions are different. """ - # does this need to raise a ValueError? - if self.version != other.version: - raise TypeError('%s and %s are not of the same type' % ( - self, other)) + if not isinstance(other, _BaseNetwork): + raise TypeError("%s is not a network object" % other) + _check_ip_version(self, other) # self.version == other.version below here: if self.network_address < other.network_address: return -1 @@ -1026,15 +1017,13 @@ def is_multicast(self): @staticmethod def _is_subnet_of(a, b): - try: - # Always false if one is v4 and the other is v6. - if a.version != b.version: - raise TypeError(f"{a} and {b} are not of the same version") - return (b.network_address <= a.network_address and - b.broadcast_address >= a.broadcast_address) - except AttributeError: - raise TypeError(f"Unable to test subnet containment " - f"between {a} and {b}") + if not isinstance(a, _BaseNetwork): + raise TypeError(f"{a} is not a network object") + if not isinstance(b, _BaseNetwork): + raise TypeError(f"{b} is not a network object") + _check_ip_version(a, b) + return (b.network_address <= a.network_address and + b.broadcast_address >= a.broadcast_address) def subnet_of(self, other): """Return True if this network is a subnet of other.""" @@ -1429,28 +1418,27 @@ def __str__(self): self._prefixlen) def __eq__(self, other): + if not isinstance(other, IPv4Interface): + if isinstance(other, IPv4Address): + # avoid falling back to IPv4Address.__eq__(other, self) + return False + return NotImplemented + # An interface with an associated network is NOT the + # same as an unassociated address. That's why the hash + # takes the extra info into account. address_equal = IPv4Address.__eq__(self, other) - if address_equal is NotImplemented or not address_equal: - return address_equal - try: - return self.network == other.network - except AttributeError: - # An interface with an associated network is NOT the - # same as an unassociated address. That's why the hash - # takes the extra info into account. - return False + return address_equal and self.network == other.network def __lt__(self, other): + # We *do* allow addresses and interfaces to be sorted. The + # unassociated address is considered less than all interfaces. address_less = IPv4Address.__lt__(self, other) - if address_less is NotImplemented: - return NotImplemented - try: - return (self.network < other.network or - self.network == other.network and address_less) - except AttributeError: - # We *do* allow addresses and interfaces to be sorted. The - # unassociated address is considered less than all interfaces. - return False + if isinstance(other, IPv4Interface): + assert address_less is not NotImplemented + # compare interfaces by their network first + return (self.network < other.network + or (self.network == other.network and address_less)) + return address_less def __hash__(self): return hash((self._ip, self._prefixlen, int(self.network.network_address))) @@ -2219,28 +2207,27 @@ def __str__(self): self._prefixlen) def __eq__(self, other): + if not isinstance(other, IPv6Interface): + if isinstance(other, IPv6Address): + # avoid falling back to IPv6Address.__eq__(other, self) + return False + return NotImplemented + # An interface with an associated network is NOT the + # same as an unassociated address. That's why the hash + # takes the extra info into account. address_equal = IPv6Address.__eq__(self, other) - if address_equal is NotImplemented or not address_equal: - return address_equal - try: - return self.network == other.network - except AttributeError: - # An interface with an associated network is NOT the - # same as an unassociated address. That's why the hash - # takes the extra info into account. - return False + return address_equal and self.network == other.network def __lt__(self, other): + # We *do* allow addresses and interfaces to be sorted. The + # unassociated address is considered less than all interfaces. address_less = IPv6Address.__lt__(self, other) - if address_less is NotImplemented: - return address_less - try: - return (self.network < other.network or - self.network == other.network and address_less) - except AttributeError: - # We *do* allow addresses and interfaces to be sorted. The - # unassociated address is considered less than all interfaces. - return False + if isinstance(other, IPv6Interface): + assert address_less is not NotImplemented + # compare interfaces by their network first + return (self.network < other.network + or (self.network == other.network and address_less)) + return address_less def __hash__(self): return hash((self._ip, self._prefixlen, int(self.network.network_address))) diff --git a/Lib/test/test_ipaddress.py b/Lib/test/test_ipaddress.py index 3f017b97dc28a3..f079af9b4a71cc 100644 --- a/Lib/test/test_ipaddress.py +++ b/Lib/test/test_ipaddress.py @@ -13,6 +13,7 @@ import ipaddress import weakref from collections.abc import Iterator +from functools import total_ordering from test.support import LARGEST, SMALLEST @@ -912,11 +913,22 @@ class ComparisonTests(unittest.TestCase): v6intf_scoped = ipaddress.IPv6Interface('::1%scope') v4_addresses = [v4addr, v4intf] - v4_objects = v4_addresses + [v4net] + v4_networks = [v4net] + v4_objects = v4_addresses + v4_networks + v6_addresses = [v6addr, v6intf] - v6_objects = v6_addresses + [v6net] + v6_networks = [v6net] + v6_objects = v6_addresses + v6_networks + v6_scoped_addresses = [v6addr_scoped, v6intf_scoped] - v6_scoped_objects = v6_scoped_addresses + [v6net_scoped] + v6_scoped_networks = [v6net_scoped] + v6_scoped_objects = v6_scoped_addresses + v6_scoped_networks + + addresses = v4_addresses + v6_addresses + addresses_with_scoped = addresses + v6_scoped_addresses + + networks = v4_networks + v6_networks + networks_with_scoped = networks + v6_scoped_networks objects = v4_objects + v6_objects objects_with_scoped = objects + v6_scoped_objects @@ -935,10 +947,14 @@ def test_foreign_type_equality(self): # __eq__ should never raise TypeError directly other = object() for obj in self.objects_with_scoped: - self.assertNotEqual(obj, other) - self.assertFalse(obj == other) - self.assertEqual(obj.__eq__(other), NotImplemented) - self.assertEqual(obj.__ne__(other), NotImplemented) + with self.subTest(obj=obj): + self.assertNotEqual(obj, other) + + self.assertFalse(obj == other) + self.assertIs(obj.__eq__(other), NotImplemented) + + self.assertTrue(obj != other) + self.assertIs(obj.__ne__(other), NotImplemented) def test_mixed_type_equality(self): # Ensure none of the internal objects accidentally @@ -1006,30 +1022,54 @@ def test_mixed_type_ordering(self): for rhs in self.objects_with_scoped: if isinstance(lhs, type(rhs)) or isinstance(rhs, type(lhs)): continue - self.assertRaises(TypeError, lambda: lhs < rhs) - self.assertRaises(TypeError, lambda: lhs > rhs) - self.assertRaises(TypeError, lambda: lhs <= rhs) - self.assertRaises(TypeError, lambda: lhs >= rhs) + + for dunder in ["__lt__", "__le__", "__ge__", "__gt__"]: + with self.subTest(dunder, lhs=lhs, rhs=rhs): + func = getattr(operator, dunder) + # dunders raise a TypeError or return NotImplemented + lhs_method = getattr(lhs, dunder) + try: + self.assertIs(lhs_method(rhs), NotImplemented) + except TypeError as exc: + self.assertIn("version", str(exc)) + rhs_method = getattr(rhs, dunder) + try: + self.assertIs(rhs_method(lhs), NotImplemented) + except TypeError as exc: + self.assertIn("version", str(exc)) + # Using the comparison operator directly must + # raise a TypeError, either because we returned + # NotImplemented or because of incompatible versions. + self.assertRaises(TypeError, func, lhs, rhs) def test_foreign_type_ordering(self): other = object() for obj in self.objects_with_scoped: - with self.assertRaises(TypeError): - obj < other - with self.assertRaises(TypeError): - obj > other - with self.assertRaises(TypeError): - obj <= other - with self.assertRaises(TypeError): - obj >= other - self.assertTrue(obj < LARGEST) - self.assertFalse(obj > LARGEST) - self.assertTrue(obj <= LARGEST) - self.assertFalse(obj >= LARGEST) - self.assertFalse(obj < SMALLEST) - self.assertTrue(obj > SMALLEST) - self.assertFalse(obj <= SMALLEST) - self.assertTrue(obj >= SMALLEST) + with self.subTest(obj=obj): + for dunder in ["__lt__", "__le__", "__ge__", "__gt__"]: + with self.subTest(dunder): + via_meth = getattr(obj, dunder) + self.assertIs(via_meth(other), NotImplemented) + via_op = getattr(operator, dunder) + self.assertRaises(TypeError, via_op, obj, other) + + self.assertIs(obj.__lt__(LARGEST), NotImplemented) + self.assertTrue(obj < LARGEST) + self.assertIs(obj.__le__(LARGEST), NotImplemented) + self.assertTrue(obj <= LARGEST) + self.assertIs(obj.__ge__(LARGEST), NotImplemented) + self.assertFalse(obj >= LARGEST) + self.assertIs(obj.__gt__(LARGEST), NotImplemented) + self.assertFalse(obj > LARGEST) + + self.assertIs(obj.__lt__(SMALLEST), NotImplemented) + self.assertFalse(obj < SMALLEST) + self.assertIs(obj.__le__(SMALLEST), NotImplemented) + self.assertFalse(obj <= SMALLEST) + self.assertIs(obj.__ge__(SMALLEST), NotImplemented) + self.assertTrue(obj >= SMALLEST) + self.assertIs(obj.__gt__(SMALLEST), NotImplemented) + self.assertTrue(obj > SMALLEST) def test_mixed_type_key(self): # with get_mixed_type_key, you can sort addresses and network. @@ -1079,6 +1119,33 @@ def test_incompatible_versions(self): self.assertRaises(TypeError, v6net_scoped.__lt__, v4net) self.assertRaises(TypeError, v6net_scoped.__gt__, v4net) + def test_object_compare_with_always_equal(self): + # Check that __eq__/__lt__ for IP objects work for non-IP + # objects that share the same attributes as IP objects. + class AlwaysEqual: + version = None + def __eq__(self, other): + return True + + same_object = AlwaysEqual() + for obj in self.objects_with_scoped: + with self.subTest(obj=obj): + self.assertEqual(obj, same_object) + + def test_object_compare_with_always_smallest(self): + @total_ordering + class Smallest: + version = None + def __lt__(self, other): + return True + + smallest = Smallest() + for obj in self.objects_with_scoped: + with self.subTest(obj=obj): + # ensure that we dispatch to Smallest.__lt__ instead. + self.assertIs(obj.__lt__(smallest), NotImplemented) + self.assertLess(smallest, obj) + class IpaddrUnitTest(unittest.TestCase): diff --git a/Misc/NEWS.d/next/Library/2025-11-22-14-57-08.gh-issue-141647.PowRcQ.rst b/Misc/NEWS.d/next/Library/2025-11-22-14-57-08.gh-issue-141647.PowRcQ.rst new file mode 100644 index 00000000000000..7d2e9fd910c3a1 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2025-11-22-14-57-08.gh-issue-141647.PowRcQ.rst @@ -0,0 +1,5 @@ +:mod:`ipaddress`: fix comparison operators :class:`~ipaddress.IPv4Address`, +:class:`~ipaddress.IPv6Address`, :class:`~ipaddress.IPv4Network`, +:class:`~ipaddress.IPv6Network`, :class:`~ipaddress.IPv4Interface`, and +:class:`~ipaddress.IPv6Interface` to avoid comparing instances of incorrect +types. Patch by Bénédikt Tran and yihong0618. From b20b15a22342b46f8214ed4bd644f6ba67f3cb1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?B=C3=A9n=C3=A9dikt=20Tran?= <10796600+picnixz@users.noreply.github.com> Date: Sat, 29 Nov 2025 10:50:33 +0100 Subject: [PATCH 2/4] revert PEP-8 --- Lib/ipaddress.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Lib/ipaddress.py b/Lib/ipaddress.py index 2152fe886c9a5c..79ecd625ad682c 100644 --- a/Lib/ipaddress.py +++ b/Lib/ipaddress.py @@ -715,9 +715,9 @@ def __lt__(self, other): def __eq__(self, other): if not isinstance(other, _BaseNetwork): return NotImplemented - return (self.version == other.version - and self.network_address == other.network_address - and int(self.netmask._ip) == int(other.netmask)) + return (self.version == other.version and + self.network_address == other.network_address and + int(self.netmask._ip) == int(other.netmask)) def __hash__(self): return hash((int(self.network_address), int(self.netmask))) From a0d718183e515bcb7a4970e14f10a31df586d6a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?B=C3=A9n=C3=A9dikt=20Tran?= <10796600+picnixz@users.noreply.github.com> Date: Sat, 29 Nov 2025 10:53:40 +0100 Subject: [PATCH 3/4] remove redundant isinstance() --- Lib/ipaddress.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/Lib/ipaddress.py b/Lib/ipaddress.py index 79ecd625ad682c..e15e255c205090 100644 --- a/Lib/ipaddress.py +++ b/Lib/ipaddress.py @@ -1017,20 +1017,21 @@ def is_multicast(self): @staticmethod def _is_subnet_of(a, b): - if not isinstance(a, _BaseNetwork): - raise TypeError(f"{a} is not a network object") - if not isinstance(b, _BaseNetwork): - raise TypeError(f"{b} is not a network object") + # The caller must ensure that 'a' and 'b' are both networks. _check_ip_version(a, b) return (b.network_address <= a.network_address and b.broadcast_address >= a.broadcast_address) def subnet_of(self, other): """Return True if this network is a subnet of other.""" + if not isinstance(other, _BaseNetwork): + raise TypeError(f"{other} is not a network object") return self._is_subnet_of(self, other) def supernet_of(self, other): """Return True if this network is a supernet of other.""" + if not isinstance(other, _BaseNetwork): + raise TypeError(f"{other} is not a network object") return self._is_subnet_of(other, self) @property From 100dad9e2c46247194e2b83205170b679d66f5b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?B=C3=A9n=C3=A9dikt=20Tran?= <10796600+picnixz@users.noreply.github.com> Date: Sat, 29 Nov 2025 10:55:53 +0100 Subject: [PATCH 4/4] improve error messages for invalid objects --- Lib/ipaddress.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/Lib/ipaddress.py b/Lib/ipaddress.py index e15e255c205090..e73aac371b31a8 100644 --- a/Lib/ipaddress.py +++ b/Lib/ipaddress.py @@ -737,7 +737,7 @@ def __contains__(self, other): def overlaps(self, other): """Tell if self is partly contained in other.""" if not isinstance(other, _BaseNetwork): - raise TypeError("%s is not a network object" % other) + raise TypeError(f"expecting a network object, not {type(other)}") return self.network_address in other or ( self.broadcast_address in other or ( other.network_address in self or ( @@ -818,7 +818,7 @@ def address_exclude(self, other): """ if not isinstance(other, _BaseNetwork): - raise TypeError("%s is not a network object" % other) + raise TypeError(f"expecting a network object, not {type(other)}") _check_ip_version(self, other) if not other.subnet_of(self): raise ValueError('%s not contained in %s' % (other, self)) @@ -885,7 +885,7 @@ def compare_networks(self, other): """ if not isinstance(other, _BaseNetwork): - raise TypeError("%s is not a network object" % other) + raise TypeError(f"expecting a network object, not {type(other)}") _check_ip_version(self, other) # self.version == other.version below here: if self.network_address < other.network_address: @@ -1025,13 +1025,13 @@ def _is_subnet_of(a, b): def subnet_of(self, other): """Return True if this network is a subnet of other.""" if not isinstance(other, _BaseNetwork): - raise TypeError(f"{other} is not a network object") + raise TypeError(f"expecting a network object, not {type(other)}") return self._is_subnet_of(self, other) def supernet_of(self, other): """Return True if this network is a supernet of other.""" if not isinstance(other, _BaseNetwork): - raise TypeError(f"{other} is not a network object") + raise TypeError(f"expecting a network object, not {type(other)}") return self._is_subnet_of(other, self) @property