Skip to content

Commit

Permalink
Merged with PR #31
Browse files Browse the repository at this point in the history
  • Loading branch information
fernandezcuesta committed Feb 24, 2016
2 parents 6ae5bf4 + ac5ba79 commit 8a83e3a
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 27 deletions.
125 changes: 101 additions & 24 deletions sshtunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,25 @@ def check_address(address):
(``str``, ``int``) representing an IP address and port,
respectively
Raises:
AssertionError:
ValueError:
raised when address has an incorrect format
Example:
>>> check_address(("127.0.0.1", 22))
"""
assert isinstance(address, tuple), "ADDRESS is not a tuple ({0})".format(
type(address).__name__
)
check_host(address[0])
check_port(address[1])
if isinstance(address, tuple):
check_host(address[0])
check_port(address[1])
elif isinstance(address, string_types):
# check if address is a valid UNIX domain socket
if not (os.path.exists(address) or
os.access(os.path.dirname(address), os.W_OK)):
raise ValueError('ADDRESS not a valid socket domain socket ({0})'
.format(address))
else:
raise ValueError('ADDRESS is not a tuple, string, or character buffer '
'({0})'.format(type(address).__name__))


def check_addresses(address_list):
Expand All @@ -108,7 +115,9 @@ def check_addresses(address_list):
>>> check_addresses([("127.0.0.1", 22), ("127.0.0.1", 2222)])
"""
assert isinstance(address_list, (list, tuple))
# TODO: remote addresses should never be UNIX sockets ???
assert (isinstance(address_list, (list, tuple)) or
isinstance(address_list, (list, string_types)))
for address in address_list:
check_address(address)

Expand Down Expand Up @@ -172,7 +181,9 @@ def _check_paramiko_handlers(logger=None):


def address_to_str(address):
return "{0[0]}:{0[1]}".format(address)
if isinstance(address, tuple):
return "{0[0]}:{0[1]}".format(address)
return str(address)


def get_connection_id():
Expand Down Expand Up @@ -220,14 +231,19 @@ class _ForwardHandler(socketserver.BaseRequestHandler):

def handle(self):
uid = get_connection_id()
info = 'In #{0} <-- {1}'.format(uid, self.client_address)
info = 'In #{0} <-- {1}'.format(uid, self.client_address or
self.server.local_address)
try:
assert isinstance(self.remote_address, tuple)
src_address = self.request.getpeername()
self.logger.critical(self.server)
if not isinstance(src_address, tuple):
src_address = ('dummy', 12345)
chan = self.ssh_transport.open_channel('direct-tcpip',
self.remote_address,
self.request.getpeername())
src_address)
except AssertionError:
msg = 'Remote address MUST be a tuple (ip:port): {0}' \
msg = 'Remote address MUST be a tuple (IP:port): {0}' \
.format(self.remote_address)
self.logger.error(msg)
raise HandlerSSHTunnelForwarderError(msg)
Expand All @@ -250,15 +266,15 @@ def handle(self):
if self.request in rqst:
data = self.request.recv(1024)
if TRACE:
self.logger.info('{0} recv: {1}'
self.logger.info('<<< {0} recv: {1} <<<'
.format(info, repr(data)))
chan.send(data)
if len(data) == 0:
break
if chan in rqst:
if chan in rqst: # else
data = chan.recv(1024)
if TRACE:
self.logger.info('{0} recv: {1}'
self.logger.info('>>> {0} recv: {1} >>>'
.format(info, repr(data)))
self.request.send(data)
if len(data) == 0:
Expand Down Expand Up @@ -310,7 +326,29 @@ def remote_port(self):

class _ThreadingForwardServer(socketserver.ThreadingMixIn, _ForwardServer):
"""
Allows concurrent connections to each tunnel
Allow concurrent connections to each tunnel
"""
# If True, cleanly stop threads created by ThreadingMixIn when quitting
daemon_threads = DAEMON


class _UnixStreamForwardServer(socketserver.UnixStreamServer, _ForwardServer):
"""
Serve over UNIX domain sockets (does not work on Windows)
"""
@property
def local_host(self):
return None

@property
def local_port(self):
return None


class _ThreadingUnixStreamForwardServer(socketserver.ThreadingMixIn,
_UnixStreamForwardServer):
"""
Allow concurrent connections to each tunnel
"""
# If True, cleanly stop threads created by ThreadingMixIn when quitting
daemon_threads = DAEMON
Expand Down Expand Up @@ -415,6 +453,10 @@ class SSHTunnelForwarder(object):
are valid values
Default: ("0.0.0.0", ``RANDOM PORT``)
.. versionchanged:: 0.0.8
Added the ability to use a UNIX domain socket as local bind
address
local_bind_addresses (list[tuple]):
In case more than one tunnel is established at once, a list
of tuples (in the same format as ``local_bind_address``)
Expand Down Expand Up @@ -531,12 +573,14 @@ def local_is_up(self, target):
"""
try:
check_address(target)
except AssertionError:
self.logger.warning('Target must be a tuple (ip, port), where ip '
except ValueError:
self.logger.warning('Target must be a tuple (IP, port), where IP '
'is a string (i.e. "192.168.0.1") and port is '
'an integer (i.e. 40000).')
'an integer (i.e. 40000). Alternatively '
'target can be a valid UNIX domain socket.')
return False

if isinstance(target, string_types): # UNIX stream
return True
(host, port) = target
reachable_from = []
self._local_interfaces = self._local_interfaces \
Expand Down Expand Up @@ -580,20 +624,37 @@ class Handler(_Handler):
remote_address = remote_address_
ssh_transport = self._transport
logger = self.logger

# if isinstance(local_bind_address, string_types):
# unix_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
# else:
# unix_socket = None
return Handler

def _make_ssh_forward_server_class(self, remote_address_):
return _ThreadingForwardServer if self._threaded else _ForwardServer

def _make_unix_ssh_forward_server_class(self, remote_address_):
return _ThreadingUnixStreamForwardServer if \
self._threaded else _UnixStreamForwardServer

def _make_ssh_forward_server(self, remote_address, local_bind_address):
"""
Make SSH forward proxy Server class.
"""
_Handler = self._make_ssh_forward_handler_class(remote_address)
_Server = self._make_ssh_forward_server_class(remote_address)

try:
if isinstance(local_bind_address, string_types):
forward_maker_class = self._make_unix_ssh_forward_server_class
else:
forward_maker_class = self._make_ssh_forward_server_class

_Server = forward_maker_class(remote_address)
ssh_forward_server = _Server(local_bind_address, _Handler)

# if _Handler.unix_socket:
# _Handler.unix_socket.connect(local_bind_address)

if ssh_forward_server:
self._server_list.append(ssh_forward_server)
else:
Expand Down Expand Up @@ -992,6 +1053,11 @@ def start(self):
self.check_local_side_of_tunnels()
self._is_started = True

def restart(self):
""" Restart connection to the gateway and tunnels """
self.stop()
self.start()

def _connect_to_gateway(self):
"""
Open connection to SSH gateway
Expand Down Expand Up @@ -1101,6 +1167,13 @@ def stop(self):
)
_srv.shutdown()
_srv.server_close()
# clean up the UNIX domain socket if we're using one
if isinstance(_srv, _UnixStreamForwardServer):
try:
os.unlink(_srv.local_address)
except Exception as e:
self.logger.error('Unable to unlink socket {0}: {1}'
.format(self.local_address, repr(e)))
self._stop_transport()
self._is_started = False

Expand Down Expand Up @@ -1144,15 +1217,17 @@ def local_bind_ports(self):
Return a list containing the ports of local side of the TCP tunnels
"""
self._check_is_started()
return [_server.local_port for _server in self._server_list]
return [_server.local_port for _server in self._server_list if
_server.local_port is not None]

@property
def local_bind_hosts(self):
"""
Return a list containing the IP addresses listening for the tunnels
"""
self._check_is_started()
return [_server.local_host for _server in self._server_list]
return [_server.local_host for _server in self._server_list if
_server.local_host is not None]

@property
def local_bind_addresses(self):
Expand All @@ -1165,7 +1240,7 @@ def _get_local_interfaces():
Return all local network interface's IP addresses
"""
local_if = socket.gethostbyname_ex(socket.gethostname())[-1]
# In Linux, if /etc/hosts is populated with the hostname it will only
# In Linux, if /etc/hosts is populated with the hostname, it will only
# return 127.0.0.1
if '127.0.0.1' not in local_if:
local_if.append('127.0.0.1')
Expand Down Expand Up @@ -1370,6 +1445,8 @@ def _parse_arguments(args=None):
dest='local_bind_addresses', metavar='IP:PORT',
help='Local bind address sequence: '
'ip_1:port_1 ip_2:port_2 ... ip_n:port_n\n'
'Elements may also be valid UNIX socket domains: \n'
'/tmp/foo.sock /tmp/bar.sock ... /tmp/baz.sock\n'
'Equivalent to ssh -LPORT:xxxxxxxxx:xxxx, '
'being the local IP address optional.\n'
'By default it will listen in all interfaces '
Expand Down
7 changes: 4 additions & 3 deletions tests/test_forwarder.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,10 +951,11 @@ def test_local_is_up(self):
self.assertIn('An error occurred while opening tunnels.',
self.sshtunnel_log_messages['error'])

self.assertFalse(server.local_is_up("not a tuple"))
self.assertIn('Target must be a tuple (ip, port), where ip '
self.assertFalse(server.local_is_up("not a valid address"))
self.assertIn('Target must be a tuple (IP, port), where IP '
'is a string (i.e. "192.168.0.1") and port is '
'an integer (i.e. 40000).',
'an integer (i.e. 40000). Alternatively '
'target can be a valid UNIX domain socket.',
self.sshtunnel_log_messages['warning'])

@mock.patch('sshtunnel.input_', return_value=linesep)
Expand Down
2 changes: 2 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ deps =
{py27,py33,py34}: readme
{py26}: unittest2
flake8
mccabe
mock
pytest
pytest-cov
Expand All @@ -49,3 +50,4 @@ commands=
[flake8]
exclude = .tox,*.egg,build,data,docs
select = E,W,F
max-complexity = 10

0 comments on commit 8a83e3a

Please sign in to comment.