diff --git a/circus/circusctl.py b/circus/circusctl.py index f0f18f345..6674930fe 100644 --- a/circus/circusctl.py +++ b/circus/circusctl.py @@ -1,11 +1,13 @@ # -*- coding: utf-8 - import argparse +import cmd +import collections import getopt import json +import os import sys -import traceback import textwrap -import os +import traceback # import pygments if here try: @@ -17,8 +19,8 @@ from circus import __version__ from circus.client import CircusClient +from circus.commands import get_commands from circus.consumer import CircusConsumer -from circus.commands.base import get_commands, KNOWN_COMMANDS from circus.exc import CallError, ArgumentError from circus.util import DEFAULT_ENDPOINT_SUB, DEFAULT_ENDPOINT_DEALER @@ -86,39 +88,125 @@ def _get_switch_str(opt): class ControllerApp(object): - def __init__(self): - self.commands = get_commands() - _Help.commands = self.commands - self.options = { - 'endpoint': {'default': None, 'help': 'connection endpoint'}, - 'timeout': {'default': 5, 'help': 'connection timeout'}, - - 'help': { - 'default': False, - 'action': 'store_true', - 'help': 'Show help and exit'}, - - 'json': {'default': False, 'action': 'store_true', - 'help': 'output to JSON'}, - - 'prettify': { - 'default': False, - 'action': 'store_true', - 'help': 'prettify output'}, - - 'ssh': { - 'default': None, - 'help': 'SSH Server in the format user@host:port'}, - - 'ssh_keyfile': { - 'default': None, - 'help': 'the path to the keyfile to authorise the user'}, - - 'version': { - 'default': False, - 'action': 'store_true', - 'help': 'display version and exit'} - } + def run(self, args): + try: + return self.dispatch(args) + except getopt.GetoptError as e: + print("Error: %s\n" % str(e)) + self.display_help() + return 2 + except CallError as e: + sys.stderr.write("%s\n" % str(e)) + return 1 + except ArgumentError as e: + sys.stderr.write("%s\n" % str(e)) + return 1 + except KeyboardInterrupt: + return 1 + except Exception, e: + sys.stderr.write(traceback.format_exc()) + return 1 + + def dispatch(self, args): + opts = {} + cmd = self.commands[args.command] + if args.help: + print textwrap.dedent(cmd.__doc__) + return 0 + else: + if hasattr(args, 'start'): + opts['start'] = args.start + + if args.endpoint is None: + if cmd.msg_type == 'sub': + args.endpoint = DEFAULT_ENDPOINT_SUB + else: + args.endpoint = DEFAULT_ENDPOINT_DEALER + msg = cmd.message(*args.args, **opts) + handler = getattr(self, "handle_%s" % cmd.msg_type) + return handler(cmd, self.globalopts, msg, args.endpoint, + int(args.timeout), args.ssh, args.ssh_keyfile) + + def handle_sub(self, cmd, opts, topics, endpoint, timeout, ssh_server, + ssh_keyfile): + consumer = CircusConsumer(topics, endpoint=endpoint) + for topic, msg in consumer: + print("%s: %s" % (topic, msg)) + return 0 + + def _console(self, client, cmd, opts, msg): + if opts['json']: + return prettify(client.call(msg), prettify=opts['prettify']) + else: + return cmd.console_msg(client.call(msg)) + + def handle_dealer(self, cmd, opts, msg, endpoint, timeout, ssh_server, + ssh_keyfile): + client = CircusClient(endpoint=endpoint, timeout=timeout, + ssh_server=ssh_server, ssh_keyfile=ssh_keyfile) + try: + if isinstance(msg, list): + for i, command in enumerate(msg): + clm = self._console(client, command['cmd'], opts, + command['msg']) + print("%s: %s" % (i, clm)) + else: + print(self._console(client, cmd, opts, msg)) + except CallError as e: + sys.stderr.write(str(e) + " Try to raise the --timeout value\n") + return 1 + finally: + client.stop() + return 0 + +class CircusCtl(cmd.Cmd, object): + """CircusCtl tool.""" + prompt = '(circusctl) ' + + def __new__(cls, client, commands, *args, **kw): + """Auto add do and complete methods for all known commands.""" + cls.commands = commands + for name, cmd in commands.iteritems(): + cls._add_do_cmd(name, cmd) + cls._add_complete_cmd(name, cmd) + cls.controller = ControllerApp() + cls.controller.commands = commands + cls.client = client + return super(CircusCtl, cls).__new__(cls, *args, **kw) + + def __init__(self, client, *args, **kwargs): + return super(CircusCtl, self).__init__() + + @classmethod + def _add_do_cmd(cls, cmd_name, cmd): + def inner_do_cmd(cls, line): + arguments = parse_arguments([cmd_name] + line.split(), cls.commands) + cls.controller.run(arguments['args']) + inner_do_cmd.__doc__ = textwrap.dedent(cmd.__doc__) + inner_do_cmd.__name__ = "do_%s" % cmd_name + setattr(cls, inner_do_cmd.__name__, inner_do_cmd) + + @classmethod + def _add_complete_cmd(cls, cmd_name, cmd): + def inner_complete_cmd(cls, *args, **kwargs): + if hasattr(cmd, 'autocomplete'): + try: + return cmd.autocomplete(cls.client, *args, **kwargs) + except Exception, e: + import traceback, sys + sys.stderr.write(e.message+"\n") + traceback.print_exc(file=sys.stderr) + else: + return [] + inner_complete_cmd.__doc__ = "Complete the %s command" % cmd_name + inner_complete_cmd.__name__ = "complete_%s" % cmd_name + setattr(cls, inner_complete_cmd.__name__, inner_complete_cmd) + + def do_EOF(self, line): + return True + + def postloop(self): + print def autocomplete(self, autocomplete=False, words=None, cword=None): """ @@ -155,129 +243,111 @@ def autocomplete(self, autocomplete=False, words=None, cword=None): except IndexError: curr = '' - subcommands = [cmd.name for cmd in KNOWN_COMMANDS] + subcommands = get_commands() if cword == 1: # if completing the command name print(' '.join(sorted(filter(lambda x: x.startswith(curr), subcommands)))) sys.exit(1) - def run(self, args): + def display_version(self, *args, **opts): + print(__version__) + return 0 + + def start(self, globalopts): + self.autocomplete() + self.controller.globalopts = globalopts + + args = globalopts['args'] + parser = globalopts['parser'] + + if hasattr(args, 'command'): + sys.exit(self.controller.run(globalopts['args'])) + + print self.prompt[1:-2], + self.display_version() + try: - sys.exit(self.dispatch(args)) - except getopt.GetoptError as e: - print("Error: %s\n" % str(e)) - self.display_help() - sys.exit(2) - except CallError as e: - sys.stderr.write("%s\n" % str(e)) - sys.exit(1) - except ArgumentError as e: - sys.stderr.write("%s\n" % str(e)) - sys.exit(1) + self.cmdloop() except KeyboardInterrupt: - sys.exit(1) - except Exception, e: - sys.stderr.write(traceback.format_exc()) - sys.exit(1) + print + sys.exit(0) + + +def parse_arguments(args, commands): + _Help.commands = commands + + usage = '%(prog)s [options] command [args]' + options = { + 'endpoint': {'default': None, 'help': 'connection endpoint'}, + 'timeout': {'default': 5, 'help': 'connection timeout'}, + + 'help': { + 'default': False, + 'action': 'store_true', + 'help': 'Show help and exit'}, + + 'json': {'default': False, 'action': 'store_true', + 'help': 'output to JSON'}, + + 'prettify': { + 'default': False, + 'action': 'store_true', + 'help': 'prettify output'}, + + 'ssh': { + 'default': None, + 'help': 'SSH Server in the format user@host:port'}, + + 'ssh_keyfile': { + 'default': None, + 'help': 'the path to the keyfile to authorise the user'}, + + 'version': { + 'default': False, + 'action': 'store_true', + 'help': 'display version and exit'} + } - def get_globalopts(self, args): - globalopts = {} - for option in self.options: - if hasattr(args, option): - globalopts[option] = getattr(args, option) - return globalopts + parser = argparse.ArgumentParser( + description="Controls a Circus daemon", + formatter_class=_Help, usage=usage, add_help=False) + + for option in options: + parser.add_argument('--' + option, **options[option]) - def dispatch(self, args): - self.autocomplete() - usage = '%(prog)s [options] command [args]' - parser = argparse.ArgumentParser( - description="Controls a Circus daemon", - formatter_class=_Help, usage=usage, add_help=False) - - for option in self.options: - parser.add_argument('--' + option, **self.options[option]) - - if any([value in self.commands for value in sys.argv]): - subparsers = parser.add_subparsers(dest='command') - for command in self.commands: - subparser = subparsers.add_parser(command) - subparser.add_argument('args', nargs="*", - help=argparse.SUPPRESS) - if command == 'add': - subparser.add_argument('--start', action='store_true', - default=False) - - args = parser.parse_args() - globalopts = self.get_globalopts(args) - opts = {} + if any([value in commands for value in args]): + subparsers = parser.add_subparsers(dest='command') + for command in commands: + subparser = subparsers.add_parser(command) + subparser.add_argument('args', nargs="*", + help=argparse.SUPPRESS) + if command == 'add': + subparser.add_argument('--start', action='store_true', + default=False) - if not hasattr(args, 'command'): - for command in self.commands: - doc = textwrap.dedent(self.commands[command].__doc__) - help = doc.split('\n')[0] - parser.add_argument(command, help=help) - parser.print_help() - return 0 - else: - cmd = self.commands[args.command] - if args.help: - print textwrap.dedent(cmd.__doc__) - return 0 - else: - if hasattr(args, 'start'): - opts['start'] = args.start - - if args.endpoint is None: - if cmd.msg_type == 'sub': - args.endpoint = DEFAULT_ENDPOINT_SUB - else: - args.endpoint = DEFAULT_ENDPOINT_DEALER - msg = cmd.message(*args.args, **opts) - handler = getattr(self, "handle_%s" % cmd.msg_type) - return handler(cmd, globalopts, msg, args.endpoint, - int(args.timeout), args.ssh, args.ssh_keyfile) + args = parser.parse_args(args) - def display_version(self, *args, **opts): - print(__version__) - return 0 + globalopts = {'args': args, 'parser': parser} + for option in options: + globalopts[option] = getattr(args, option) + return globalopts - def handle_sub(self, cmd, opts, topics, endpoint, timeout, ssh_server, - ssh_keyfile): - consumer = CircusConsumer(topics, endpoint=endpoint) - for topic, msg in consumer: - print("%s: %s" % (topic, msg)) - return 0 - def _console(self, client, cmd, opts, msg): - if opts['json']: - return prettify(client.call(msg), prettify=opts['prettify']) - else: - return cmd.console_msg(client.call(msg)) +def main(): + # TODO, we should ask the server for its command list + commands = get_commands() - def handle_dealer(self, cmd, opts, msg, endpoint, timeout, ssh_server, - ssh_keyfile): - client = CircusClient(endpoint=endpoint, timeout=timeout, - ssh_server=ssh_server, ssh_keyfile=ssh_keyfile) - try: - if isinstance(msg, list): - for i, command in enumerate(msg): - clm = self._console(client, command['cmd'], opts, - command['msg']) - print("%s: %s" % (i, clm)) - else: - print(self._console(client, cmd, opts, msg)) - except CallError as e: - sys.stderr.write(str(e) + " Try to raise the --timeout value\n") - return 1 - finally: - client.stop() - return 0 + globalopts = parse_arguments(sys.argv[1:], commands) + if globalopts['endpoint'] is None: + globalopts['endpoint'] = DEFAULT_ENDPOINT_DEALER + client = CircusClient(endpoint=globalopts['endpoint'], + timeout=globalopts['timeout'], + ssh_server=globalopts['ssh'], + ssh_keyfile=globalopts['ssh_keyfile']) -def main(): - controller = ControllerApp() - controller.run(sys.argv[1:]) + CircusCtl(client, commands).start(globalopts) if __name__ == '__main__': main()