Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[networking] Fix various bugs in socks proxy implementation #8065

Merged
merged 12 commits into from
Sep 18, 2023
38 changes: 9 additions & 29 deletions test/test_socks.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,17 +281,13 @@ def test_socks4_auth(self, handler, ctx):
rh, proxies={'all': f'socks4://user:@{server_address}'})
assert response['version'] == 4

@pytest.mark.parametrize('handler,ctx', [
pytest.param('Urllib', 'http', marks=pytest.mark.xfail(
reason='socks4a implementation currently broken when destination is not a domain name'))
], indirect=True)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True)
def test_socks4a_ipv4_target(self, handler, ctx):
with ctx.socks_server(Socks4ProxyHandler) as server_address:
with handler(proxies={'all': f'socks4a://{server_address}'}) as rh:
response = ctx.socks_info_request(rh, target_domain='127.0.0.1')
assert response['version'] == 4
assert response['ipv4_address'] == '127.0.0.1'
assert response['domain_address'] is None
assert (response['ipv4_address'] == '127.0.0.1') != (response['domain_address'] == '127.0.0.1')

@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True)
def test_socks4a_domain_target(self, handler, ctx):
Expand All @@ -302,10 +298,7 @@ def test_socks4a_domain_target(self, handler, ctx):
assert response['ipv4_address'] is None
assert response['domain_address'] == 'localhost'

@pytest.mark.parametrize('handler,ctx', [
pytest.param('Urllib', 'http', marks=pytest.mark.xfail(
reason='source_address is not yet supported for socks4 proxies'))
], indirect=True)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True)
def test_ipv4_client_source_address(self, handler, ctx):
with ctx.socks_server(Socks4ProxyHandler) as server_address:
source_address = f'127.0.0.{random.randint(5, 255)}'
Expand All @@ -327,10 +320,7 @@ def test_socks4_errors(self, handler, ctx, reply_code):
with pytest.raises(ProxyError):
ctx.socks_info_request(rh)

@pytest.mark.parametrize('handler,ctx', [
pytest.param('Urllib', 'http', marks=pytest.mark.xfail(
reason='IPv6 socks4 proxies are not yet supported'))
], indirect=True)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True)
def test_ipv6_socks4_proxy(self, handler, ctx):
with ctx.socks_server(Socks4ProxyHandler, bind_ip='::1') as server_address:
with handler(proxies={'all': f'socks4://{server_address}'}) as rh:
Expand All @@ -342,7 +332,7 @@ def test_ipv6_socks4_proxy(self, handler, ctx):
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True)
def test_timeout(self, handler, ctx):
with ctx.socks_server(Socks4ProxyHandler, sleep=2) as server_address:
with handler(proxies={'all': f'socks4://{server_address}'}, timeout=1) as rh:
with handler(proxies={'all': f'socks4://{server_address}'}, timeout=0.5) as rh:
with pytest.raises(TransportError):
ctx.socks_info_request(rh)

Expand Down Expand Up @@ -383,7 +373,7 @@ def test_socks5_domain_target(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address:
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
response = ctx.socks_info_request(rh, target_domain='localhost')
assert response['ipv4_address'] == '127.0.0.1'
assert (response['ipv4_address'] == '127.0.0.1') != (response['ipv6_address'] == '::1')
assert response['version'] == 5

@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True)
Expand All @@ -404,22 +394,15 @@ def test_socks5h_ip_target(self, handler, ctx):
assert response['domain_address'] is None
assert response['version'] == 5

@pytest.mark.parametrize('handler,ctx', [
pytest.param('Urllib', 'http', marks=pytest.mark.xfail(
reason='IPv6 destination addresses are not yet supported'))
], indirect=True)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True)
def test_socks5_ipv6_destination(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address:
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
response = ctx.socks_info_request(rh, target_domain='[::1]')
assert response['ipv6_address'] == '::1'
assert response['port'] == 80
assert response['version'] == 5

@pytest.mark.parametrize('handler,ctx', [
pytest.param('Urllib', 'http', marks=pytest.mark.xfail(
reason='IPv6 socks5 proxies are not yet supported'))
], indirect=True)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True)
def test_ipv6_socks5_proxy(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler, bind_ip='::1') as server_address:
with handler(proxies={'all': f'socks5://{server_address}'}) as rh:
Expand All @@ -430,10 +413,7 @@ def test_ipv6_socks5_proxy(self, handler, ctx):

# XXX: is there any feasible way of testing IPv6 source addresses?
# Same would go for non-proxy source_address test...
@pytest.mark.parametrize('handler,ctx', [
pytest.param('Urllib', 'http', marks=pytest.mark.xfail(
reason='source_address is not yet supported for socks5 proxies'))
], indirect=True)
@pytest.mark.parametrize('handler,ctx', [('Urllib', 'http')], indirect=True)
def test_ipv4_client_source_address(self, handler, ctx):
with ctx.socks_server(Socks5ProxyHandler) as server_address:
source_address = f'127.0.0.{random.randint(5, 255)}'
Expand Down
57 changes: 57 additions & 0 deletions yt_dlp/networking/_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import contextlib
import functools
import socket
import ssl
import sys
import typing
Expand Down Expand Up @@ -206,3 +207,59 @@ def wrapper(self, *args, **kwargs):
e.handler = self
raise
return wrapper


def _socket_connect(ip_addr, timeout, source_address):
af, socktype, proto, canonname, sa = ip_addr
sock = socket.socket(af, socktype, proto)
try:
if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
sock.settimeout(timeout)
if source_address:
sock.bind(source_address)
sock.connect(sa)
return sock
except socket.error:
sock.close()
raise


def create_connection(
address,
timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
source_address=None,
*,
_create_socket_func=_socket_connect
):
# Work around socket.create_connection() which tries all addresses from getaddrinfo() including IPv6.
# This filters the addresses based on the given source_address.
# Based on: https://github.com/python/cpython/blob/main/Lib/socket.py#L810
host, port = address
ip_addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM)
if not ip_addrs:
raise socket.error('getaddrinfo returns an empty list')
if source_address is not None:
af = socket.AF_INET if ':' not in source_address[0] else socket.AF_INET6
ip_addrs = [addr for addr in ip_addrs if addr[0] == af]
if not ip_addrs:
raise OSError(
f'No remote IPv{4 if af == socket.AF_INET else 6} addresses available for connect. '
f'Can\'t use "{source_address[0]}" as source address')

err = None
for ip_addr in ip_addrs:
coletdjnz marked this conversation as resolved.
Show resolved Hide resolved
try:
sock = _create_socket_func(ip_addr, timeout, source_address)
# Explicitly break __traceback__ reference cycle
# https://bugs.python.org/issue36820
err = None
return sock
except socket.error as e:
err = e

try:
raise err
finally:
# Explicitly break __traceback__ reference cycle
# https://bugs.python.org/issue36820
err = None
68 changes: 25 additions & 43 deletions yt_dlp/networking/_urllib.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ._helper import (
InstanceStoreMixin,
add_accept_encoding_header,
create_connection,
get_redirect_method,
make_socks_proxy_opts,
select_proxy,
Expand Down Expand Up @@ -54,44 +55,10 @@
def _create_http_connection(http_class, source_address, *args, **kwargs):
hc = http_class(*args, **kwargs)

if hasattr(hc, '_create_connection'):
hc._create_connection = create_connection
Grub4K marked this conversation as resolved.
Show resolved Hide resolved

if source_address is not None:
# This is to workaround _create_connection() from socket where it will try all
# address data from getaddrinfo() including IPv6. This filters the result from
# getaddrinfo() based on the source_address value.
# This is based on the cpython socket.create_connection() function.
# https://github.com/python/cpython/blob/master/Lib/socket.py#L691
def _create_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, source_address=None):
host, port = address
err = None
addrs = socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM)
af = socket.AF_INET if '.' in source_address[0] else socket.AF_INET6
ip_addrs = [addr for addr in addrs if addr[0] == af]
if addrs and not ip_addrs:
ip_version = 'v4' if af == socket.AF_INET else 'v6'
raise OSError(
"No remote IP%s addresses available for connect, can't use '%s' as source address"
% (ip_version, source_address[0]))
for res in ip_addrs:
af, socktype, proto, canonname, sa = res
sock = None
try:
sock = socket.socket(af, socktype, proto)
if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT:
sock.settimeout(timeout)
sock.bind(source_address)
sock.connect(sa)
err = None # Explicitly break reference cycle
return sock
except OSError as _:
err = _
if sock is not None:
sock.close()
if err is not None:
raise err
else:
raise OSError('getaddrinfo returns an empty list')
if hasattr(hc, '_create_connection'):
hc._create_connection = _create_connection
hc.source_address = (source_address, 0)

return hc
Expand Down Expand Up @@ -220,13 +187,28 @@ def make_socks_conn_class(base_class, socks_proxy):
proxy_args = make_socks_proxy_opts(socks_proxy)

class SocksConnection(base_class):
def connect(self):
self.sock = sockssocket()
self.sock.setproxy(**proxy_args)
if type(self.timeout) in (int, float): # noqa: E721
self.sock.settimeout(self.timeout)
self.sock.connect((self.host, self.port))
_create_connection = create_connection

def connect(self):
def sock_socket_connect(ip_addr, timeout, source_address):
af, socktype, proto, canonname, sa = ip_addr
sock = sockssocket(af, socktype, proto)
try:
connect_proxy_args = proxy_args.copy()
connect_proxy_args.update({'addr': sa[0], 'port': sa[1]})
sock.setproxy(**connect_proxy_args)
if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: # noqa: E721
sock.settimeout(timeout)
if source_address:
sock.bind(source_address)
sock.connect((self.host, self.port))
return sock
except socket.error:
sock.close()
raise
Grub4K marked this conversation as resolved.
Show resolved Hide resolved
self.sock = create_connection(
(proxy_args['addr'], proxy_args['port']), timeout=self.timeout,
source_address=self.source_address, _create_socket_func=sock_socket_connect)
if isinstance(self, http.client.HTTPSConnection):
self.sock = self._context.wrap_socket(self.sock, server_hostname=self.host)

Expand Down
31 changes: 19 additions & 12 deletions yt_dlp/socks.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,26 +134,31 @@ def _check_response_version(self, expected_version, got_version):
self.close()
raise InvalidVersionError(expected_version, got_version)

def _resolve_address(self, destaddr, default, use_remote_dns):
try:
return socket.inet_aton(destaddr)
except OSError:
if use_remote_dns and self._proxy.remote_dns:
return default
else:
return socket.inet_aton(socket.gethostbyname(destaddr))
def _resolve_address(self, destaddr, default, use_remote_dns, family=None):
for f in (family,) if family else (socket.AF_INET, socket.AF_INET6):
try:
return f, socket.inet_pton(f, destaddr)
except OSError:
continue

if use_remote_dns and self._proxy.remote_dns:
return 0, default
else:
res = socket.getaddrinfo(destaddr, None, family=family or 0)
f, _, _, _, ipaddr = res[0]
return f, socket.inet_pton(f, ipaddr[0])
coletdjnz marked this conversation as resolved.
Show resolved Hide resolved

def _setup_socks4(self, address, is_4a=False):
destaddr, port = address

ipaddr = self._resolve_address(destaddr, SOCKS4_DEFAULT_DSTIP, use_remote_dns=is_4a)
_, ipaddr = self._resolve_address(destaddr, SOCKS4_DEFAULT_DSTIP, use_remote_dns=is_4a, family=socket.AF_INET)

packet = struct.pack('!BBH', SOCKS4_VERSION, Socks4Command.CMD_CONNECT, port) + ipaddr

username = (self._proxy.username or '').encode()
packet += username + b'\x00'

if is_4a and self._proxy.remote_dns:
if is_4a and self._proxy.remote_dns and ipaddr == SOCKS4_DEFAULT_DSTIP:
packet += destaddr.encode() + b'\x00'

self.sendall(packet)
Expand Down Expand Up @@ -210,7 +215,7 @@ def _socks5_auth(self):
def _setup_socks5(self, address):
destaddr, port = address

ipaddr = self._resolve_address(destaddr, None, use_remote_dns=True)
family, ipaddr = self._resolve_address(destaddr, None, use_remote_dns=True)

self._socks5_auth()

Expand All @@ -220,8 +225,10 @@ def _setup_socks5(self, address):
destaddr = destaddr.encode()
packet += struct.pack('!B', Socks5AddressType.ATYP_DOMAINNAME)
packet += self._len_and_data(destaddr)
else:
elif family == socket.AF_INET:
packet += struct.pack('!B', Socks5AddressType.ATYP_IPV4) + ipaddr
elif family == socket.AF_INET6:
packet += struct.pack('!B', Socks5AddressType.ATYP_IPV6) + ipaddr
packet += struct.pack('!H', port)

self.sendall(packet)
Expand Down