Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/pyinfra/api/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,12 @@ def _run_host_op(state: "State", host: "Host", op_hash: str) -> bool:

if isinstance(command, FunctionCommand):
try:
status = command.execute(state, host, connector_arguments)
with gevent.Timeout(timeout, exception=TimeoutError):
status = command.execute(state, host, connector_arguments)

except TimeoutError as e:
log_host_command_error(host, e, timeout=timeout)

except NestedOperationError:
host.log_styled("Error in nested operation", fg="red", log_func=logger.error)
except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion src/pyinfra/api/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def log_error_or_warning(


def log_host_command_error(host: "Host", e: Exception, timeout: int | None = 0) -> None:
if isinstance(e, timeout_error):
if isinstance(e, (TimeoutError, timeout_error)):
logger.error(
"{0}{1}".format(
host.print_prefix,
Expand Down
18 changes: 18 additions & 0 deletions tests/test_api/test_api_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from os import path
from unittest import TestCase
from unittest.mock import mock_open, patch
import time

import pyinfra
from pyinfra.api import (
Expand Down Expand Up @@ -258,6 +259,23 @@ def mocked_function(*args, **kwargs):

assert is_called

def test_function_call_op_timeout(self):
inventory = make_inventory()
state = State(inventory, Config())
state.current_stage = StateStage.Prepare
connect_all(state)

timeout = 1

def mocked_function(*args, **kwargs):
time.sleep(timeout + 1)

add_op(state, python.call, mocked_function, _timeout=timeout)

# Timeout should cause the operation to fail and hosts to be removed
with self.assertRaises(PyinfraError) as context:
run_ops(state)

def test_run_once_serial_op(self):
inventory = make_inventory()
state = State(inventory, Config())
Expand Down
Loading