Permalink
Browse files

auto reconnect if connections break in SSHClient

moved sftp and transport attributes to properties that will test whether
the current connections are still active or not and if not will
attempt to reconnect. if the reconnect is unsuccessful an exception is
thrown. this is useful when you restart a cluster in the development
shell.

add method to retrieve the remote server's public key
  • Loading branch information...
1 parent 19623d8 commit c97ce7627cfa7b7b7705a529b38a9a201354d32c @jtriley jtriley committed Mar 24, 2011
Showing with 66 additions and 57 deletions.
  1. +66 −57 starcluster/ssh.py
View
@@ -43,23 +43,14 @@ def __init__(self,
private_key_pass=None,
port=22,
timeout=30):
+ self._host = host
+ self._port = 22
+ self._pkey = None
+ self._username = username or os.environ['LOGNAME']
+ self._password = password
self._timeout = timeout
- self._sftp_live = False
self._sftp = None
- self._pkey = None
- if not username:
- username = os.environ['LOGNAME']
- # Begin the SSH transport.
- self._transport_live = False
- try:
- sock = self._get_socket(host, port)
- self._transport = paramiko.Transport(sock)
- self._transport.banner_timeout = self._timeout
- except socket.error:
- raise exception.SSHConnectionError(host, port)
- self._transport_live = True
- # Authenticate the transport.
- pkey = None
+ self._transport = None
if private_key:
# Use Private Key.
log.debug('private key specified')
@@ -76,20 +67,45 @@ def __init__(self,
self._pkey = pkey
elif not password:
raise exception.SSHNoCredentialsError()
- try:
- self._transport.connect(username=username, pkey=pkey,
- password=password)
- except paramiko.AuthenticationException:
- raise exception.SSHAuthException(username, host)
- except paramiko.SSHException, e:
- msg = e.args[0]
- raise exception.SSHError(msg)
- except socket.error:
- raise exception.SSHConnectionError(host, port)
- except EOFError:
- raise exception.SSHConnectionError(host, port)
- except Exception, e:
- raise exception.SSHError(str(e))
+ assert self.transport is not None
+
+ @property
+ def transport(self):
+ """
+ This property attempts to return an active SSH transport
+ """
+ if not self._transport or not self._transport.is_active():
+ try:
+ sock = self._get_socket(self._host, self._port)
+ self._transport = paramiko.Transport(sock)
+ self._transport.banner_timeout = self._timeout
+ except socket.error:
+ raise exception.SSHConnectionError(self._host, self._port)
+ # Authenticate the transport.
+ try:
+ self._transport.connect(username=self._username,
+ pkey=self._pkey,
+ password=self._password)
+ except paramiko.AuthenticationException:
+ raise exception.SSHAuthException(self._username, self._host)
+ except paramiko.SSHException, e:
+ msg = e.args[0]
+ raise exception.SSHError(msg)
+ except socket.error:
+ raise exception.SSHConnectionError(self._host, self._port)
+ except EOFError:
+ raise exception.SSHConnectionError(self._host, self._port)
+ except Exception, e:
+ raise exception.SSHError(str(e))
+ return self._transport
+
+ def get_server_public_key(self):
+ return self.transport.get_remote_server_key()
+
+ def is_active(self):
+ if self._transport:
+ return self._transport.is_active()
+ return False
def _get_socket(self, hostname, port):
for (family, socktype, proto, canonname, sockaddr) in \
@@ -126,11 +142,13 @@ def _load_dsa_key(self, private_key, private_key_pass=None):
except paramiko.SSHException:
log.error('invalid dsa key or password specified')
- def _sftp_connect(self):
+ @property
+ def sftp(self):
"""Establish the SFTP connection."""
- if not self._sftp_live:
- self._sftp = paramiko.SFTPClient.from_transport(self._transport)
- self._sftp_live = True
+ if not self._sftp or self._sftp.sock.closed:
+ log.debug("creating sftp connection")
+ self._sftp = paramiko.SFTPClient.from_transport(self.transport)
+ return self._sftp
def generate_rsa_key(self):
return paramiko.RSAKey.generate(2048)
@@ -178,9 +196,8 @@ def mkdir(self, path, mode=0755, ignore_failure=False):
mode specifies unix permissions to apply to the new dir
"""
- self._sftp_connect()
try:
- return self._sftp.mkdir(path, mode)
+ return self.sftp.mkdir(path, mode)
except IOError:
if not ignore_failure:
raise
@@ -224,12 +241,14 @@ def remove_lines_from_file(self, remote_file, regex):
f.writelines(lines)
f.close()
+ def unlink(self, remote_file):
+ return self.sftp.unlink(remote_file)
+
def remote_file(self, file, mode='w'):
"""
Returns a remote file descriptor
"""
- self._sftp_connect()
- rfile = self._sftp.open(file, mode)
+ rfile = self.sftp.open(file, mode)
rfile.name = file
return rfile
@@ -238,7 +257,6 @@ def path_exists(self, path):
Test whether a remote path exists.
Returns False for broken symbolic links
"""
- self._sftp_connect()
try:
self.stat(path)
return True
@@ -265,14 +283,12 @@ def ls(self, path):
"""
Return a list containing the names of the entries in the remote path.
"""
- self._sftp_connect()
- return [os.path.join(path, f) for f in self._sftp.listdir(path)]
+ return [os.path.join(path, f) for f in self.sftp.listdir(path)]
def isdir(self, path):
"""
Return true if the remote path refers to an existing directory.
"""
- self._sftp_connect()
try:
s = self.stat(path)
except IOError:
@@ -283,7 +299,6 @@ def isfile(self, path):
"""
Return true if the remote path refers to an existing file.
"""
- self._sftp_connect()
try:
s = self.stat(path)
except IOError:
@@ -294,26 +309,24 @@ def stat(self, path):
"""
Perform a stat system call on the given remote path.
"""
- self._sftp_connect()
- return self._sftp.stat(path)
+ return self.sftp.stat(path)
def get(self, remotepath, localpath=None):
"""
Copies a file between the remote host and the local host.
"""
if not localpath:
localpath = os.path.split(remotepath)[1]
- self._sftp_connect()
- self._sftp.get(remotepath, localpath)
+ self.sftp_connect()
+ self.sftp.get(remotepath, localpath)
def put(self, localpath, remotepath=None):
"""
Copies a file between the local host and the remote host.
"""
if not remotepath:
remotepath = os.path.split(localpath)[1]
- self._sftp_connect()
- self._sftp.put(localpath, remotepath)
+ self.sftp.put(localpath, remotepath)
def execute_async(self, command):
"""
@@ -324,7 +337,7 @@ def execute_async(self, command):
code exits, it will not persist on the remote machine
"""
- channel = self._transport.open_session()
+ channel = self.transport.open_session()
channel.exec_command(command)
def execute(self, command, silent=True, only_printable=False,
@@ -340,7 +353,7 @@ def execute(self, command, silent=True, only_printable=False,
characters
returns List of output lines
"""
- channel = self._transport.open_session()
+ channel = self.transport.open_session()
channel.exec_command(command)
#stdin = channel.makefile('wb', -1)
stdout = channel.makefile('rb', -1)
@@ -412,18 +425,14 @@ def get_env(self):
def close(self):
"""Closes the connection and cleans up."""
- # Close SFTP Connection.
- if self._sftp_live:
+ if self._sftp:
self._sftp.close()
- self._sftp_live = False
- # Close the SSH Transport.
- if self._transport_live:
+ if self._transport:
self._transport.close()
- self._transport_live = False
def interactive_shell(self):
try:
- chan = self._transport.open_session()
+ chan = self.transport.open_session()
chan.get_pty()
chan.invoke_shell()
log.info('Starting interactive shell...')

0 comments on commit c97ce76

Please sign in to comment.