Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Unix domain socket (local) forwarding #31

Merged
merged 1 commit into from
Mar 8, 2016
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 71 additions & 12 deletions sshtunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def do_something(port):
import warnings
from select import select
from os.path import expanduser
from os import unlink

if sys.version_info.major < 3:
string_types = basestring,
Expand Down Expand Up @@ -149,10 +150,13 @@ def check_address(address):

>>> check_address(("127.0.0.1", 22))
"""
assert isinstance(address, tuple), "ADDRESS is not a tuple ({0})" \
.format(type(address).__name__)
check_host(address[0])
check_port(address[1])
if isinstance(address, tuple):
check_host(address[0])
check_port(address[1])
return
assert isinstance(address, basestring),\
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this python 3.x compatible?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use string_types

"ADDRESS is not a tuple, string, or character buffer ({0})"\
.format(type(address).__name__)


def check_addresses(address_list):
Expand Down Expand Up @@ -192,7 +196,9 @@ def create_logger(logger=None, loglevel=DEFAULT_LOGLEVEL):


def address_to_str(address):
return "{0[0]}:{0[1]}".format(address)
if isinstance(address, tuple):
return "{0[0]}:{0[1]}".format(address)
return str(address)


def get_connection_id():
Expand Down Expand Up @@ -237,9 +243,15 @@ def handle(self):
uid = get_connection_id()
try:
assert isinstance(self.remote_address, tuple)
if isinstance(self.request.getpeername(), tuple):
src_address = self.request.getpeername()
else:
# paramiko wants a tuple, so we'll give it a tuple
src_address = ('dummy', 12345)

chan = self.ssh_transport.open_channel('direct-tcpip',
self.remote_address,
self.request.getpeername())
src_address)
except AssertionError:
msg = 'Remote address MUST be a tuple (ip:port): {0}' \
.format(self.remote_address)
Expand Down Expand Up @@ -329,6 +341,36 @@ class _ThreadingForwardServer(SocketServer.ThreadingMixIn, _ForwardServer):
daemon_threads = True


class _UnixStreamForwardServer(SocketServer.UnixStreamServer, _ForwardServer): # Not Threading
"""
Serve over Unix domain sockets (does not work on Windows)
"""
@property
def local_host(self):
return None

@property
def local_port(self):
return None


class _ThreadingForwardServer(SocketServer.ThreadingMixIn, _ForwardServer):
"""
Allows concurrent connections to each tunnel
"""
# Will cleanly stop threads created by ThreadingMixIn when quitting
daemon_threads = True


class _ThreadingUnixStreamForwardServer(SocketServer.ThreadingMixIn, _UnixStreamForwardServer): # Not Threading
"""
Allows concurrent connections to each tunnel
"""
# Will cleanly stop threads created by ThreadingMixIn when quitting
daemon_threads = True



class SSHTunnelForwarder(object):
"""
Class for forward remote server port throw SSH tunnel to local port.
Expand Down Expand Up @@ -419,12 +461,20 @@ class Handler(_Handler):
def make_ssh_forward_server_class(self, remote_address_):
return _ThreadingForwardServer if self._threaded else _ForwardServer

def make_unix_ssh_forward_server_class(self, remote_address_):
return _ThreadingUnixStreamForwardServer if self._threaded else _UnixStreamForwardServer

def make_ssh_forward_server(self, remote_address, local_bind_address):
"""
Make SSH forward proxy Server class.
"""
_Handler = self.make_ssh_forward_handler_class(remote_address)
_Server = self.make_ssh_forward_server_class(remote_address)

if isinstance(local_bind_address, (list, tuple)):
_Server = self.make_ssh_forward_server_class(remote_address)
else:
_Server = self.make_unix_ssh_forward_server_class(remote_address)

try:
return _Server(local_bind_address, _Handler)
except IOError:
Expand Down Expand Up @@ -480,7 +530,7 @@ def __init__(
use *remote_bind_addresses* if you want open more than one tunnel else
use *remote_bind_address*

*local_bind_address* - (ip, port) If None uses ("0.0.0.0", RANDOM)
*local_bind_address* - (ip, port) Or string Or character buffer If None uses ("0.0.0.0", RANDOM)
*local_bind_addresses* - [(ip1, port_1), (ip_2, port2), ...] If None
uses [local_bind_address]

Expand Down Expand Up @@ -628,7 +678,7 @@ def __init__(
'Problem with make ssh {0} <> {1} forwarder. You can '
'suppres this exception by using the '
'`raise_exception_if_any_forwarder_have_a_problem` '
'argument'.format(address_to_str(l), address_to_str(l))
'argument'.format(address_to_str(l), address_to_str(r))
)

except paramiko.SSHException:
Expand Down Expand Up @@ -686,7 +736,7 @@ def start(self):

for _srv in self._server_list:
self.tunnel_is_up[_srv.local_address] = \
self.local_is_up(_srv.local_address)
True if isinstance(_srv, _UnixStreamForwardServer) else self.local_is_up(_srv.local_address)

if not any(self.tunnel_is_up.values()):
self.logger.error("An error occurred while opening tunnels.")
Expand Down Expand Up @@ -759,6 +809,13 @@ def stop(self):
self.logger.info('Shutting down tunnel ' + local_address_text)
_srv.shutdown()
_srv.server_close()
# clean up the unix domain socket if we're using one
if isinstance(_srv, _UnixStreamForwardServer):
try:
unlink(_srv.local_address)
except Exception as e:
self.logger.error('Unable to unlink {0}: {1}'
.format(self.local_address, repr(e)))

self._transport.close()
self._transport.stop_thread()
Expand Down Expand Up @@ -799,15 +856,15 @@ def local_bind_ports(self):
Returns a list containing the ports of local side of the TCP tunnels
"""
self._check_is_started()
return [_server.local_port for _server in self._server_list]
return [_server.local_port for _server in self._server_list if _server.local_port is not None]

@property
def local_bind_hosts(self):
"""
Returns a list containing the IP addresses listening for the tunnels
"""
self._check_is_started()
return [_server.local_host for _server in self._server_list]
return [_server.local_host for _server in self._server_list if _server.local_host is not None]

@property
def local_bind_addresses(self):
Expand Down Expand Up @@ -973,6 +1030,8 @@ def main():
dest='local_bind_addresses', metavar='IP:PORT',
help='Local bind address sequence: '
'ip_1:port_1 ip_2:port_2 ... ip_n:port_n\n'
'or unix domain sockets: '
'/tmp/foo.sock /tmp/bar.sock ... /tmp/baz.sock\n'
'Equivalent to ssh -LPORT:xxxxxxxxx:xxxx, '
'being the local IP address optional.\n'
'By default it will listen in all interfaces '
Expand Down