Skip to content

Commit

Permalink
Adding command_timeout option for run()
Browse files Browse the repository at this point in the history
  • Loading branch information
fruch committed Jul 1, 2019
1 parent 96771fd commit 23f8169
Show file tree
Hide file tree
Showing 8 changed files with 73 additions and 20 deletions.
1 change: 1 addition & 0 deletions invoke/__init__.py
Expand Up @@ -16,6 +16,7 @@
UnknownFileType,
UnpicklableConfigMember,
WatcherError,
CommandTimeouted,
)
from .executor import Executor # noqa
from .loader import FilesystemLoader # noqa
Expand Down
1 change: 1 addition & 0 deletions invoke/config.py
Expand Up @@ -485,6 +485,7 @@ def global_defaults():
"in_stream": None,
"watchers": [],
"echo_stdin": None,
"command_timeout": None,
},
# This doesn't live inside the 'run' tree; otherwise it'd make it
# somewhat harder to extend/override in Fabric 2 which has a split
Expand Down
20 changes: 14 additions & 6 deletions invoke/exceptions.py
Expand Up @@ -76,18 +76,14 @@ def __str__(self):
if "stdout" not in self.result.hide:
stdout = already_printed
else:
stdout = encode_output(
_tail(self.result.stdout), self.result.encoding
)
stdout = encode_output(_tail(self.result.stdout), self.result.encoding)
if self.result.pty:
stderr = " n/a (PTYs have no stderr)"
else:
if "stderr" not in self.result.hide:
stderr = already_printed
else:
stderr = encode_output(
_tail(self.result.stderr), self.result.encoding
)
stderr = encode_output(_tail(self.result.stderr), self.result.encoding)
command = self.result.command
exited = self.result.exited
template = """Encountered a bad command exit code!
Expand All @@ -111,6 +107,18 @@ def __repr__(self):
)


class CommandTimeouted(UnexpectedExit):
def __init__(self, result, reason=None, timeout=None):
super(CommandTimeouted, self).__init__(result, reason)
self.timeout = timeout

def __str__(self):
return (
super(CommandTimeouted, self).__str__()
+ "\n[TIMEOUT after %ss]" % self.timeout
)


class AuthFailure(Failure):
"""
An authentication failure, e.g. due to an incorrect ``sudo`` password.
Expand Down
49 changes: 37 additions & 12 deletions invoke/runners.py
Expand Up @@ -7,6 +7,7 @@
import sys
import threading
import time
import signal

from .util import six

Expand All @@ -25,7 +26,13 @@
except ImportError:
termios = None

from .exceptions import UnexpectedExit, Failure, ThreadException, WatcherError
from .exceptions import (
UnexpectedExit,
Failure,
ThreadException,
WatcherError,
CommandTimeouted,
)
from .terminals import (
WINDOWS,
pty_size,
Expand Down Expand Up @@ -250,6 +257,9 @@ def run(self, command, **kwargs):
When not ``None``, this parameter will override that auto-detection
and force, or disable, echoing.
:param command_timeout:
time in secoands to abort the command
:returns:
`Result`, or a subclass thereof.
Expand Down Expand Up @@ -282,7 +292,7 @@ def _run_body(self, command, **kwargs):
if opts["echo"]:
print("\033[1;37m{}\033[0m".format(command))
# Start executing the actual command (runs in background)
self.start(command, shell, env)
self.start(command, shell, env, command_timeout=opts["command_timeout"])
# Arrive at final encoding if neither config nor kwargs had one
self.encoding = opts["encoding"] or self.default_encoding()
# Set up IO thread parameters (format - body_func: {kwargs})
Expand Down Expand Up @@ -400,6 +410,12 @@ def _run_body(self, command, **kwargs):
# TODO: ambiguity exists if we somehow get WatcherError in *both*
# threads...as unlikely as that would normally be.
raise Failure(result, reason=watcher_errors[0])
if (
opts["command_timeout"]
and result.return_code == -int(signal.SIGUSR1)
and not opts["warn"]
):
raise CommandTimeouted(result, timeout=opts["command_timeout"])
if not (result or opts["warn"]):
raise UnexpectedExit(result)
return result
Expand Down Expand Up @@ -560,9 +576,7 @@ def handle_stdout(self, buffer_, hide, output):
.. versionadded:: 1.0
"""
self._handle_output(
buffer_, hide, output, reader=self.read_proc_stdout
)
self._handle_output(buffer_, hide, output, reader=self.read_proc_stdout)

def handle_stderr(self, buffer_, hide, output):
"""
Expand All @@ -573,9 +587,7 @@ def handle_stderr(self, buffer_, hide, output):
.. versionadded:: 1.0
"""
self._handle_output(
buffer_, hide, output, reader=self.read_proc_stderr
)
self._handle_output(buffer_, hide, output, reader=self.read_proc_stderr)

def read_our_stdin(self, input_):
"""
Expand Down Expand Up @@ -810,7 +822,7 @@ def process_is_finished(self):
"""
raise NotImplementedError

def start(self, command, shell, env):
def start(self, command, shell, env, command_timeout):
"""
Initiate execution of ``command`` (via ``shell``, with ``env``).
Expand Down Expand Up @@ -949,6 +961,7 @@ def __init__(self, context):
super(Local, self).__init__(context)
# Bookkeeping var for pty use case
self.status = None
self.timer = None

def should_use_pty(self, pty=False, fallback=True):
use_pty = False
Expand Down Expand Up @@ -1006,7 +1019,7 @@ def _write_proc_stdin(self, data):
if "Broken pipe" not in str(e):
raise

def start(self, command, shell, env):
def start(self, command, shell, env, command_timeout=None):
if self.using_pty:
if pty is None: # Encountered ImportError
err = "You indicated pty=True, but your platform doesn't support the 'pty' module!" # noqa
Expand All @@ -1033,6 +1046,12 @@ def start(self, command, shell, env):
# for now.
# TODO: see if subprocess is using equivalent of execvp...
os.execve(shell, [shell, "-c", command], env)
else:
if command_timeout:
self.timer = threading.Timer(
command_timeout, os.kill, args=(self.pid, signal.SIGUSR1)
)
self.timer.start()
else:
self.process = Popen(
command,
Expand All @@ -1043,6 +1062,11 @@ def start(self, command, shell, env):
stderr=PIPE,
stdin=PIPE,
)
if command_timeout:
self.timer = threading.Timer(
command_timeout, os.kill, args=(self.process.pid, signal.SIGUSR1)
)
self.timer.start()

@property
def process_is_finished(self):
Expand Down Expand Up @@ -1081,8 +1105,9 @@ def returncode(self):
return self.process.returncode

def stop(self):
# No explicit close-out required (so far).
pass
# explicit close-out required
if self.timer:
self.timer.cancel()


class Result(object):
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Expand Up @@ -8,4 +8,4 @@ license_file = LICENSE

[tool:pytest]
testpaths = tests
python_files = *
python_files = *
2 changes: 1 addition & 1 deletion tests/_util.py
Expand Up @@ -265,7 +265,7 @@ class _Dummy(Runner):
# which isn't a problem for testing).
input_sleep = 0

def start(self, command, shell, env):
def start(self, command, shell, env, command_timeout):
pass

def read_proc_stdout(self, num_bytes):
Expand Down
1 change: 1 addition & 0 deletions tests/config.py
Expand Up @@ -105,6 +105,7 @@ def basic_settings(self):
"replace_env": False,
"shell": "/bin/bash",
"warn": False,
"command_timeout": None,
"watchers": [],
},
"runners": {"local": Local},
Expand Down
17 changes: 17 additions & 0 deletions tests/runners.py
Expand Up @@ -1240,6 +1240,23 @@ def sends_escape_byte_sequence(self):
runner.run(_, pty=pty)
mock_stdin.assert_called_once_with(u"\x03")

class command_timeout:
def pass_timeout_in_run(self):
import time
import pytest
from invoke.exceptions import CommandTimeouted

runner = Context()
before = time.time()

with pytest.raises(CommandTimeouted) as exc:
runner.run("sleep 5", command_timeout=0.1)
after = time.time()
passed = after - before
assert passed < 0.2
assert exc.value.timeout == 0.1
assert "TIMEOUT" in str(exc.value)

class stop:
def always_runs_no_matter_what(self):
class _ExceptingRunner(_Dummy):
Expand Down

0 comments on commit 23f8169

Please sign in to comment.