Skip to content

Commit

Permalink
auto reconnect if connections break in SSHClient
Browse files Browse the repository at this point in the history
moved sftp and transport attributes to properties that will test whether
the current connections are still active or not and if not will
attempt to reconnect. if the reconnect is unsuccessful an exception is
thrown. this is useful when you restart a cluster in the development
shell.

add method to retrieve the remote server's public key
  • Loading branch information
jtriley committed Mar 24, 2011
1 parent 19623d8 commit c97ce76
Showing 1 changed file with 66 additions and 57 deletions.
123 changes: 66 additions & 57 deletions starcluster/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,23 +43,14 @@ def __init__(self,
private_key_pass=None,
port=22,
timeout=30):
self._host = host
self._port = 22
self._pkey = None
self._username = username or os.environ['LOGNAME']
self._password = password
self._timeout = timeout
self._sftp_live = False
self._sftp = None
self._pkey = None
if not username:
username = os.environ['LOGNAME']
# Begin the SSH transport.
self._transport_live = False
try:
sock = self._get_socket(host, port)
self._transport = paramiko.Transport(sock)
self._transport.banner_timeout = self._timeout
except socket.error:
raise exception.SSHConnectionError(host, port)
self._transport_live = True
# Authenticate the transport.
pkey = None
self._transport = None
if private_key:
# Use Private Key.
log.debug('private key specified')
Expand All @@ -76,20 +67,45 @@ def __init__(self,
self._pkey = pkey
elif not password:
raise exception.SSHNoCredentialsError()
try:
self._transport.connect(username=username, pkey=pkey,
password=password)
except paramiko.AuthenticationException:
raise exception.SSHAuthException(username, host)
except paramiko.SSHException, e:
msg = e.args[0]
raise exception.SSHError(msg)
except socket.error:
raise exception.SSHConnectionError(host, port)
except EOFError:
raise exception.SSHConnectionError(host, port)
except Exception, e:
raise exception.SSHError(str(e))
assert self.transport is not None

@property
def transport(self):
"""
This property attempts to return an active SSH transport
"""
if not self._transport or not self._transport.is_active():
try:
sock = self._get_socket(self._host, self._port)
self._transport = paramiko.Transport(sock)
self._transport.banner_timeout = self._timeout
except socket.error:
raise exception.SSHConnectionError(self._host, self._port)
# Authenticate the transport.
try:
self._transport.connect(username=self._username,
pkey=self._pkey,
password=self._password)
except paramiko.AuthenticationException:
raise exception.SSHAuthException(self._username, self._host)
except paramiko.SSHException, e:
msg = e.args[0]
raise exception.SSHError(msg)
except socket.error:
raise exception.SSHConnectionError(self._host, self._port)
except EOFError:
raise exception.SSHConnectionError(self._host, self._port)
except Exception, e:
raise exception.SSHError(str(e))
return self._transport

def get_server_public_key(self):
return self.transport.get_remote_server_key()

def is_active(self):
if self._transport:
return self._transport.is_active()
return False

def _get_socket(self, hostname, port):
for (family, socktype, proto, canonname, sockaddr) in \
Expand Down Expand Up @@ -126,11 +142,13 @@ def _load_dsa_key(self, private_key, private_key_pass=None):
except paramiko.SSHException:
log.error('invalid dsa key or password specified')

def _sftp_connect(self):
@property
def sftp(self):
"""Establish the SFTP connection."""
if not self._sftp_live:
self._sftp = paramiko.SFTPClient.from_transport(self._transport)
self._sftp_live = True
if not self._sftp or self._sftp.sock.closed:
log.debug("creating sftp connection")
self._sftp = paramiko.SFTPClient.from_transport(self.transport)
return self._sftp

def generate_rsa_key(self):
return paramiko.RSAKey.generate(2048)
Expand Down Expand Up @@ -178,9 +196,8 @@ def mkdir(self, path, mode=0755, ignore_failure=False):
mode specifies unix permissions to apply to the new dir
"""
self._sftp_connect()
try:
return self._sftp.mkdir(path, mode)
return self.sftp.mkdir(path, mode)
except IOError:
if not ignore_failure:
raise
Expand Down Expand Up @@ -224,12 +241,14 @@ def remove_lines_from_file(self, remote_file, regex):
f.writelines(lines)
f.close()

def unlink(self, remote_file):
return self.sftp.unlink(remote_file)

def remote_file(self, file, mode='w'):
"""
Returns a remote file descriptor
"""
self._sftp_connect()
rfile = self._sftp.open(file, mode)
rfile = self.sftp.open(file, mode)
rfile.name = file
return rfile

Expand All @@ -238,7 +257,6 @@ def path_exists(self, path):
Test whether a remote path exists.
Returns False for broken symbolic links
"""
self._sftp_connect()
try:
self.stat(path)
return True
Expand All @@ -265,14 +283,12 @@ def ls(self, path):
"""
Return a list containing the names of the entries in the remote path.
"""
self._sftp_connect()
return [os.path.join(path, f) for f in self._sftp.listdir(path)]
return [os.path.join(path, f) for f in self.sftp.listdir(path)]

def isdir(self, path):
"""
Return true if the remote path refers to an existing directory.
"""
self._sftp_connect()
try:
s = self.stat(path)
except IOError:
Expand All @@ -283,7 +299,6 @@ def isfile(self, path):
"""
Return true if the remote path refers to an existing file.
"""
self._sftp_connect()
try:
s = self.stat(path)
except IOError:
Expand All @@ -294,26 +309,24 @@ def stat(self, path):
"""
Perform a stat system call on the given remote path.
"""
self._sftp_connect()
return self._sftp.stat(path)
return self.sftp.stat(path)

def get(self, remotepath, localpath=None):
"""
Copies a file between the remote host and the local host.
"""
if not localpath:
localpath = os.path.split(remotepath)[1]
self._sftp_connect()
self._sftp.get(remotepath, localpath)
self.sftp_connect()
self.sftp.get(remotepath, localpath)

def put(self, localpath, remotepath=None):
"""
Copies a file between the local host and the remote host.
"""
if not remotepath:
remotepath = os.path.split(localpath)[1]
self._sftp_connect()
self._sftp.put(localpath, remotepath)
self.sftp.put(localpath, remotepath)

def execute_async(self, command):
"""
Expand All @@ -324,7 +337,7 @@ def execute_async(self, command):
code exits, it will not persist on the remote machine
"""

channel = self._transport.open_session()
channel = self.transport.open_session()
channel.exec_command(command)

def execute(self, command, silent=True, only_printable=False,
Expand All @@ -340,7 +353,7 @@ def execute(self, command, silent=True, only_printable=False,
characters
returns List of output lines
"""
channel = self._transport.open_session()
channel = self.transport.open_session()
channel.exec_command(command)
#stdin = channel.makefile('wb', -1)
stdout = channel.makefile('rb', -1)
Expand Down Expand Up @@ -412,18 +425,14 @@ def get_env(self):

def close(self):
"""Closes the connection and cleans up."""
# Close SFTP Connection.
if self._sftp_live:
if self._sftp:
self._sftp.close()
self._sftp_live = False
# Close the SSH Transport.
if self._transport_live:
if self._transport:
self._transport.close()
self._transport_live = False

def interactive_shell(self):
try:
chan = self._transport.open_session()
chan = self.transport.open_session()
chan.get_pty()
chan.invoke_shell()
log.info('Starting interactive shell...')
Expand Down

0 comments on commit c97ce76

Please sign in to comment.