Skip to content

Commit

Permalink
Support Unix domain socket (local) forwarding
Browse files Browse the repository at this point in the history
  • Loading branch information
RasterBurn committed Sep 16, 2015
1 parent b66e1fe commit ac5ba79
Showing 1 changed file with 71 additions and 12 deletions.
83 changes: 71 additions & 12 deletions sshtunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def do_something(port):
import warnings
from select import select
from os.path import expanduser
from os import unlink

if sys.version_info.major < 3:
string_types = basestring,
Expand Down Expand Up @@ -149,10 +150,13 @@ def check_address(address):
>>> check_address(("127.0.0.1", 22))
"""
assert isinstance(address, tuple), "ADDRESS is not a tuple ({0})" \
.format(type(address).__name__)
check_host(address[0])
check_port(address[1])
if isinstance(address, tuple):
check_host(address[0])
check_port(address[1])
return
assert isinstance(address, basestring),\
"ADDRESS is not a tuple, string, or character buffer ({0})"\
.format(type(address).__name__)


def check_addresses(address_list):
Expand Down Expand Up @@ -192,7 +196,9 @@ def create_logger(logger=None, loglevel=DEFAULT_LOGLEVEL):


def address_to_str(address):
return "{0[0]}:{0[1]}".format(address)
if isinstance(address, tuple):
return "{0[0]}:{0[1]}".format(address)
return str(address)


def get_connection_id():
Expand Down Expand Up @@ -237,9 +243,15 @@ def handle(self):
uid = get_connection_id()
try:
assert isinstance(self.remote_address, tuple)
if isinstance(self.request.getpeername(), tuple):
src_address = self.request.getpeername()
else:
# paramiko wants a tuple, so we'll give it a tuple
src_address = ('dummy', 12345)

chan = self.ssh_transport.open_channel('direct-tcpip',
self.remote_address,
self.request.getpeername())
src_address)
except AssertionError:
msg = 'Remote address MUST be a tuple (ip:port): {0}' \
.format(self.remote_address)
Expand Down Expand Up @@ -329,6 +341,36 @@ class _ThreadingForwardServer(SocketServer.ThreadingMixIn, _ForwardServer):
daemon_threads = True


class _UnixStreamForwardServer(SocketServer.UnixStreamServer, _ForwardServer): # Not Threading
"""
Serve over Unix domain sockets (does not work on Windows)
"""
@property
def local_host(self):
return None

@property
def local_port(self):
return None


class _ThreadingForwardServer(SocketServer.ThreadingMixIn, _ForwardServer):
"""
Allows concurrent connections to each tunnel
"""
# Will cleanly stop threads created by ThreadingMixIn when quitting
daemon_threads = True


class _ThreadingUnixStreamForwardServer(SocketServer.ThreadingMixIn, _UnixStreamForwardServer): # Not Threading
"""
Allows concurrent connections to each tunnel
"""
# Will cleanly stop threads created by ThreadingMixIn when quitting
daemon_threads = True



class SSHTunnelForwarder(object):
"""
Class for forward remote server port throw SSH tunnel to local port.
Expand Down Expand Up @@ -419,12 +461,20 @@ class Handler(_Handler):
def make_ssh_forward_server_class(self, remote_address_):
return _ThreadingForwardServer if self._threaded else _ForwardServer

def make_unix_ssh_forward_server_class(self, remote_address_):
return _ThreadingUnixStreamForwardServer if self._threaded else _UnixStreamForwardServer

def make_ssh_forward_server(self, remote_address, local_bind_address):
"""
Make SSH forward proxy Server class.
"""
_Handler = self.make_ssh_forward_handler_class(remote_address)
_Server = self.make_ssh_forward_server_class(remote_address)

if isinstance(local_bind_address, (list, tuple)):
_Server = self.make_ssh_forward_server_class(remote_address)
else:
_Server = self.make_unix_ssh_forward_server_class(remote_address)

try:
return _Server(local_bind_address, _Handler)
except IOError:
Expand Down Expand Up @@ -480,7 +530,7 @@ def __init__(
use *remote_bind_addresses* if you want open more than one tunnel else
use *remote_bind_address*
*local_bind_address* - (ip, port) If None uses ("0.0.0.0", RANDOM)
*local_bind_address* - (ip, port) Or string Or character buffer If None uses ("0.0.0.0", RANDOM)
*local_bind_addresses* - [(ip1, port_1), (ip_2, port2), ...] If None
uses [local_bind_address]
Expand Down Expand Up @@ -628,7 +678,7 @@ def __init__(
'Problem with make ssh {0} <> {1} forwarder. You can '
'suppres this exception by using the '
'`raise_exception_if_any_forwarder_have_a_problem` '
'argument'.format(address_to_str(l), address_to_str(l))
'argument'.format(address_to_str(l), address_to_str(r))
)

except paramiko.SSHException:
Expand Down Expand Up @@ -686,7 +736,7 @@ def start(self):

for _srv in self._server_list:
self.tunnel_is_up[_srv.local_address] = \
self.local_is_up(_srv.local_address)
True if isinstance(_srv, _UnixStreamForwardServer) else self.local_is_up(_srv.local_address)

if not any(self.tunnel_is_up.values()):
self.logger.error("An error occurred while opening tunnels.")
Expand Down Expand Up @@ -759,6 +809,13 @@ def stop(self):
self.logger.info('Shutting down tunnel ' + local_address_text)
_srv.shutdown()
_srv.server_close()
# clean up the unix domain socket if we're using one
if isinstance(_srv, _UnixStreamForwardServer):
try:
unlink(_srv.local_address)
except Exception as e:
self.logger.error('Unable to unlink {0}: {1}'
.format(self.local_address, repr(e)))

self._transport.close()
self._transport.stop_thread()
Expand Down Expand Up @@ -799,15 +856,15 @@ def local_bind_ports(self):
Returns a list containing the ports of local side of the TCP tunnels
"""
self._check_is_started()
return [_server.local_port for _server in self._server_list]
return [_server.local_port for _server in self._server_list if _server.local_port is not None]

@property
def local_bind_hosts(self):
"""
Returns a list containing the IP addresses listening for the tunnels
"""
self._check_is_started()
return [_server.local_host for _server in self._server_list]
return [_server.local_host for _server in self._server_list if _server.local_host is not None]

@property
def local_bind_addresses(self):
Expand Down Expand Up @@ -973,6 +1030,8 @@ def main():
dest='local_bind_addresses', metavar='IP:PORT',
help='Local bind address sequence: '
'ip_1:port_1 ip_2:port_2 ... ip_n:port_n\n'
'or unix domain sockets: '
'/tmp/foo.sock /tmp/bar.sock ... /tmp/baz.sock\n'
'Equivalent to ssh -LPORT:xxxxxxxxx:xxxx, '
'being the local IP address optional.\n'
'By default it will listen in all interfaces '
Expand Down

0 comments on commit ac5ba79

Please sign in to comment.