Skip to content

Commit

Permalink
Fix bug introduced in handler plus minor fixes in doc
Browse files Browse the repository at this point in the history
  • Loading branch information
fernandezcuesta committed Aug 9, 2016
1 parent dab91a8 commit 87f5adc
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 34 deletions.
71 changes: 38 additions & 33 deletions sshtunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@
"""

import os
import sys
import socket
import getpass
import logging
import os
import argparse
import warnings
import threading
from binascii import hexlify
from select import select
from binascii import hexlify

import paramiko

Expand Down Expand Up @@ -98,7 +98,6 @@ def check_address(address):
raised when address has an incorrect format
Example:
>>> check_address(('127.0.0.1', 22))
"""
if isinstance(address, tuple):
Expand Down Expand Up @@ -184,7 +183,7 @@ def create_logger(logger=None,
Default: True
Return:
logging.Logger
:class:`logging.Logger`
"""
logger = logger or logging.getLogger(
'{0}.SSHTunnelForwarder'.format(__name__)
Expand Down Expand Up @@ -374,10 +373,8 @@ def __init__(self, *args, **kwargs):
self.logger = create_logger(kwargs.pop('logger', None))
self.tunnel_ok = queue.Queue()
socketserver.TCPServer.__init__(self, *args, **kwargs)
# super(_ForwardServer, self).__init__(*args, **kwargs)

def serve_forever(self, poll_interval=0.5):
# super(_ForwardServer, self).serve_forever(poll_interval)
socketserver.TCPServer.serve_forever(self, poll_interval)

def handle_error(self, request, client_address):
Expand Down Expand Up @@ -559,10 +556,10 @@ class SSHTunnelForwarder(object):
local_bind_address (tuple):
Local tuple in the format (``str``, ``int``) representing the
IP and port of the local side of the tunnel. Both elements in
the tuple are optional so both ('', 8000) and ('10.0.0.1', )
are valid values
the tuple are optional so both ``('', 8000)`` and
``('10.0.0.1', )`` are valid values
Default: ('0.0.0.0', ``RANDOM PORT``)
Default: ``('0.0.0.0', RANDOM_PORT)``
.. versionchanged:: 0.0.8
Added the ability to use a UNIX domain socket as local bind
Expand All @@ -573,7 +570,7 @@ class SSHTunnelForwarder(object):
of tuples (in the same format as ``local_bind_address``)
can be specified, such as [(ip1, port_1), (ip_2, port2), ...]
Default: [``local_bind_address``]
Default: ``[local_bind_address]``
.. versionadded:: 0.0.4
Expand All @@ -586,7 +583,7 @@ class SSHTunnelForwarder(object):
of tuples (in the same format as ``remote_bind_address``)
can be specified, such as [(ip1, port_1), (ip_2, port2), ...]
Default: [``remote_bind_address``]
Default: ``[remote_bind_address]``
.. versionadded:: 0.0.4
Expand All @@ -598,8 +595,8 @@ class SSHTunnelForwarder(object):
.. versionadded:: 0.0.8
compression (boolean):
Turn on/off compression. By default compression is off since it
negatively affects interactive sessions
Turn on/off transport compression. By default compression is
disabled since it may negatively affect interactive sessions
Default: ``False``
Expand All @@ -609,7 +606,8 @@ class SSHTunnelForwarder(object):
logging instance for sshtunnel and paramiko
Default: :class:`logging.Logger` instance with a single
`StreamHandler` handler and :const:`DEFAULT_LOGLEVEL` level
:class:`logging.StreamHandler` handler and
:const:`DEFAULT_LOGLEVEL` level
.. versionadded:: 0.0.3
Expand Down Expand Up @@ -679,8 +677,8 @@ class SSHTunnelForwarder(object):
This attribute should not be modified
.. note::
When ``skip_tunnel_checkup`` is disabled or the local bind is a
UNIX socket, the value will always be ``True``
When :attr:`.skip_tunnel_checkup` is disabled or the local bind
is a UNIX socket, the value will always be ``True``
**Example**::
Expand Down Expand Up @@ -712,7 +710,7 @@ def local_is_up(self, target):
boolean
.. deprecated:: 0.1.0
Replaced by ``check_tunnels()`` and ``.tunnel_is_up``
Replaced by :meth:`.check_tunnels()` and :attr:`.tunnel_is_up`
"""
try:
check_address(target)
Expand All @@ -733,10 +731,10 @@ def _make_ssh_forward_handler_class(self, remote_address_):
"""
Make SSH Handler class
"""
Handler = _ForwardHandler
Handler.remote_address = remote_address_
Handler.ssh_transport = self._transport
Handler.logger = self.logger
class Handler(_ForwardHandler):
remote_address = remote_address_
ssh_transport = self._transport
logger = self.logger
return Handler

def _make_ssh_forward_server_class(self, remote_address_):
Expand Down Expand Up @@ -1159,7 +1157,7 @@ def _check_tunnel(self, _srv):
if self.skip_tunnel_checkup:
self.tunnel_is_up[_srv.local_address] = True
return
self.logger.info('Checking tunnel: {0}'.format(_srv.local_address))
self.logger.info('Checking tunnel to: {0}'.format(_srv.remote_address))
if isinstance(_srv.local_address, string_types): # UNIX stream
s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
else:
Expand Down Expand Up @@ -1190,7 +1188,7 @@ def _check_tunnel(self, _srv):
def check_tunnels(self):
"""
Check that if all tunnels are established and populates
``self.tunnel_is_up``
:attr:`.tunnel_is_up`
"""
for _srv in self._server_list:
self._check_tunnel(_srv)
Expand All @@ -1213,12 +1211,16 @@ def start(self):
thread.daemon = self.daemon_forward_servers
thread.start()
self._check_tunnel(_srv)
self.is_alive = any(self.tunnel_is_up.values())
if not self.is_alive:
self._raise(HandlerSSHTunnelForwarderError,
'An error occurred while opening tunnels.')
self.is_alive = any(self.tunnel_is_up.values())
if not self.is_alive:
self._raise(HandlerSSHTunnelForwarderError,
'An error occurred while opening tunnels.')

def _status_ok(self):
"""
Return whether or not everything (underlying transport + tunnels) are
already set up
"""
try:
self._check_is_started()
except HandlerSSHTunnelForwarderError as e: # tunnels down
Expand All @@ -1232,7 +1234,9 @@ def _status_ok(self):

def stop(self):
"""
Shut the tunnel down. This has to be handled with care:
Shut the tunnel down.
.. note:: This **had** to be handled with care before ``0.1.0``:
- if a port redirection is opened
- the destination is not reachable
Expand Down Expand Up @@ -1324,7 +1328,7 @@ def _stop_transport(self):
tunnel = _srv.local_address
if self.tunnel_is_up[tunnel]:
self.logger.info('Shutting down tunnel {0}'.format(tunnel))
_srv.shutdown()
_srv.shutdown()
_srv.server_close()
# clean up the UNIX domain socket if we're using one
if isinstance(_srv, _UnixStreamForwardServer):
Expand Down Expand Up @@ -1397,10 +1401,11 @@ def local_bind_addresses(self):
@property
def tunnel_bindings(self):
"""
Return a dictionary containing the local<>remote tunnel_bindings
Return a dictionary containing the active local<>remote tunnel_bindings
"""
return dict((_server.remote_address, _server.local_address) for
_server in self._server_list)
_server in self._server_list if
self.tunnel_is_up[_server.local_address])

@property
def is_active(self):
Expand Down Expand Up @@ -1513,7 +1518,7 @@ def open_tunnel(*args, **kwargs):
Enable/disable the local side check and populate
:attr:`~SSHTunnelForwarder.tunnel_is_up`
Default: False
Default: True
.. versionadded:: 0.1.0
Expand Down Expand Up @@ -1560,7 +1565,7 @@ def do_something(port):
)

ssh_port = kwargs.pop('ssh_port', None)
skip_tunnel_checkup = kwargs.pop('skip_tunnel_checkup', False)
skip_tunnel_checkup = kwargs.pop('skip_tunnel_checkup', True)
if not args:
args = ((ssh_address, ssh_port), )
forwarder = SSHTunnelForwarder(*args, **kwargs)
Expand Down
1 change: 0 additions & 1 deletion tests/test_forwarder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import paramiko
import sshtunnel


if sys.version_info[0] == 2:
from cStringIO import StringIO
if sys.version_info < (2, 7):
Expand Down

0 comments on commit 87f5adc

Please sign in to comment.