Skip to content

Commit

Permalink
Refactored ssh connection handling to only open network connections w…
Browse files Browse the repository at this point in the history
…hen needed and to allow multiple servers in the future.
  • Loading branch information
fschulze committed Apr 20, 2010
1 parent bcf0d16 commit d671f7f
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 27 deletions.
5 changes: 5 additions & 0 deletions docs/HISTORY.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ Changelog
0.8 - Unreleased
----------------

* Refactored ssh connection handling to only open network connections when
needed. Any fabric option which doesn't need a connection runs right away
now (like ``-h`` and ``-l``).
[fschulze]

* Fix status output after ``start``.
[fschulze]

Expand Down
51 changes: 24 additions & 27 deletions mr/awsome/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@
from mr.awsome import template
import boto.ec2
import datetime
import fabric.main
import fabric.network
import fabric.state
import logging
import argparse
import os
Expand Down Expand Up @@ -245,10 +242,10 @@ def start(self, overrides={}):
return instance

def init_ssh_key(self, user=None):
fabric.state.env.reject_unknown_hosts = True
fabric.state.env.disable_known_hosts = True
#user, host, port = fabric.network.normalize(hoststr)
instance = self.instance
if instance is None:
log.error("Can't establish ssh connection.")
return
if user is None:
user = 'root'
host = str(instance.public_dns_name)
Expand All @@ -267,10 +264,7 @@ def init_ssh_key(self, user=None):
os.remove(known_hosts)
client.get_host_keys().clear()
client.save_host_keys(known_hosts)
# store the connection in the fabric connection cache
real_key = fabric.network.join_host_strings(user, host, port)
fabric.state.connections[real_key] = client
return real_key, known_hosts
return user, host, port, client, known_hosts

def snapshot(self, devs=None):
if devs is None:
Expand Down Expand Up @@ -513,17 +507,24 @@ def cmd_do(self, argv, help):
return
old_sys_argv = sys.argv
old_cwd = os.getcwd()

import fabric_integration
# this needs to be done before any other fabric module import
fabric_integration.patch()

import fabric.state
import fabric.main

hoststr = None
try:
sid = argv[0]
server = self.ec2.servers[sid]
try:
hoststr, known_hosts = server.init_ssh_key()
except paramiko.SSHException, e:
log.error("Couldn't validate fingerprint for ssh connection.")
log.error(e)
log.error("Is the server finished starting up?")
return
fabric.state.connections.set_ec2(self.ec2)
fabric.state.connections.set_log(log)
hoststr = argv[0]
server = self.ec2.servers[hoststr]
# prepare the connection
fabric.state.env.reject_unknown_hosts = True
fabric.state.env.disable_known_hosts = True

fabfile = server.config.get('fabfile')
if fabfile is None:
log.error("No fabfile declared.")
Expand All @@ -539,6 +540,7 @@ def cmd_do(self, argv, help):
os.chdir(os.path.dirname(fabfile))
fabric.state.env.servers = self.ec2.servers
fabric.state.env.server = server
known_hosts = os.path.join(self.ec2.configpath, 'known_hosts')
fabric.state.env.known_hosts = known_hosts

class StdFilter(object):
Expand All @@ -559,7 +561,7 @@ def write(self, msg):

fabric.main.main()
finally:
if hoststr is not None:
if fabric.state.connections.opened(hoststr):
fabric.state.connections[hoststr].close()
sys.argv = old_sys_argv
os.chdir(old_cwd)
Expand Down Expand Up @@ -591,19 +593,14 @@ def cmd_ssh(self, argv, help):
parser.print_help()
return
server = self.ec2.servers[argv[sid_index]]
if server.instance is None:
log.error("Can't establish ssh connection.")
return
try:
hoststr, known_hosts = server.init_ssh_key()
user, host, port, client, known_hosts = server.init_ssh_key()
except paramiko.SSHException, e:
log.error("Couldn't validate fingerprint for ssh connection.")
log.error(e)
log.error("Is the server finished starting up?")
return
fabric.state.connections[hoststr].close()
user, host, port = fabric.network.normalize(hoststr)
known_hosts = os.path.join(self.ec2.configpath, 'known_hosts')
client.close()
argv[sid_index:sid_index+1] = ['-o', 'UserKnownHostsFile=%s' % known_hosts,
'-l', user,
host]
Expand Down
38 changes: 38 additions & 0 deletions mr/awsome/fabric_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import paramiko


class HostConnectionCache(object):
def __init__(self):
self._cache = dict()

def set_ec2(self, ec2):
self._ec2 = ec2

def set_log(self, log):
self._log = log

def keys(self):
return self._cache.keys()

def opened(self, key):
if key in self._cache:
return True

def __getitem__(self, key):
if key not in self._cache and key in self._ec2.servers:
server = self._ec2.servers[key]
try:
user, host, port, client, known_hosts = server.init_ssh_key()
except paramiko.SSHException, e:
self._log.error("Couldn't validate fingerprint for ssh connection.")
self._log.error(e)
self._log.error("Is the server finished starting up?")
return
self._cache[key] = client
return client
return self._cache[key]


def patch():
import fabric.network
fabric.network.HostConnectionCache = HostConnectionCache

0 comments on commit d671f7f

Please sign in to comment.