diff --git a/src/pyinfra/api/operations.py b/src/pyinfra/api/operations.py index 5bca320d3..17bdf0e95 100644 --- a/src/pyinfra/api/operations.py +++ b/src/pyinfra/api/operations.py @@ -103,7 +103,12 @@ def _run_host_op(state: "State", host: "Host", op_hash: str) -> bool: if isinstance(command, FunctionCommand): try: - status = command.execute(state, host, connector_arguments) + with gevent.Timeout(timeout, exception=TimeoutError): + status = command.execute(state, host, connector_arguments) + + except TimeoutError as e: + log_host_command_error(host, e, timeout=timeout) + except NestedOperationError: host.log_styled("Error in nested operation", fg="red", log_func=logger.error) except Exception as e: diff --git a/src/pyinfra/api/util.py b/src/pyinfra/api/util.py index aaf565eef..e10a20b1f 100644 --- a/src/pyinfra/api/util.py +++ b/src/pyinfra/api/util.py @@ -261,7 +261,7 @@ def log_error_or_warning( def log_host_command_error(host: "Host", e: Exception, timeout: int | None = 0) -> None: - if isinstance(e, timeout_error): + if isinstance(e, (TimeoutError, timeout_error)): logger.error( "{0}{1}".format( host.print_prefix, diff --git a/tests/test_api/test_api_operations.py b/tests/test_api/test_api_operations.py index b3edd3a1a..d1cbb645c 100644 --- a/tests/test_api/test_api_operations.py +++ b/tests/test_api/test_api_operations.py @@ -2,6 +2,7 @@ from os import path from unittest import TestCase from unittest.mock import mock_open, patch +import time import pyinfra from pyinfra.api import ( @@ -258,6 +259,23 @@ def mocked_function(*args, **kwargs): assert is_called + def test_function_call_op_timeout(self): + inventory = make_inventory() + state = State(inventory, Config()) + state.current_stage = StateStage.Prepare + connect_all(state) + + timeout = 1 + + def mocked_function(*args, **kwargs): + time.sleep(timeout + 1) + + add_op(state, python.call, mocked_function, _timeout=timeout) + + # Timeout should cause the operation to fail and hosts to be removed + with self.assertRaises(PyinfraError) as context: + run_ops(state) + def test_run_once_serial_op(self): inventory = make_inventory() state = State(inventory, Config())