Permalink
Browse files

upgrading to the last version of sauce connect

  • Loading branch information...
1 parent 780e82b commit 5ad1256460430e582cfc90f56eb3d44d682075e5 @santiycr santiycr committed Jan 21, 2011
Showing with 48 additions and 37 deletions.
  1. +48 −37 support/sauce_connect
View
@@ -27,6 +27,7 @@ import time
import platform
import tempfile
import string
+from base64 import b64encode
from collections import defaultdict
from contextlib import closing
from functools import wraps
@@ -37,7 +38,7 @@ except ImportError:
import simplejson as json # Python 2.5 dependency
NAME = "sauce_connect"
-RELEASE = 21
+RELEASE = 25
DISPLAY_VERSION = "%s release %s" % (NAME, RELEASE)
PRODUCT_NAME = u"Sauce Connect"
VERSIONS_URL = "http://saucelabs.com/versions.json"
@@ -59,6 +60,12 @@ is_openbsd = platform.system().lower() == "openbsd"
logger = logging.getLogger(NAME)
+class DeleteRequest(urllib2.Request):
+
+ def get_method(self):
+ return "DELETE"
+
+
class HTTPResponseError(Exception):
def __init__(self, msg):
@@ -84,18 +91,19 @@ class TunnelMachine(object):
_host_search = re.compile("//([^/]+)").search
- def __init__(self, rest_url, user, password, domains, metadata=None):
+ def __init__(self, rest_url, user, password, domains, ssh_port, metadata=None):
self.user = user
self.password = password
self.domains = set(domains)
+ self.ssh_port = ssh_port
self.metadata = metadata or dict()
self.reverse_ssh = None
self.is_shutdown = False
self.base_url = "%(rest_url)s/%(user)s/tunnels" % locals()
self.rest_host = self._host_search(rest_url).group(1)
- self.basic_auth_header = {"Authorization": "Basic %s" %
- ("%s:%s" % (user, password)).encode("base64").strip()}
+ self.basic_auth_header = {"Authorization": "Basic %s"
+ % b64encode("%s:%s" % (user, password))}
self._set_urlopen(user, password)
@@ -111,9 +119,7 @@ class TunnelMachine(object):
"help@saucelabs.com.")
def _set_urlopen(self, user, password):
- # always send Basic Auth header for GET and POST
- # NOTE: we directly construct the header because it is more reliable
- # and more efficient than HTTPBasicAuthHandler and we always need it
+ # always send Basic Auth header (HTTPBasicAuthHandler was unreliable)
opener = urllib2.build_opener()
opener.addheaders = self.basic_auth_header.items()
self.urlopen = opener.open
@@ -155,21 +161,6 @@ class TunnelMachine(object):
raise HTTPResponseError(resp.msg)
return json.loads(resp.read())
- @_retry_rest_api
- def _get_delete_doc(self, url):
- # urllib2 doesn support the DELETE method (lame), so we build our own
- if self.base_url.startswith("https"):
- make_conn = httplib.HTTPSConnection
- else:
- make_conn = httplib.HTTPConnection
- with closing(make_conn(self.rest_host)) as conn:
- conn.request(method="DELETE", url=url,
- headers=self.basic_auth_header)
- resp = conn.getresponse()
- if resp.reason != "OK":
- raise HTTPResponseError(resp.reason)
- return json.loads(resp.read())
-
def _provision_tunnel(self):
# Shutdown any tunnel using a requested domain
kill_list = set()
@@ -186,7 +177,7 @@ class TunnelMachine(object):
logger.debug(
"Shutting down old tunnel host: %s" % tunnel_id)
url = "%s/%s" % (self.base_url, tunnel_id)
- doc = self._get_delete_doc(url)
+ doc = self._get_doc(DeleteRequest(url=url))
if not doc.get('ok'):
logger.warning("Old tunnel host failed to shutdown?")
continue
@@ -201,7 +192,8 @@ class TunnelMachine(object):
# Request a tunnel machine
headers = {"Content-Type": "application/json"}
data = json.dumps(dict(DomainNames=list(self.domains),
- Metadata=self.metadata))
+ Metadata=self.metadata,
+ SSHPort=self.ssh_port))
req = urllib2.Request(url=self.base_url, headers=headers, data=data)
doc = self._get_doc(req)
if doc.get('error'):
@@ -243,9 +235,10 @@ class TunnelMachine(object):
logger.debug("Tunnel host ID: %s" % self.id)
try:
- doc = self._get_delete_doc(self.url)
- except TunnelMachineError:
+ doc = self._get_doc(DeleteRequest(url=self.url))
+ except TunnelMachineError, e:
logger.warning("Unable to shut down tunnel host")
+ logger.debug("Shut down failed because: %s", str(e))
self.is_shutdown = True # fuhgeddaboudit
return
assert doc.get('ok')
@@ -350,13 +343,14 @@ class ReverseSSHError(Exception):
class ReverseSSH(object):
- def __init__(self, tunnel, host, ports, tunnel_ports,
+ def __init__(self, tunnel, host, ports, tunnel_ports, ssh_port,
use_ssh_config=False, debug=False):
self.tunnel = tunnel
self.host = host
self.ports = ports
self.tunnel_ports = tunnel_ports
self.use_ssh_config = use_ssh_config
+ self.ssh_port = ssh_port
self.debug = debug
self.proc = None
@@ -388,8 +382,8 @@ class ReverseSSH(object):
def get_plink_command(self):
"""Return the Windows SSH command."""
verbosity = "-v" if self.debug else ""
- return ("plink\plink %s -l %s -pw %s -N %s %s"
- % (verbosity, self.tunnel.user, self.tunnel.password,
+ return ("plink\plink %s -P %s -l %s -pw %s -N %s %s"
+ % (verbosity, self.ssh_port, self.tunnel.user, self.tunnel.password,
self._dash_Rs, self.tunnel.host))
def get_expect_script(self):
@@ -402,8 +396,8 @@ class ReverseSSH(object):
config_file = "" if self.use_ssh_config else "-F /dev/null"
host_ip = socket.gethostbyname(self.tunnel.host)
script = (
- "spawn ssh %s %s -p 22 -l %s -o ServerAliveInterval=%s -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -N %s %s;"
- % (verbosity, config_file, self.tunnel.user,
+ "spawn ssh %s %s -p %s -l %s -o ServerAliveInterval=%s -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -N %s %s;"
+ % (verbosity, config_file, self.ssh_port, self.tunnel.user,
HEALTH_CHECK_INTERVAL, self._dash_Rs, self.tunnel.host) +
"expect *password:;send -- %s\\r;" % self.tunnel.password +
"expect -timeout -1 timeout")
@@ -433,9 +427,9 @@ class ReverseSSH(object):
# setup recurring healthchecks
forwarded_health = HealthChecker(self.host, self.ports)
- tunnel_health = HealthChecker(host=self.tunnel.host, ports=[22],
+ tunnel_health = HealthChecker(host=self.tunnel.host, ports=[self.ssh_port],
fail_msg="!! Your tests may fail because your network can not get "
- "to the tunnel host (%s:%d)." % (self.tunnel.host, 22))
+ "to the tunnel host (%s:%d)." % (self.tunnel.host, self.ssh_port))
start_time = int(time.time())
while self.proc.poll() is None:
@@ -719,6 +713,8 @@ Performance tip:
help=optparse.SUPPRESS_HELP)
og.add_option("--allow-unclean-exit", action="store_true", default=False,
help=optparse.SUPPRESS_HELP)
+ og.add_option("--ssh-port", default=22, type="int",
+ help=optparse.SUPPRESS_HELP)
op.add_option_group(og)
og = optparse.OptionGroup(op, "Script debugging options")
@@ -730,6 +726,20 @@ Performance tip:
(options, args) = op.parse_args()
+ # check ports are numbers
+ try:
+ map(int, options.ports)
+ map(int, options.tunnel_ports)
+ except ValueError:
+ sys.stderr.write("Error: Ports must be integers\n\n")
+ print "Help with options -t and -p:"
+ print " All ports must be integers. You used:"
+ if options.ports:
+ print " -p", " -p ".join(options.ports)
+ if options.tunnel_ports:
+ print " -t", " -t ".join(options.tunnel_ports)
+ raise SystemExit(1)
+
# default to 80 and default to matching host ports with tunnel ports
if not options.ports and not options.tunnel_ports:
options.ports = ["80"]
@@ -830,6 +840,7 @@ def run(options, dependency_versions=None):
print "| Contact us: http://saucelabs.com/forums |"
print "-----------------------------------------------------"
logger.info("/ Starting \\")
+ logger.info('Please wait for "You may start your tests" to start your tests.')
logger.info("%s" % DISPLAY_VERSION)
check_version()
@@ -861,7 +872,8 @@ def run(options, dependency_versions=None):
for attempt in xrange(1, RETRY_BOOT_MAX + 1):
try:
tunnel = TunnelMachine(options.rest_url, options.user,
- options.api_key, options.domains, metadata)
+ options.api_key, options.domains,
+ options.ssh_port, metadata)
except TunnelMachineError, e:
logger.error(e)
peace_out(returncode=1) # exits
@@ -881,6 +893,7 @@ def run(options, dependency_versions=None):
ssh = ReverseSSH(tunnel=tunnel, host=options.host,
ports=options.ports, tunnel_ports=options.tunnel_ports,
+ ssh_port=options.ssh_port,
use_ssh_config=options.use_ssh_config,
debug=options.debug_ssh)
try:
@@ -891,9 +904,7 @@ def run(options, dependency_versions=None):
def main():
- # more complicated so this works on old Python
- pyver = float("%s.%s" % tuple(platform.python_version().split('.')[:2]))
- if pyver < 2.5:
+ if map(int, platform.python_version_tuple ()) < [2, 5]:
print "%s requires Python 2.5 (2006) or newer." % PRODUCT_NAME
raise SystemExit(1)

0 comments on commit 5ad1256

Please sign in to comment.