Skip to content

Commit

Permalink
Fixed a bug in argparse for Python 2.7.9 + added some new tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Erik Bernhardsson committed Feb 3, 2015
1 parent c2f8c23 commit db34510
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 52 deletions.
93 changes: 42 additions & 51 deletions luigi/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import configuration
import task
import parameter
import re
import argparse
import sys
import os
Expand Down Expand Up @@ -174,42 +173,34 @@ def run(tasks, worker_scheduler_factory=None, override_defaults=None):
return success


class ErrorWrappedArgumentParser(argparse.ArgumentParser):
# Simple unweighted Levenshtein distance
def _editdistance(a, b):
r0 = range(0, len(b) + 1)
r1 = [0] * (len(b) + 1)

''' Wraps ArgumentParser's error message to suggested similar tasks
'''
for i in range(0, len(a)):
r1[0] = i + 1

# Simple unweighted Levenshtein distance
def _editdistance(self, a, b):
r0 = range(0, len(b) + 1)
r1 = [0] * (len(b) + 1)
for j in range(0, len(b)):
c = 0 if a[i] is b[j] else 1
r1[j + 1] = min(r1[j] + 1, r0[j + 1] + 1, r0[j] + c)

for i in range(0, len(a)):
r1[0] = i + 1
r0 = r1[:]

for j in range(0, len(b)):
c = 0 if a[i] is b[j] else 1
r1[j + 1] = min(r1[j] + 1, r0[j + 1] + 1, r0[j] + c)
return r1[len(b)]

r0 = r1[:]

return r1[len(b)]
def error_task_names(task_name):
weighted_tasks = [(_editdistance(task_name, task), task) for task in Register.get_reg().keys()]
ordered_tasks = sorted(weighted_tasks, key=lambda pair: pair[0])
candidates = [task for (dist, task) in ordered_tasks if dist <= 5 and dist < len(task)]
displaystring = ""
if candidates:
displaystring = "No task %s. Did you mean:\n%s" % (task_name, '\n'.join(candidates))
else:
displaystring = "No task %s." % task_name

def error(self, message):
result = re.match("argument .+: invalid choice: '(\w+)'.+", message)
if result:
arg = result.group(1)
weightedTasks = [(self._editdistance(arg, task), task) for task in Register.get_reg().keys()]
orderedTasks = sorted(weightedTasks, key=lambda pair: pair[0])
candidates = [task for (dist, task) in orderedTasks if dist <= 5 and dist < len(task)]
displaystring = ""
if candidates:
displaystring = "No task %s. Did you mean:\n%s" % (arg, '\n'.join(candidates))
else:
displaystring = "No task %s." % arg
super(ErrorWrappedArgumentParser, self).error(displaystring)
else:
super(ErrorWrappedArgumentParser, self).error(message)
raise SystemExit(displaystring)


def add_task_parameters(parser, task_cls, optparse=False):
Expand Down Expand Up @@ -246,37 +237,37 @@ class ArgParseInterface(Interface):
'''

def parse_task(self, cmdline_args=None, main_task_cls=None):
parser = ErrorWrappedArgumentParser()
parser = argparse.ArgumentParser()

add_global_parameters(parser)

subparsers_by_name = {}

if main_task_cls:
add_task_parameters(parser, main_task_cls)

else:
orderedtasks = '{%s}' % ','.join(sorted(Register.get_reg().keys()))
subparsers = parser.add_subparsers(dest='command', metavar=orderedtasks)

for name, cls in Register.get_reg().iteritems():
subparsers_by_name[name] = subparsers.add_parser(name)
if cls == Register.AMBIGUOUS_CLASS:
continue
add_task_parameters(subparsers_by_name[name], cls)

if main_task_cls:
args = parser.parse_args(args=cmdline_args)
task_cls = main_task_cls
else:
task_names = sorted(Register.get_reg().keys())
orderedtasks = '{%s}' % ','.join(task_names)

args, unknown = parser.parse_known_args(args=cmdline_args)
task_cls = Register.get_task_cls(args.command)
if len(unknown) == 0:
raise SystemExit('No task specified')
task_name = unknown[0]
if task_name not in task_names:
error_task_names(task_name)

# Add global params here as well so that we can support both:
# test.py --global-param xyz Test --n 42
# test.py Test --n 42 --global-param xyz
add_global_parameters(subparsers_by_name[args.command])
args = parser.parse_args(args=cmdline_args)
subparsers = parser.add_subparsers(dest='command')
task_cls = Register.get_task_cls(task_name)

subparser = subparsers.add_parser(task_name)
add_global_parameters(subparser)
add_task_parameters(subparser, task_cls)
subargs = parser.parse_args(args=cmdline_args)

for key, value in vars(subargs).items():
if value:
setattr(args, key, value)

# Notice that this is not side effect free because it might set global params
set_global_parameters(args)
Expand All @@ -299,7 +290,7 @@ class DynamicArgParseInterface(ArgParseInterface):
'''

def parse(self, cmdline_args=None, main_task_cls=None):
parser = ErrorWrappedArgumentParser()
parser = argparse.ArgumentParser()

add_global_parameters(parser)

Expand Down
4 changes: 4 additions & 0 deletions test/cmdline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,5 +144,9 @@ def test_bin_luigi(self):
subprocess.check_call(cmd, env=env, stderr=subprocess.STDOUT)
self.assertTrue(t.exists())

@mock.patch('argparse.ArgumentParser.print_usage')
def test_no_task(self, print_usage):
self.assertRaises(SystemExit, luigi.run, ['--local-scheduler', '--no-lock'])

if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion test/dynamic_import_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class CmdlineTest(unittest.TestCase):

def test_dynamic_loading(self):
interface = luigi.interface.ArgParseInterface()
self.assertRaises(Exception, interface.parse, (['FooTask', '--blah', 'xyz', '--x', '123'],)) # should raise since it's not imported
self.assertRaises(SystemExit, interface.parse, (['FooTask', '--blah', 'xyz', '--x', '123'],)) # should raise since it's not imported

interface = luigi.interface.DynamicArgParseInterface()
tasks = interface.parse(['--module', 'foo_module', 'FooTask', '--blah', 'xyz', '--x', '123'])
Expand Down
6 changes: 6 additions & 0 deletions test/parameter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,12 @@ def test_global_param_cmdline(self):
self.assertEqual(h.global_param, 124)
self.assertEqual(h.global_bool_param, False)

def test_global_param_cmdline_flipped(self):
luigi.run(['--local-scheduler', '--no-lock', '--global-param', '125', 'HasGlobalParam', '--x', 'xyz'])
h = HasGlobalParam(x='xyz')
self.assertEqual(h.global_param, 125)
self.assertEqual(h.global_bool_param, False)

def test_global_param_override(self):
h1 = HasGlobalParam(x='xyz', global_param=124)
h2 = HasGlobalParam(x='xyz')
Expand Down

0 comments on commit db34510

Please sign in to comment.