diff --git a/sshtunnel.py b/sshtunnel.py index 877adb3a..dbd4ccb7 100644 --- a/sshtunnel.py +++ b/sshtunnel.py @@ -1236,22 +1236,6 @@ def start(self): self._raise(HandlerSSHTunnelForwarderError, 'An error occurred while opening tunnels.') - def _status_ok(self): - """ - Return whether or not everything (underlying transport + tunnels) are - already set up - """ - try: - self._check_is_started() - except HandlerSSHTunnelForwarderError as e: # tunnels down - self._stop_transport() - self.logger.warning(e) - return False - except BaseSSHTunnelForwarderError as e: # underlying transport down - self.logger.warning(e) - return False - return True - def stop(self): """ Shut the tunnel down. @@ -1310,7 +1294,7 @@ def _connect_to_gateway(self): self._stop_transport() if self.ssh_password: # avoid conflict using both pass and pkey - self.logger.debug('Logging in with password {0}' + self.logger.debug('Trying to log in with password: {0}' .format('*' * len(self.ssh_password))) try: self._transport = self._get_transport() @@ -1342,8 +1326,11 @@ def _serve_forever_wrapper(self, _srv, poll_interval=0.1): def _stop_transport(self): """ Close the underlying transport when nothing more is needed """ - if not self._status_ok(): - return + try: + self._check_is_started() + except (BaseSSHTunnelForwarderError, + HandlerSSHTunnelForwarderError) as e: + self.logger.warning(e) for _srv in self._server_list: tunnel = _srv.local_address if self.tunnel_is_up[tunnel]: @@ -1358,8 +1345,9 @@ def _stop_transport(self): self.logger.error('Unable to unlink socket {0}: {1}' .format(self.local_address, repr(e))) self.is_alive = False - self._transport.close() - self._transport.stop_thread() + if self.is_active: + self._transport.close() + self._transport.stop_thread() self.logger.debug('Transport is closed') @property diff --git a/tests/test_forwarder.py b/tests/test_forwarder.py index a7ea57c0..d0774cae 100644 --- a/tests/test_forwarder.py +++ b/tests/test_forwarder.py @@ -732,7 +732,6 @@ def test_gateway_unreachable_raises_exception(self): ): pass - @unittest.expectedFailure # catchall rule on local DNS may make this fail @unittest.skipIf(sys.version_info < (2, 7), reason="Cannot intercept logging messages in py26") def test_gateway_ip_unresolvable_raises_exception(self):