Skip to content

Commit

Permalink
Make parse_optional_parameter accept multiple params.
Browse files Browse the repository at this point in the history
  • Loading branch information
thangleiter committed Jun 19, 2020
1 parent 2a54499 commit ca21477
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 25 deletions.
6 changes: 3 additions & 3 deletions filter_functions/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def calculate_control_matrix_periodic(phases: ndarray, R: ndarray,
return R_tot


@util.parse_optional_parameter('which', ('total', 'correlations'))
@util.parse_optional_parameters({'which': ('total', 'correlations')})
def calculate_cumulant_function(
pulse: 'PulseSequence',
S: ndarray,
Expand Down Expand Up @@ -478,7 +478,7 @@ def calculate_cumulant_function(
return K.real


@util.parse_optional_parameter('which', ('total', 'correlations'))
@util.parse_optional_parameters({'which': ('total', 'correlations')})
def calculate_decay_amplitudes(
pulse: 'PulseSequence',
S: ndarray,
Expand Down Expand Up @@ -888,7 +888,7 @@ def error_transfer_matrix(
return U


@util.parse_optional_parameter('which', ('total', 'correlations'))
@util.parse_optional_parameters({'which': ('total', 'correlations')})
def infidelity(pulse: 'PulseSequence',
S: Union[Coefficients, Callable],
omega: Union[Coefficients, Dict[str, Union[int, str]]],
Expand Down
4 changes: 2 additions & 2 deletions filter_functions/pulse_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,8 +836,8 @@ def nbytes(self) -> int:

return sum(_nbytes)

@util.parse_optional_parameter(
'method', ('conservative', 'greedy', 'frequency dependent', 'all'))
@util.parse_optional_parameters(
{'method': ('conservative', 'greedy', 'frequency dependent', 'all')})
def cleanup(self, method: str = 'conservative') -> None:
"""
Delete cached byproducts of the calculation of the filter function that
Expand Down
38 changes: 20 additions & 18 deletions filter_functions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@
import string
import sys
from itertools import zip_longest
from typing import (Callable, Generator, Iterable, List, Optional, Sequence,
Tuple, Union)
from typing import (Callable, Dict, Generator, Iterable, List, Optional,
Sequence, Tuple, Union)

import numpy as np
import qutip as qt
Expand Down Expand Up @@ -1119,34 +1119,36 @@ def progressbar_range(*args, show_progressbar: Optional[bool] = True,
return range(*args)


def parse_optional_parameter(name: str, allowed: Sequence) -> Callable:
def parse_optional_parameters(params_dict: Dict[str, Sequence]) -> Callable:
"""Decorator factory to parse optional parameter with certain legal values.
If the parameter value corresponding to ``name`` (either in args or kwargs)
is not contained in ``allowed`` a ``ValueError`` is raised.
For ``params_dict = {name: allowed, ...}``: If the parameter value
corresponding to ``name`` (either in args or kwargs of the decorated
function) is not contained in ``allowed`` a ``ValueError`` is raised.
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
parameters = inspect.signature(func).parameters
idx = tuple(parameters).index(name)
try:
value = args[idx]
except IndexError:
value = kwargs.get(name, parameters[name].default)

if value not in allowed:
raise ValueError(
"Invalid value for {}: {}. ".format(name, value) +
"Should be one of {}.".format(allowed)
)
for name, allowed in params_dict.items():
idx = tuple(parameters).index(name)
try:
value = args[idx]
except IndexError:
value = kwargs.get(name, parameters[name].default)

if value not in allowed:
raise ValueError(
"Invalid value for {}: {}. ".format(name, value) +
"Should be one of {}.".format(allowed)
)
return func(*args, **kwargs)
return wrapper
return decorator


parse_which_FF_parameter = parse_optional_parameter(
'which', ('fidelity', 'generalized'))
parse_which_FF_parameter = parse_optional_parameters(
{'which': ('fidelity', 'generalized')})


class CalculationError(Exception):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,9 +609,9 @@ def test_progressbar_range(self):

self.assertEqual(ii, list(range(523, 123, -32)))

def test_parse_optional_parameter(self):
def test_parse_optional_parameters(self):

@util.parse_optional_parameter('foo', [1, 'bar', (2, 3)])
@util.parse_optional_parameters({'foo': [1, 'bar', (2, 3)]})
def foobar(a, b, foo=None, x=2):
pass

Expand Down

0 comments on commit ca21477

Please sign in to comment.