diff --git a/sshtunnel.py b/sshtunnel.py index 2ac5afed..44d30364 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -78,18 +78,25 @@ def check_address(address): (``str``, ``int``) representing an IP address and port, respectively Raises: - AssertionError: + ValueError: raised when address has an incorrect format Example: >>> check_address(("127.0.0.1", 22)) """ - assert isinstance(address, tuple), "ADDRESS is not a tuple ({0})".format( - type(address).__name__ - ) - check_host(address[0]) - check_port(address[1]) + if isinstance(address, tuple): + check_host(address[0]) + check_port(address[1]) + elif isinstance(address, string_types): + # check if address is a valid UNIX domain socket + if not (os.path.exists(address) or + os.access(os.path.dirname(address), os.W_OK)): + raise ValueError('ADDRESS not a valid socket domain socket ({0})' + .format(address)) + else: + raise ValueError('ADDRESS is not a tuple, string, or character buffer ' + '({0})'.format(type(address).__name__)) def check_addresses(address_list): @@ -108,7 +115,9 @@ def check_addresses(address_list): >>> check_addresses([("127.0.0.1", 22), ("127.0.0.1", 2222)]) """ - assert isinstance(address_list, (list, tuple)) + # TODO: remote addresses should never be UNIX sockets ??? + assert (isinstance(address_list, (list, tuple)) or + isinstance(address_list, (list, string_types))) for address in address_list: check_address(address) @@ -172,7 +181,9 @@ def _check_paramiko_handlers(logger=None): def address_to_str(address): - return "{0[0]}:{0[1]}".format(address) + if isinstance(address, tuple): + return "{0[0]}:{0[1]}".format(address) + return str(address) def get_connection_id(): @@ -220,14 +231,19 @@ class _ForwardHandler(socketserver.BaseRequestHandler): def handle(self): uid = get_connection_id() - info = 'In #{0} <-- {1}'.format(uid, self.client_address) + info = 'In #{0} <-- {1}'.format(uid, self.client_address or + self.server.local_address) try: assert isinstance(self.remote_address, tuple) + src_address = self.request.getpeername() + self.logger.critical(self.server) + if not isinstance(src_address, tuple): + src_address = ('dummy', 12345) chan = self.ssh_transport.open_channel('direct-tcpip', self.remote_address, - self.request.getpeername()) + src_address) except AssertionError: - msg = 'Remote address MUST be a tuple (ip:port): {0}' \ + msg = 'Remote address MUST be a tuple (IP:port): {0}' \ .format(self.remote_address) self.logger.error(msg) raise HandlerSSHTunnelForwarderError(msg) @@ -250,15 +266,15 @@ def handle(self): if self.request in rqst: data = self.request.recv(1024) if TRACE: - self.logger.info('{0} recv: {1}' + self.logger.info('<<< {0} recv: {1} <<<' .format(info, repr(data))) chan.send(data) if len(data) == 0: break - if chan in rqst: + if chan in rqst: # else data = chan.recv(1024) if TRACE: - self.logger.info('{0} recv: {1}' + self.logger.info('>>> {0} recv: {1} >>>' .format(info, repr(data))) self.request.send(data) if len(data) == 0: @@ -310,7 +326,29 @@ def remote_port(self): class _ThreadingForwardServer(socketserver.ThreadingMixIn, _ForwardServer): """ - Allows concurrent connections to each tunnel + Allow concurrent connections to each tunnel + """ + # If True, cleanly stop threads created by ThreadingMixIn when quitting + daemon_threads = DAEMON + + +class _UnixStreamForwardServer(socketserver.UnixStreamServer, _ForwardServer): + """ + Serve over UNIX domain sockets (does not work on Windows) + """ + @property + def local_host(self): + return None + + @property + def local_port(self): + return None + + +class _ThreadingUnixStreamForwardServer(socketserver.ThreadingMixIn, + _UnixStreamForwardServer): + """ + Allow concurrent connections to each tunnel """ # If True, cleanly stop threads created by ThreadingMixIn when quitting daemon_threads = DAEMON @@ -415,6 +453,10 @@ class SSHTunnelForwarder(object): are valid values Default: ("0.0.0.0", ``RANDOM PORT``) + .. versionchanged:: 0.0.8 + Added the ability to use a UNIX domain socket as local bind + address + local_bind_addresses (list[tuple]): In case more than one tunnel is established at once, a list of tuples (in the same format as ``local_bind_address``) @@ -531,12 +573,14 @@ def local_is_up(self, target): """ try: check_address(target) - except AssertionError: - self.logger.warning('Target must be a tuple (ip, port), where ip ' + except ValueError: + self.logger.warning('Target must be a tuple (IP, port), where IP ' 'is a string (i.e. "192.168.0.1") and port is ' - 'an integer (i.e. 40000).') + 'an integer (i.e. 40000). Alternatively ' + 'target can be a valid UNIX domain socket.') return False - + if isinstance(target, string_types): # UNIX stream + return True (host, port) = target reachable_from = [] self._local_interfaces = self._local_interfaces \ @@ -580,20 +624,37 @@ class Handler(_Handler): remote_address = remote_address_ ssh_transport = self._transport logger = self.logger - + # if isinstance(local_bind_address, string_types): + # unix_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + # else: + # unix_socket = None return Handler def _make_ssh_forward_server_class(self, remote_address_): return _ThreadingForwardServer if self._threaded else _ForwardServer + def _make_unix_ssh_forward_server_class(self, remote_address_): + return _ThreadingUnixStreamForwardServer if \ + self._threaded else _UnixStreamForwardServer + def _make_ssh_forward_server(self, remote_address, local_bind_address): """ Make SSH forward proxy Server class. """ _Handler = self._make_ssh_forward_handler_class(remote_address) - _Server = self._make_ssh_forward_server_class(remote_address) + try: + if isinstance(local_bind_address, string_types): + forward_maker_class = self._make_unix_ssh_forward_server_class + else: + forward_maker_class = self._make_ssh_forward_server_class + + _Server = forward_maker_class(remote_address) ssh_forward_server = _Server(local_bind_address, _Handler) + + # if _Handler.unix_socket: + # _Handler.unix_socket.connect(local_bind_address) + if ssh_forward_server: self._server_list.append(ssh_forward_server) else: @@ -992,6 +1053,11 @@ def start(self): self.check_local_side_of_tunnels() self._is_started = True + def restart(self): + """ Restart connection to the gateway and tunnels """ + self.stop() + self.start() + def _connect_to_gateway(self): """ Open connection to SSH gateway @@ -1101,6 +1167,13 @@ def stop(self): ) _srv.shutdown() _srv.server_close() + # clean up the UNIX domain socket if we're using one + if isinstance(_srv, _UnixStreamForwardServer): + try: + os.unlink(_srv.local_address) + except Exception as e: + self.logger.error('Unable to unlink socket {0}: {1}' + .format(self.local_address, repr(e))) self._stop_transport() self._is_started = False @@ -1144,7 +1217,8 @@ def local_bind_ports(self): Return a list containing the ports of local side of the TCP tunnels """ self._check_is_started() - return [_server.local_port for _server in self._server_list] + return [_server.local_port for _server in self._server_list if + _server.local_port is not None] @property def local_bind_hosts(self): @@ -1152,7 +1226,8 @@ def local_bind_hosts(self): Return a list containing the IP addresses listening for the tunnels """ self._check_is_started() - return [_server.local_host for _server in self._server_list] + return [_server.local_host for _server in self._server_list if + _server.local_host is not None] @property def local_bind_addresses(self): @@ -1165,7 +1240,7 @@ def _get_local_interfaces(): Return all local network interface's IP addresses """ local_if = socket.gethostbyname_ex(socket.gethostname())[-1] - # In Linux, if /etc/hosts is populated with the hostname it will only + # In Linux, if /etc/hosts is populated with the hostname, it will only # return 127.0.0.1 if '127.0.0.1' not in local_if: local_if.append('127.0.0.1') @@ -1370,6 +1445,8 @@ def _parse_arguments(args=None): dest='local_bind_addresses', metavar='IP:PORT', help='Local bind address sequence: ' 'ip_1:port_1 ip_2:port_2 ... ip_n:port_n\n' + 'Elements may also be valid UNIX socket domains: \n' + '/tmp/foo.sock /tmp/bar.sock ... /tmp/baz.sock\n' 'Equivalent to ssh -LPORT:xxxxxxxxx:xxxx, ' 'being the local IP address optional.\n' 'By default it will listen in all interfaces ' diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index 17ed314b..fdbfd0a4 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -951,10 +951,11 @@ def test_local_is_up(self): self.assertIn('An error occurred while opening tunnels.', self.sshtunnel_log_messages['error']) - self.assertFalse(server.local_is_up("not a tuple")) - self.assertIn('Target must be a tuple (ip, port), where ip ' + self.assertFalse(server.local_is_up("not a valid address")) + self.assertIn('Target must be a tuple (IP, port), where IP ' 'is a string (i.e. "192.168.0.1") and port is ' - 'an integer (i.e. 40000).', + 'an integer (i.e. 40000). Alternatively ' + 'target can be a valid UNIX domain socket.', self.sshtunnel_log_messages['warning']) @mock.patch('sshtunnel.input_', return_value=linesep) diff --git a/tox.ini b/tox.ini index 6153f1b6..544e2eed 100644 --- a/tox.ini +++ b/tox.ini @@ -25,6 +25,7 @@ deps = {py27,py33,py34}: readme {py26}: unittest2 flake8 + mccabe mock pytest pytest-cov @@ -49,3 +50,4 @@ commands= [flake8] exclude = .tox,*.egg,build,data,docs select = E,W,F +max-complexity = 10