Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition in batch logic #56273 #60256

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/56273.fixed
@@ -0,0 +1 @@
Fix race condition in batch logic.
24 changes: 19 additions & 5 deletions salt/cli/batch.py
Expand Up @@ -19,18 +19,31 @@
class Batch:
"""
Manage the execution of batch runs

"""

def __init__(self, opts, eauth=None, quiet=False, parser=None):
def __init__(self, opts, eauth=None, quiet=False, _parser=None):
"""
:param dict opts: A config options dictionary.

:param dict eauth: An eauth config to use.

The default is an empty dict.

:param bool quiet: Supress printing to stdout

The default is False.
"""
self.opts = opts
self.eauth = eauth if eauth else {}
self.pub_kwargs = eauth if eauth else {}
self.quiet = quiet
self.local = salt.client.get_local_client(opts["conf_file"])
self.minions, self.ping_gen, self.down_minions = self.__gather_minions()
self.options = parser
self.options = _parser
# Passing listen True to local client will prevent it from purging
# cahced events while iterating over the batches.
self.local = salt.client.get_local_client(opts["conf_file"], listen=True)

def __gather_minions(self):
def gather_minions(self):
"""
Return a list of minions to use for the batch run
"""
Expand Down Expand Up @@ -106,6 +119,7 @@ def run(self):
"""
Execute the batch run
"""
self.minions, self.ping_gen, self.down_minions = self.gather_minions()
args = [
[],
self.opts["fun"],
Expand Down
2 changes: 1 addition & 1 deletion salt/cli/salt.py
Expand Up @@ -289,7 +289,7 @@ def _run_batch(self):
try:
self.config["batch"] = self.options.batch
batch = salt.cli.batch.Batch(
self.config, eauth=eauth, parser=self.options
self.config, eauth=eauth, _parser=self.options
)
except SaltClientError:
# We will print errors to the console further down the stack
Expand Down
73 changes: 66 additions & 7 deletions salt/client/__init__.py
Expand Up @@ -71,18 +71,46 @@ def get_local_client(
skip_perm_errors=False,
io_loop=None,
auto_reconnect=False,
listen=False,
):
"""
.. versionadded:: 2014.7.0

Read in the config and return the correct LocalClient object based on
the configured transport

:param str c_path: Path of config file to use for opts.

The default value is None.

:param bool mopts: When provided the local client will use is dictionary of
garethgreenaway marked this conversation as resolved.
Show resolved Hide resolved
options insead of loading a config file from the value
of c_path.

The default value is None.

:param str skip_perm_errors: Ignore permissions errors while loading keys.

The default value is False.

:param IOLoop io_loop: io_loop used for events.
Pass in an io_loop if you want asynchronous
operation for obtaining events. Eg use of
set_event_handler() API. Otherwise, operation
will be synchronous.

:param bool keep_loop: Do not destroy the event loop when closing the event
subsriber.

:param bool auto_reconnect: When True the event subscriber will reconnect
automatically if a disconnect error is raised.

.. versionadded:: 3004
:param bool listen: Listen for events indefinitly. When option is set the
LocalClient object will listen for events until it's
destroy method is called.

The default value is False.
"""
if mopts:
opts = mopts
Expand All @@ -98,6 +126,7 @@ def get_local_client(
skip_perm_errors=skip_perm_errors,
io_loop=io_loop,
auto_reconnect=auto_reconnect,
listen=listen,
)


Expand Down Expand Up @@ -128,6 +157,7 @@ class LocalClient:

local = salt.client.LocalClient()
local.cmd('*', 'test.fib', [10])

"""

def __init__(
Expand All @@ -138,13 +168,41 @@ def __init__(
io_loop=None,
keep_loop=False,
auto_reconnect=False,
listen=False,
):
"""
:param str c_path: Path of config file to use for opts.

The default value is None.

:param bool mopts: When provided the local client will use is dictionary of
garethgreenaway marked this conversation as resolved.
Show resolved Hide resolved
options insead of loading a config file from the value
of c_path.

The default value is None.

:param str skip_perm_errors: Ignore permissions errors while loading keys.

The default value is False.

:param IOLoop io_loop: io_loop used for events.
Pass in an io_loop if you want asynchronous
operation for obtaining events. Eg use of
set_event_handler() API. Otherwise,
operation will be synchronous.
set_event_handler() API. Otherwise, operation
will be synchronous.

:param bool keep_loop: Do not destroy the event loop when closing the event
subsriber.

:param bool auto_reconnect: When True the event subscriber will reconnect
automatically if a disconnect error is raised.

.. versionadded:: 3004
:param bool listen: Listen for events indefinitly. When option is set the
LocalClient object will listen for events until it's
destroy method is called.

The default value is False.
"""
if mopts:
self.opts = mopts
Expand All @@ -162,12 +220,13 @@ def __init__(
self.skip_perm_errors = skip_perm_errors
self.key = self.__read_master_key()
self.auto_reconnect = auto_reconnect
self.listen = listen
self.event = salt.utils.event.get_event(
"master",
self.opts["sock_dir"],
self.opts["transport"],
opts=self.opts,
listen=False,
listen=self.listen,
io_loop=io_loop,
keep_loop=keep_loop,
)
Expand Down Expand Up @@ -1052,6 +1111,9 @@ def get_returns_no_block(self, tag, match_type=None):
)
yield raw

def returns_for_job(self, jid):
return self.returners["{}.get_load".format(self.opts["master_job_cache"])](jid)

def get_iter_returns(
self,
jid,
Expand Down Expand Up @@ -1088,10 +1150,7 @@ def get_iter_returns(
missing = set()
# Check to see if the jid is real, if not return the empty dict
try:
if (
self.returners["{}.get_load".format(self.opts["master_job_cache"])](jid)
== {}
):
if self.returns_for_job(jid) == {}:
dwoz marked this conversation as resolved.
Show resolved Hide resolved
log.warning("jid does not exist")
yield {}
# stop the iteration, since the jid is invalid
Expand Down
185 changes: 185 additions & 0 deletions tests/pytests/functional/cli/test_batch.py
@@ -0,0 +1,185 @@
"""
tests.pytests.functional.cli.test_batch
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
"""
import salt.cli.batch
import salt.config
import salt.utils.jid
from tests.support.mock import Mock, patch


class MockPub:
"""
Mock salt.client.LocalClient.pub method
"""

calls = 0
initial_ping = False
batch1_jid = None
batch1_tgt = None
batch2_jid = None
batch2_tgt = None
batch3_jid = None
batch3_tgt = None

def __call__(self, tgt, fun, *args, **kwargs):
if tgt == "minion*" and fun == "test.ping":
MockPub.calls += 1
MockPub.initial_ping = salt.utils.jid.gen_jid({})
pub_ret = {
"jid": MockPub.initial_ping,
"minions": ["minion0", "minion1", "minion2", "minion3"],
}
elif fun == "state.sls":
if MockPub.calls == 1:
MockPub.calls += 1
MockPub.batch1_tgt = list(tgt)
MockPub.batch1_jid = jid = salt.utils.jid.gen_jid({})
pub_ret = {"jid": jid, "minions": tgt}
elif MockPub.calls == 2:
MockPub.calls += 1
MockPub.batch2_tgt = tgt
MockPub.batch2_jid = jid = salt.utils.jid.gen_jid({})
pub_ret = {"jid": jid, "minions": tgt}
elif MockPub.calls == 3:
MockPub.calls += 1
MockPub.batch3_tgt = tgt
MockPub.batch3_jid = jid = salt.utils.jid.gen_jid({})
pub_ret = {"jid": jid, "minions": tgt}
elif fun == "saltutil.find_job":
jid = salt.utils.jid.gen_jid({})
pub_ret = {"jid": jid, "minions": tgt}
return pub_ret


class MockSubscriber:
"""
Mock salt.transport.ipc IPCMessageSubscriber in order to inject events into
salt.utils.Event
"""

calls = 0
pubret = None

def __init__(self, *args, **kwargs):
return

def read(self, timeout=None):
"""
Mock IPCMessageSubcriber read method.

- Return events for initial ping
- Returns event for a minion in first batch to cause second batch to get sent.
- Returns 5 null events on first iteration of second batch to go back to first batch.
- On second iteration of first batch, send an event from second batch which will get cached.
- Return events for the rest of the batches.
"""
if MockSubscriber.pubret.initial_ping:
# Send ping responses for 4 minions
jid = MockSubscriber.pubret.initial_ping
if MockSubscriber.calls == 0:
MockSubscriber.calls += 1
return self._ret(jid, minion_id="minion0", fun="test.ping")
elif MockSubscriber.calls == 1:
MockSubscriber.calls += 1
return self._ret(jid, minion_id="minion1", fun="test.ping")
elif MockSubscriber.calls == 2:
MockSubscriber.calls += 1
return self._ret(jid, minion_id="minion2", fun="test.ping")
elif MockSubscriber.calls == 3:
MockSubscriber.calls += 1
return self._ret(jid, minion_id="minion3", fun="test.ping")
if MockSubscriber.pubret.batch1_jid:
jid = MockSubscriber.pubret.batch1_jid
tgt = MockSubscriber.pubret.batch1_tgt
if MockSubscriber.calls == 4:
# Send a return for first minion in first batch. This causes the
# second batch to get sent.
MockSubscriber.calls += 1
return self._ret(jid, minion_id=tgt[0], fun="state.sls")
if MockSubscriber.pubret.batch2_jid:
if MockSubscriber.calls <= 10:
# Skip the first iteration of the second batch; this will cause
# batch logic to go back to iterating over the first batch.
MockSubscriber.calls += 1
return
elif MockSubscriber.calls == 11:
# Send the minion from the second batch, This event will get cached.
jid = MockSubscriber.pubret.batch2_jid
tgt = MockSubscriber.pubret.batch2_tgt
MockSubscriber.calls += 1
return self._ret(jid, minion_id=tgt[0], fun="state.sls")
if MockSubscriber.calls == 12:
jid = MockSubscriber.pubret.batch1_jid
tgt = MockSubscriber.pubret.batch1_tgt
MockSubscriber.calls += 1
return self._ret(jid, minion_id=tgt[1], fun="state.sls")
if MockSubscriber.pubret.batch3_jid:
jid = MockSubscriber.pubret.batch3_jid
tgt = MockSubscriber.pubret.batch3_tgt
if MockSubscriber.calls == 13:
MockSubscriber.calls += 1
return self._ret(jid, minion_id=tgt[0], fun="state.sls")
return

def _ret(self, jid, minion_id, fun, _return=True, _retcode=0):
"""
Create a mock return from a jid, minion, and fun
"""
serial = salt.payload.Serial({"serial": "msgpack"})
dumped = serial.dumps(
{
"fun_args": [],
"jid": jid,
"return": _return,
"retcode": 0,
"success": True,
"cmd": "_return",
"fun": fun,
"id": minion_id,
"_stamp": "2021-05-24T01:23:25.373194",
},
use_bin_type=True,
)
tag = "salt/job/{}/ret".format(jid).encode()
return b"".join([tag, b"\n\n", dumped])

def connect(self, timeout=None):
pass


def test_batch_issue_56273():
"""
Regression test for race condition in batch logic.
https://github.com/saltstack/salt/issues/56273
"""

mock_pub = MockPub()
MockSubscriber.pubret = mock_pub

def returns_for_job(jid):
return True

opts = {
"conf_file": "",
"tgt": "minion*",
"fun": "state.sls",
"arg": ["foo"],
"timeout": 1,
"gather_job_timeout": 1,
"batch": 2,
"extension_modules": "",
"failhard": True,
}
with patch("salt.transport.ipc.IPCMessageSubscriber", MockSubscriber):
batch = salt.cli.batch.Batch(opts, quiet=True)
with patch.object(batch.local, "pub", Mock(side_effect=mock_pub)):
with patch.object(
batch.local, "returns_for_job", Mock(side_effect=returns_for_job)
):
ret = list(batch.run())
assert len(ret) == 4
for val in ret:
values = list(val.values())
assert len(values) == 1
assert values[0] is True