Skip to content

Commit

Permalink
Allow only mocking methods for specific args
Browse files Browse the repository at this point in the history
  • Loading branch information
themylogin committed Feb 3, 2023
1 parent 33d7576 commit 3600c6e
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 24 deletions.
47 changes: 32 additions & 15 deletions src/middlewared/middlewared/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ def send_error(self, message, errno, reason=None, exc_info=None, etype=None, ext
async def call_method(self, message, serviceobj, methodobj):
params = message.get('params') or []

if mock := self.middleware._mock_method(message['method'], params):
methodobj = mock

try:
async with self._softhardsemaphore:
result = await self.middleware._call(message['method'], serviceobj, methodobj, params, app=self)
Expand Down Expand Up @@ -887,7 +890,7 @@ def __init__(
self.__console_io = False if os.path.exists(self.CONSOLE_ONCE_PATH) else None
self.__terminate_task = None
self.jobs = JobsQueue(self)
self.mocks = {}
self.mocks = defaultdict(list)
self.socket_messages_queue = deque(maxlen=200)
self.tasks = set()

Expand Down Expand Up @@ -1330,7 +1333,7 @@ async def _call(
self.logger.trace('Calling %r in current IO loop', name)
return await methodobj(*prepared_call.args)

if name not in self.mocks and serviceobj._config.process_pool:
if not self.mocks.get(name) and serviceobj._config.process_pool:
self.logger.trace('Calling %r in process pool', name)
if isinstance(serviceobj, middlewared.service.CRUDService):
service_name, method_name = name.rsplit('.', 1)
Expand All @@ -1352,6 +1355,9 @@ def dump_args(self, args, method=None, method_name=None):
except Exception:
return args

if mock := self._mock_method(method_name, args):
method = mock

if (not hasattr(method, 'accepts') and
method.__name__ in ['create', 'update', 'delete'] and
hasattr(method, '__self__')):
Expand All @@ -1376,6 +1382,9 @@ def dump_result(self, result, method):
async def call(self, name, *params, pipes=None, job_on_progress_cb=None, app=None, profile=False):
serviceobj, methodobj = self._method_lookup(name)

if mock := self._mock_method(name, params):
methodobj = mock

if profile:
methodobj = profile_wrap(methodobj)

Expand All @@ -1390,6 +1399,9 @@ def call_sync(self, name, *params, job_on_progress_cb=None, background=False):

serviceobj, methodobj = self._method_lookup(name)

if mock := self._mock_method(name, params):
methodobj = mock

prepared_call = self._call_prepare(name, serviceobj, methodobj, params, job_on_progress_cb=job_on_progress_cb,
in_event_loop=False)

Expand Down Expand Up @@ -1570,9 +1582,10 @@ def _tracemalloc_start(self, limit, interval):

time.sleep(interval)

def set_mock(self, name, mock):
if name in self.mocks:
raise ValueError(f'{name!r} is already mocked')
def set_mock(self, name, args, mock):
for _args, _mock in self.mocks[name]:
if args == _args:
raise ValueError(f'{name!r} is already mocked with {args!r}')

serviceobj, methodobj = self._method_lookup(name)

Expand All @@ -1587,18 +1600,22 @@ def f(*args, **kwargs):
f._job = methodobj._job
copy_function_metadata(mock, f)

self.mocks[name] = f
self.mocks[name].append((args, f))

def remove_mock(self, name):
self.mocks.pop(name)

def _method_lookup(self, name):
serviceobj, methodobj = super()._method_lookup(name)

if mock := self.mocks.get(name):
return serviceobj, mock
def remove_mock(self, name, args):
for i, (_args, _mock) in enumerate(self.mocks[name]):
if args == _args:
del self.mocks[name][i]
break

return serviceobj, methodobj
def _mock_method(self, name, params):
if mocks := self.mocks.get(name):
for args, mock in mocks:
if args == list(params):
return mock
for args, mock in mocks:
if args is None:
return mock

async def ws_handler(self, request):
ws = web.WebSocketResponse()
Expand Down
8 changes: 4 additions & 4 deletions src/middlewared/middlewared/plugins/test/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ class TestService(Service):
class Config:
private = True

async def set_mock(self, name, description):
async def set_mock(self, name, args, description):
if isinstance(description, str):
exec(description)
try:
Expand All @@ -25,10 +25,10 @@ def method(*args, **kwargs):
else:
raise CallError("Invalid mock declaration")

self.middleware.set_mock(name, method)
self.middleware.set_mock(name, args, method)

async def remove_mock(self, name):
self.middleware.remove_mock(name)
async def remove_mock(self, name, args):
self.middleware.remove_mock(name, args)

# Dummy methods to mock for internal infrastructure testing (i.e. jobs manager)

Expand Down
12 changes: 8 additions & 4 deletions src/middlewared/middlewared/test/integration/utils/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,29 @@ def mock(method, declaration="", **kwargs):
:param method: Method name to replace
:params args: Only use this mock when the method is called with the specified arguments.
:param return_value: The value returned when the mock is called.
:param declaration: A string, containing python function declaration for mock. Function should be named `mock`,
can be normal function or `async` and must accept `self` argument and all other arguments the function being
replaced accepts. No `@accepts`, `@job` or other decorators are required, but if a method being replaced is a
job, then mock signature must also accept `job` argument.
"""
args = kwargs.pop("args", None)

if declaration and kwargs:
raise ValueError("Mock `declaration` is not allowed with kwargs")
elif declaration:
arg = textwrap.dedent(declaration)
description = textwrap.dedent(declaration)
else:
arg = kwargs
description = kwargs

with client() as c:
c.call("test.set_mock", method, arg)
c.call("test.set_mock", method, args, description)

try:
yield
finally:
with client() as c:
c.call("test.remove_mock", method)
c.call("test.remove_mock", method, args)
2 changes: 1 addition & 1 deletion tests/api2/test_user_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_can_set_admin_authorized_key(admin):


def test_admin_user_alert(admin):
with mock("user.get_user_obj", return_value={
with mock("user.get_user_obj", args=[{"uid": 950}], return_value={
"pw_name": "root", "pw_uid": 0, "pw_gid": 0, "pw_gecos": "root", "pw_dir": "/root", "pw_shell": "/usr/bin/zsh"
}):
alerts = call("alert.run_source", "AdminUser")
Expand Down

0 comments on commit 3600c6e

Please sign in to comment.