diff --git a/sshm/_info.py b/sshm/_info.py index a243c1d..d0abda8 100644 --- a/sshm/_info.py +++ b/sshm/_info.py @@ -1,7 +1,7 @@ #! /usr/bin/env python3 # This is the official version of sshm -__version__ = '2.0.1' +__version__ = '2.0.2' __long_description__ = ''' SSH Multi v%s. SSH into multiple machines at once. diff --git a/sshm/lib.py b/sshm/lib.py index a0ee56f..d4f97cc 100644 --- a/sshm/lib.py +++ b/sshm/lib.py @@ -8,6 +8,7 @@ __all__ = ['sshm', 'uri_expansion'] disable_formatting = False +default_workers = 20 # This is used to parse a range string @@ -237,7 +238,7 @@ def ssh(thread_num, context, uri, command, extra_arguments, if_stdin=False): CHUNK_SIZE = 65536 -def sshm(servers, command, extra_arguments=None, stdin=None, disable_formatting_var=False, max_workers=5): +def sshm(servers, command, extra_arguments=None, stdin=None, disable_formatting_var=False, workers=default_workers): """ SSH into multiple servers and execute "command". Pass stdin to these ssh handles. @@ -265,6 +266,9 @@ def sshm(servers, command, extra_arguments=None, stdin=None, disable_formatting_ instance. @type stdin: file + @param workers: The max amount of concurrent SSH connections. + @type workers: int + @returns: A list containing (success, handle, message) from each method call. """ @@ -308,7 +312,7 @@ def sshm(servers, command, extra_arguments=None, stdin=None, disable_formatting_ next_uri = next(uri_gen) while next_uri or threads: # Start a new thread if there are any URIs left - while next_uri and len(threads) < max_workers: + while next_uri and len(threads) < workers: thread = threading.Thread(target=ssh, args=(thread_num, context, next_uri, command, extra_arguments, if_stdin)) thread.start() diff --git a/sshm/main.py b/sshm/main.py index f093fe8..9267a12 100755 --- a/sshm/main.py +++ b/sshm/main.py @@ -44,6 +44,8 @@ def get_argparse_args(args=None): help='Disable command formatting.') parser.add_argument('-u', '--quiet', action='store_true', default=False, help="Hide SSHM's server information on output (this implies sorted).") + parser.add_argument('-w', '--workers', type=int, default=20, + help="Limit the amount of concurrent SSH connections.") parser.add_argument('--version', action='version', version='%(prog)s '+__version__) args, extra_args = parser.parse_known_args(args=args) diff --git a/sshm/test/test_main.py b/sshm/test/test_main.py index fb4e9a0..40ab205 100755 --- a/sshm/test/test_main.py +++ b/sshm/test/test_main.py @@ -86,6 +86,14 @@ def test_get_argparse_args(self): self.assertEqual(command, 'ls') self.assertEqual(extra_args, []) + # You can specify the amount of workers + provided = ['example.com', 'ls', '-w 5'] + args, command, extra_args = get_argparse_args(provided) + self.assertEqual(args.servers, ['example.com',]) + self.assertEqual(command, 'ls') + self.assertEqual(args.workers, 5) + self.assertEqual(extra_args, []) + def test__print_handling_newlines(self): """