From ac5ba7920a0b4f2bc5e7613c0f95b342d3143c9a Mon Sep 17 00:00:00 2001 From: Dan Harbin Date: Wed, 16 Sep 2015 13:59:28 -0700 Subject: [PATCH] Support Unix domain socket (local) forwarding --- sshtunnel.py | 83 ++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 71 insertions(+), 12 deletions(-) diff --git a/sshtunnel.py b/sshtunnel.py index 1f8b5bf6..3742723b 100755 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -103,6 +103,7 @@ def do_something(port): import warnings from select import select from os.path import expanduser +from os import unlink if sys.version_info.major < 3: string_types = basestring, @@ -149,10 +150,13 @@ def check_address(address): >>> check_address(("127.0.0.1", 22)) """ - assert isinstance(address, tuple), "ADDRESS is not a tuple ({0})" \ - .format(type(address).__name__) - check_host(address[0]) - check_port(address[1]) + if isinstance(address, tuple): + check_host(address[0]) + check_port(address[1]) + return + assert isinstance(address, basestring),\ + "ADDRESS is not a tuple, string, or character buffer ({0})"\ + .format(type(address).__name__) def check_addresses(address_list): @@ -192,7 +196,9 @@ def create_logger(logger=None, loglevel=DEFAULT_LOGLEVEL): def address_to_str(address): - return "{0[0]}:{0[1]}".format(address) + if isinstance(address, tuple): + return "{0[0]}:{0[1]}".format(address) + return str(address) def get_connection_id(): @@ -237,9 +243,15 @@ def handle(self): uid = get_connection_id() try: assert isinstance(self.remote_address, tuple) + if isinstance(self.request.getpeername(), tuple): + src_address = self.request.getpeername() + else: + # paramiko wants a tuple, so we'll give it a tuple + src_address = ('dummy', 12345) + chan = self.ssh_transport.open_channel('direct-tcpip', self.remote_address, - self.request.getpeername()) + src_address) except AssertionError: msg = 'Remote address MUST be a tuple (ip:port): {0}' \ .format(self.remote_address) @@ -329,6 +341,36 @@ class _ThreadingForwardServer(SocketServer.ThreadingMixIn, _ForwardServer): daemon_threads = True +class _UnixStreamForwardServer(SocketServer.UnixStreamServer, _ForwardServer): # Not Threading + """ + Serve over Unix domain sockets (does not work on Windows) + """ + @property + def local_host(self): + return None + + @property + def local_port(self): + return None + + +class _ThreadingForwardServer(SocketServer.ThreadingMixIn, _ForwardServer): + """ + Allows concurrent connections to each tunnel + """ + # Will cleanly stop threads created by ThreadingMixIn when quitting + daemon_threads = True + + +class _ThreadingUnixStreamForwardServer(SocketServer.ThreadingMixIn, _UnixStreamForwardServer): # Not Threading + """ + Allows concurrent connections to each tunnel + """ + # Will cleanly stop threads created by ThreadingMixIn when quitting + daemon_threads = True + + + class SSHTunnelForwarder(object): """ Class for forward remote server port throw SSH tunnel to local port. @@ -419,12 +461,20 @@ class Handler(_Handler): def make_ssh_forward_server_class(self, remote_address_): return _ThreadingForwardServer if self._threaded else _ForwardServer + def make_unix_ssh_forward_server_class(self, remote_address_): + return _ThreadingUnixStreamForwardServer if self._threaded else _UnixStreamForwardServer + def make_ssh_forward_server(self, remote_address, local_bind_address): """ Make SSH forward proxy Server class. """ _Handler = self.make_ssh_forward_handler_class(remote_address) - _Server = self.make_ssh_forward_server_class(remote_address) + + if isinstance(local_bind_address, (list, tuple)): + _Server = self.make_ssh_forward_server_class(remote_address) + else: + _Server = self.make_unix_ssh_forward_server_class(remote_address) + try: return _Server(local_bind_address, _Handler) except IOError: @@ -480,7 +530,7 @@ def __init__( use *remote_bind_addresses* if you want open more than one tunnel else use *remote_bind_address* - *local_bind_address* - (ip, port) If None uses ("0.0.0.0", RANDOM) + *local_bind_address* - (ip, port) Or string Or character buffer If None uses ("0.0.0.0", RANDOM) *local_bind_addresses* - [(ip1, port_1), (ip_2, port2), ...] If None uses [local_bind_address] @@ -628,7 +678,7 @@ def __init__( 'Problem with make ssh {0} <> {1} forwarder. You can ' 'suppres this exception by using the ' '`raise_exception_if_any_forwarder_have_a_problem` ' - 'argument'.format(address_to_str(l), address_to_str(l)) + 'argument'.format(address_to_str(l), address_to_str(r)) ) except paramiko.SSHException: @@ -686,7 +736,7 @@ def start(self): for _srv in self._server_list: self.tunnel_is_up[_srv.local_address] = \ - self.local_is_up(_srv.local_address) + True if isinstance(_srv, _UnixStreamForwardServer) else self.local_is_up(_srv.local_address) if not any(self.tunnel_is_up.values()): self.logger.error("An error occurred while opening tunnels.") @@ -759,6 +809,13 @@ def stop(self): self.logger.info('Shutting down tunnel ' + local_address_text) _srv.shutdown() _srv.server_close() + # clean up the unix domain socket if we're using one + if isinstance(_srv, _UnixStreamForwardServer): + try: + unlink(_srv.local_address) + except Exception as e: + self.logger.error('Unable to unlink {0}: {1}' + .format(self.local_address, repr(e))) self._transport.close() self._transport.stop_thread() @@ -799,7 +856,7 @@ def local_bind_ports(self): Returns a list containing the ports of local side of the TCP tunnels """ self._check_is_started() - return [_server.local_port for _server in self._server_list] + return [_server.local_port for _server in self._server_list if _server.local_port is not None] @property def local_bind_hosts(self): @@ -807,7 +864,7 @@ def local_bind_hosts(self): Returns a list containing the IP addresses listening for the tunnels """ self._check_is_started() - return [_server.local_host for _server in self._server_list] + return [_server.local_host for _server in self._server_list if _server.local_host is not None] @property def local_bind_addresses(self): @@ -973,6 +1030,8 @@ def main(): dest='local_bind_addresses', metavar='IP:PORT', help='Local bind address sequence: ' 'ip_1:port_1 ip_2:port_2 ... ip_n:port_n\n' + 'or unix domain sockets: ' + '/tmp/foo.sock /tmp/bar.sock ... /tmp/baz.sock\n' 'Equivalent to ssh -LPORT:xxxxxxxxx:xxxx, ' 'being the local IP address optional.\n' 'By default it will listen in all interfaces '