diff --git a/src/requests/utils.py b/src/requests/utils.py index ae6c42f6cb..5e7f80be5f 100644 --- a/src/requests/utils.py +++ b/src/requests/utils.py @@ -9,6 +9,7 @@ import codecs import contextlib import io +import ipaddress import os import re import socket @@ -687,11 +688,12 @@ def address_in_network(ip, net): :rtype: bool """ - ipaddr = struct.unpack("=L", socket.inet_aton(ip))[0] - netaddr, bits = net.split("/") - netmask = struct.unpack("=L", socket.inet_aton(dotted_netmask(int(bits))))[0] - network = struct.unpack("=L", socket.inet_aton(netaddr))[0] & netmask - return (ipaddr & netmask) == (network & netmask) + try: + ip_address = ipaddress.ip_address(ip) + network = ipaddress.ip_network(net) + return ip_address in network + except (ipaddress.AddressValueError, ValueError): + return False def dotted_netmask(mask): @@ -710,8 +712,8 @@ def is_ipv4_address(string_ip): :rtype: bool """ try: - socket.inet_aton(string_ip) - except OSError: + ipaddress.IPv4Address(string_ip) + except ipaddress.AddressValueError: return False return True @@ -722,22 +724,11 @@ def is_valid_cidr(string_network): :rtype: bool """ - if string_network.count("/") == 1: - try: - mask = int(string_network.split("/")[1]) - except ValueError: - return False - - if mask < 1 or mask > 32: - return False - - try: - socket.inet_aton(string_network.split("/")[0]) - except OSError: - return False - else: + try: + interface = ipaddress.ip_interface(string_network) + except (ipaddress.AddressValueError, ValueError): return False - return True + return string_network in (interface.compressed, interface.exploded) @contextlib.contextmanager diff --git a/tests/test_utils.py b/tests/test_utils.py index 5e9b56ea64..31f7416441 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -257,7 +257,14 @@ class TestIsIPv4Address: def test_valid(self): assert is_ipv4_address("8.8.8.8") - @pytest.mark.parametrize("value", ("8.8.8.8.8", "localhost.localdomain")) + @pytest.mark.parametrize( + "value", + ( + "8.8.8.8.8", + "1.1.1.1 someone was here...", + "localhost.localdomain", + ), + ) def test_invalid(self, value): assert not is_ipv4_address(value) @@ -274,6 +281,7 @@ def test_valid(self): "192.168.1.0/128", "192.168.1.0/-1", "192.168.1.999/24", + "1.1.1.1 something/24", ), ) def test_invalid(self, value): @@ -284,8 +292,17 @@ class TestAddressInNetwork: def test_valid(self): assert address_in_network("192.168.1.1", "192.168.1.0/24") - def test_invalid(self): - assert not address_in_network("172.16.0.1", "192.168.1.0/24") + @pytest.mark.parametrize( + "ip, net", + ( + ("172.16.0.1", "192.168.1.0/24"), + ("1.1.1.1", "1.1.1.1/24"), + ("1.1.1.1wtf", "1.1.1.1/24"), + ("1.1.1.1 wtf", "1.1.1.1/24"), + ), + ) + def test_invalid(self, ip, net): + assert not address_in_network(ip, net) class TestGuessFilename: