diff --git a/docs/manpage.rst b/docs/manpage.rst index 40d0715e92..269039f9cf 100644 --- a/docs/manpage.rst +++ b/docs/manpage.rst @@ -386,6 +386,60 @@ Options controlling ReFrame execution .. versionchanged:: 3.6.1 Multiple report files are now accepted. +.. option:: -S, --setvar=[TEST.]VAR=VAL + + Set variable ``VAR`` in all tests or optionally only in test ``TEST`` to ``VAL``. + + Multiple variables can be set at the same time by passing this option multiple times. + This option *cannot* change arbitrary test attributes, but only test variables declared with the :attr:`~reframe.core.pipeline.RegressionMixin.variable` built-in. + If an attempt is made to change an inexistent variable or a test parameter, a warning will be issued. + + ReFrame will try to convert ``VAL`` to the type of the variable. + If it does not succeed, a warning will be issued and the variable will not be set. + ``VAL`` can take the special value ``@none`` to denote that the variable must be set to :obj:`None`. + + Sequence and mapping types can also be set from the command line by using the following syntax: + + - Sequence types: ``-S seqvar=1,2,3,4`` + - Mapping types: ``-S mapvar=a:1,b:2,c:3`` + + Conversions to arbitrary objects are also supported. + See :class:`~reframe.utility.typecheck.ConvertibleType` for more details. + + + The optional ``TEST.`` prefix refers to the test class name, *not* the test name. + + Variable assignments passed from the command line happen *before* the test is instantiated and is the exact equivalent of assigning a new value to the variable *at the end* of the test class body. + This has a number of implications that users of this feature should be aware of: + + - In the following test, :attr:`num_tasks` will have always the value ``1`` regardless of any command-line assignment of the variable :attr:`foo`: + + .. code-block:: python + + @rfm.simple_test + class my_test(rfm.RegressionTest): + foo = variable(int, value=1) + num_tasks = foo + + - If the variable is set in any pipeline hook, the command line assignment will have an effect until the variable assignment in the pipeline hook is reached. + The variable will be then overwritten. + - The `test filtering <#test-filtering>`__ happens *after* a test is instantiated, so the only way to scope a variable assignment is to prefix it with the test class name. + However, this has some positive side effects: + + - Passing ``-S valid_systems='*'`` and ``-S valid_prog_environs='*'`` is the equivalent of passing the :option:`--skip-system-check` and :option:`--skip-prgenv-check` options. + - Users could alter the behavior of tests based on tag values that they pass from the command line, by changing the behavior of a test in a post-init hook based on the value of the :attr:`~reframe.core.pipeline.RegressionTest.tags` attribute. + - Users could force a test with required variables to run if they set these variables from the command line. + For example, the following test could only be run if invoked with ``-S num_tasks=``: + + .. code-block:: python + + @rfm.simple_test + class my_test(rfm.RegressionTest): + num_tasks = required + + .. versionadded:: 3.8.0 + + ---------------------------------- Options controlling job submission ---------------------------------- diff --git a/reframe/core/fields.py b/reframe/core/fields.py index 4ec5addcaf..35790b91a7 100644 --- a/reframe/core/fields.py +++ b/reframe/core/fields.py @@ -15,6 +15,26 @@ from reframe.utility import ScopedDict +class _Convertible: + '''Wrapper for values that allowed to be converted implicitly''' + + __slots__ = ('value') + + def __init__(self, value): + self.value = value + + +def make_convertible(value): + return _Convertible(value) + + +def remove_convertible(value): + if isinstance(value, _Convertible): + return value.value + else: + return value + + class Field: '''Base class for attribute validators.''' @@ -34,7 +54,7 @@ def __get__(self, obj, objtype): (objtype.__name__, self._name)) from None def __set__(self, obj, value): - obj.__dict__[self._name] = value + obj.__dict__[self._name] = remove_convertible(value) class TypedField(Field): @@ -46,6 +66,10 @@ def __init__(self, main_type, *other_types): raise TypeError('{0} is not a sequence of types'. format(self._types)) + @property + def valid_types(self): + return self._types + def _check_type(self, value): if not any(isinstance(value, t) for t in self._types): typedescr = '|'.join(t.__name__ for t in self._types) @@ -54,8 +78,33 @@ def _check_type(self, value): (self._name, value, typedescr)) def __set__(self, obj, value): - self._check_type(value) - super().__set__(obj, value) + try: + self._check_type(value) + except TypeError: + raw_value = remove_convertible(value) + if raw_value is value: + # value was not convertible; reraise + raise + + # Try to convert value to any of the supported types + value = raw_value + for t in self._types: + try: + value = t(value) + except TypeError: + continue + else: + super().__set__(obj, value) + return + + # Conversion failed + raise TypeError( + f'failed to set field {self._name!r}: ' + f'could not convert to any of the supported types: ' + f'{self._types}' + ) + else: + super().__set__(obj, value) class ConstantField(Field): @@ -88,6 +137,7 @@ def __init__(self, *other_types): super().__init__(str, int, float, *other_types) def __set__(self, obj, value): + value = remove_convertible(value) self._check_type(value) if isinstance(value, str): time_match = re.match(r'^((?P\d+)d)?' @@ -119,6 +169,7 @@ def __init__(self, valuetype, *other_types): ScopedDict, *other_types) def __set__(self, obj, value): + value = remove_convertible(value) self._check_type(value) if not isinstance(value, ScopedDict): value = ScopedDict(value) if value is not None else value diff --git a/reframe/core/meta.py b/reframe/core/meta.py index d6815f14a0..a1c5fbc729 100644 --- a/reframe/core/meta.py +++ b/reframe/core/meta.py @@ -473,6 +473,41 @@ def __getattr__(cls, name): f'class {cls.__qualname__!r} has no attribute {name!r}' ) from None + def setvar(cls, name, value): + '''Set the value of a variable. + + :param name: The name of the variable. + :param value: The value of the variable. + + :returns: :class:`True` if the variable was set. + A variable will *not* be set, if it does not exist or when an + attempt is made to set it with its underlying descriptor. + This happens during the variable injection time and it should be + delegated to the class' :func:`__setattr__` method. + + :raises ReframeSyntaxError: If an attempt is made to override a + variable with a descriptor other than its underlying one. + + ''' + + try: + var_space = super().__getattribute__('_rfm_var_space') + if name in var_space: + if not hasattr(value, '__get__'): + var_space[name].define(value) + return True + elif var_space[name].field is not value: + desc = '.'.join([cls.__qualname__, name]) + raise ReframeSyntaxError( + f'cannot override variable descriptor {desc!r}' + ) + else: + # Variable is being injected + return False + except AttributeError: + '''Catch early access attempt to the variable space.''' + return False + def __setattr__(cls, name, value): '''Handle the special treatment required for variables and parameters. @@ -489,31 +524,20 @@ class attribute. This behavior does not apply when the assigned value is not allowed. This would break the parameter space internals. ''' - # Set the value of a variable (except when the value is a descriptor). - try: - var_space = super().__getattribute__('_rfm_var_space') - if name in var_space: - if not hasattr(value, '__get__'): - var_space[name].define(value) - return - elif not var_space[name].field is value: - desc = '.'.join([cls.__qualname__, name]) - raise ReframeSyntaxError( - f'cannot override variable descriptor {desc!r}' - ) - - except AttributeError: - pass + # Try to treat `name` as variable + if cls.setvar(name, value): + return - # Catch attempts to override a test parameter + # Try to treat `name` as a parameter try: + # Catch attempts to override a test parameter param_space = super().__getattribute__('_rfm_param_space') if name in param_space.params: raise ReframeSyntaxError(f'cannot override parameter {name!r}') - except AttributeError: - pass + '''Catch early access attempt to the parameter space.''' + # Treat `name` as normal class attribute super().__setattr__(name, value) @property diff --git a/reframe/core/pipeline.py b/reframe/core/pipeline.py index e87659f4a3..87c9caa459 100644 --- a/reframe/core/pipeline.py +++ b/reframe/core/pipeline.py @@ -213,6 +213,12 @@ def pipeline_hooks(cls): #: The name of the test. #: #: :type: string that can contain any character except ``/`` + #: :default: For non-parameterised tests, the default name is the test + #: class name. For parameterised tests, the default name is constructed + #: by concatenating the test class name and the string representations + #: of every test parameter: ``TestClassName__``. + #: Any non-alphanumeric value in a parameter's representation is + #: converted to ``_``. name = variable(typ.Str[r'[^\/]+']) #: List of programming environments supported by this test. diff --git a/reframe/core/variables.py b/reframe/core/variables.py index 2727a290c1..c8777503ce 100644 --- a/reframe/core/variables.py +++ b/reframe/core/variables.py @@ -448,7 +448,6 @@ def join(self, other, cls): :param cls: the target class. ''' for key, var in other.items(): - # Make doubly declared vars illegal. Note that this will be # triggered when inheriting from multiple RegressionTest classes. if key in self.vars: diff --git a/reframe/frontend/cli.py b/reframe/frontend/cli.py index 387f95fff6..144986a3c6 100644 --- a/reframe/frontend/cli.py +++ b/reframe/frontend/cli.py @@ -329,6 +329,12 @@ def main(): '--skip-prgenv-check', action='store_true', help='Skip programming environment check' ) + run_options.add_argument( + '-S', '--setvar', action='append', metavar='[TEST.]VAR=VAL', + dest='vars', default=[], + help=('Set test variable VAR to VAL in all tests ' + 'or optionally in TEST only') + ) run_options.add_argument( '--exec-policy', metavar='POLICY', action='store', choices=['async', 'serial'], default='async', @@ -711,10 +717,21 @@ def main(): ) check_search_path = site_config.get('general/0/check_search_path') - loader = RegressionCheckLoader( - load_path=check_search_path, - recurse=check_search_recursive - ) + # Collect any variables set from the command line + external_vars = {} + for expr in options.vars: + try: + lhs, rhs = expr.split('=', maxsplit=1) + except ValueError: + printer.warning( + f'invalid test variable assignment: {expr!r}; skipping' + ) + else: + external_vars[lhs] = rhs + + loader = RegressionCheckLoader(check_search_path, + check_search_recursive, + external_vars) def print_infoline(param, value): param = param + ':' diff --git a/reframe/frontend/loader.py b/reframe/frontend/loader.py index e928c04e19..179308a36b 100644 --- a/reframe/frontend/loader.py +++ b/reframe/frontend/loader.py @@ -13,6 +13,7 @@ import sys import traceback +import reframe.core.fields as fields import reframe.utility as util import reframe.utility.osext as osext from reframe.core.exceptions import NameConflictError, is_severe, what @@ -39,7 +40,7 @@ def visit_ImportFrom(self, node): class RegressionCheckLoader: - def __init__(self, load_path, recurse=False): + def __init__(self, load_path, recurse=False, external_vars=None): # Expand any environment variables and symlinks load_path = [os.path.realpath(osext.expandvars(p)) for p in load_path] self._load_path = osext.unique_abs_paths(load_path, recurse) @@ -48,6 +49,9 @@ def __init__(self, load_path, recurse=False): # Loaded tests by name; maps test names to the file that were defined self._loaded = {} + # Variables set in the command line + self._external_vars = external_vars or {} + def _module_name(self, filename): '''Figure out a module name from filename. @@ -114,6 +118,37 @@ def prefix(self): def recurse(self): return self._recurse + def _set_defaults(self, test_registry): + if test_registry is None: + return + + unset_vars = {} + for test in test_registry: + for name, val in self._external_vars.items(): + if '.' in name: + testname, varname = name.split('.', maxsplit=1) + else: + testname, varname = test.__name__, name + + if testname == test.__name__: + # Treat special values + if val == '@none': + val = None + else: + val = fields.make_convertible(val) + + if not test.setvar(varname, val): + unset_vars.setdefault(test.__name__, []) + unset_vars[test.__name__].append(varname) + + # Warn for all unset variables + for testname, varlist in unset_vars.items(): + varlist = ', '.join(f'{v!r}' for v in varlist) + getlogger().warning( + f'test {testname!r}: ' + f'the following variables were not set: {varlist}' + ) + def load_from_module(self, module): '''Load user checks from module. @@ -127,13 +162,6 @@ def load_from_module(self, module): ''' from reframe.core.pipeline import RegressionTest - # Warn in case of old syntax - if hasattr(module, '_get_checks'): - getlogger().warning( - f'{module.__file__}: _get_checks() is no more supported ' - f'in test files: please use @reframe.simple_test decorator' - ) - # FIXME: Remove the legacy_registry after dropping parameterized_test registry = getattr(module, '_rfm_test_registry', None) legacy_registry = getattr(module, '_rfm_gettests', None) @@ -141,13 +169,19 @@ def load_from_module(self, module): getlogger().debug('No tests registered') return [] + self._set_defaults(registry) candidates = registry.instantiate_all() if registry else [] legacy_candidates = legacy_registry() if legacy_registry else [] + if self._external_vars and legacy_candidates: + getlogger().warning( + "variables of tests using the deprecated " + "'@parameterized_test' decorator cannot be set externally; " + "please use the 'parameter' builtin in your tests" + ) # Merge registries candidates += legacy_candidates - - ret = [] + tests = [] for c in candidates: if not isinstance(c, RegressionTest): continue @@ -160,15 +194,15 @@ def load_from_module(self, module): conflicted = self._loaded[c.name] except KeyError: self._loaded[c.name] = testfile - ret.append(c) + tests.append(c) else: raise NameConflictError( f'test {c.name!r} from {testfile!r} ' f'is already defined in {conflicted!r}' ) - getlogger().debug(f' > Loaded {len(ret)} test(s)') - return ret + getlogger().debug(f' > Loaded {len(tests)} test(s)') + return tests def load_from_file(self, filename, force=False): if not self._validate_source(filename): @@ -195,9 +229,7 @@ def load_from_dir(self, dirname, recurse=False, force=False): checks = [] for entry in os.scandir(dirname): if recurse and entry.is_dir(): - checks.extend( - self.load_from_dir(entry.path, recurse, force) - ) + checks += self.load_from_dir(entry.path, recurse, force) if (entry.name.startswith('.') or not entry.name.endswith('.py') or diff --git a/reframe/utility/typecheck.py b/reframe/utility/typecheck.py index 20076cb64b..71274dfc2c 100644 --- a/reframe/utility/typecheck.py +++ b/reframe/utility/typecheck.py @@ -76,6 +76,10 @@ .. code-block:: none + type + | + | + | List / | / | @@ -93,16 +97,68 @@ import re -class _TypeFactory(abc.ABCMeta): - def register_subtypes(cls): - for t in cls._subtypes: - cls.register(t) +class ConvertibleType(abc.ABCMeta): + '''A type that support conversions from other types. + This is a metaclass that allows classes that use it to support arbitrary + conversions from other types using a cast-like syntax without having to + change their constructor: -# Metaclasses that implement the isinstance logic for the different aggregate -# types + .. code-block:: python -class _ContainerType(_TypeFactory): + new_obj = convertible_type(another_type) + + For example, a class whose constructor accepts and :class:`int` may need + to support a cast-from-string conversion. This is particular useful if you + want a custom-typed test + :attr:`~reframe.core.pipeline.RegressionMixin.variable` to be able to be + set from the command line using the :option:`-S` option. + + In order to support such conversions, a class must use this metaclass and + define a class method, named as :obj:`__rfm_cast___`, for each of + the type conversion that needs to support . + + The following is an example of a class :class:`X` that its normal + constructor accepts two arguments but it also allows conversions from + string: + + .. code-block:: python + + class X(metaclass=ConvertibleType): + def __init__(self, x, y): + self.data = (x, y) + + @classmethod + def __rfm_cast_str__(cls, s): + return X(*(int(x) for x in s.split(',', maxsplit=1))) + + assert X(2, 3).data == X('2,3').data + + .. versionadded:: 3.8.0 + + ''' + + def __call__(cls, *args, **kwargs): + if len(args) == 1: + cast_fn_name = f'__rfm_cast_{type(args[0]).__name__}__' + if hasattr(cls, cast_fn_name): + cast_fn = getattr(cls, cast_fn_name) + return cast_fn(args[0]) + + return super().__call__(*args, **kwargs) + + +# Metaclasses that implement the isinstance logic for the different builtin +# container types + +class _BuiltinType(ConvertibleType): + def __init__(cls, name, bases, namespace): + # Make sure that the class defines `_type` + assert hasattr(cls, '_type') + cls.register(cls._type) + + +class _SequenceType(_BuiltinType): '''A metaclass for containers with uniformly typed elements.''' def __init__(cls, name, bases, namespace): @@ -110,7 +166,6 @@ def __init__(cls, name, bases, namespace): cls._elem_type = None cls._bases = bases cls._namespace = namespace - cls.register_subtypes() def __instancecheck__(cls, inst): if not issubclass(type(inst), cls): @@ -129,15 +184,19 @@ def __getitem__(cls, elem_type): raise TypeError('invalid type specification for container type: ' 'expected ContainerType[elem_type]') - ret = _ContainerType('%s[%s]' % (cls.__name__, elem_type.__name__), - cls._bases, cls._namespace) + ret = _SequenceType('%s[%s]' % (cls.__name__, elem_type.__name__), + cls._bases, cls._namespace) ret._elem_type = elem_type - ret.register_subtypes() cls.register(ret) return ret + def __rfm_cast_str__(cls, s): + container_type = cls._type + elem_type = cls._elem_type + return container_type(elem_type(e) for e in s.split(',')) -class _TupleType(_ContainerType): + +class _TupleType(_SequenceType): '''A metaclass for tuples. Tuples may contain uniformly-typed elements or non-uniformly typed ones. @@ -174,12 +233,25 @@ def __getitem__(cls, elem_types): ) ret = _TupleType(cls_name, cls._bases, cls._namespace) ret._elem_type = elem_types - ret.register_subtypes() cls.register(ret) return ret - -class _MappingType(_TypeFactory): + def __rfm_cast_str__(cls, s): + container_type = cls._type + elem_types = cls._elem_type + elems = s.split(',') + if len(elem_types) == 1: + elem_t = elem_types[0] + return container_type(elem_t(e) for e in elems) + elif len(elem_types) != len(elems): + raise TypeError(f'cannot convert string {s!r} to {cls.__name__!r}') + else: + return container_type( + elem_t(e) for elem_t, e in zip(elem_types, elems) + ) + + +class _MappingType(_BuiltinType): '''A metaclass for type checking mapping types.''' def __init__(cls, name, bases, namespace): @@ -188,7 +260,6 @@ def __init__(cls, name, bases, namespace): cls._value_type = None cls._bases = bases cls._namespace = namespace - cls.register_subtypes() def __instancecheck__(cls, inst): if not issubclass(type(inst), cls): @@ -221,12 +292,29 @@ def __getitem__(cls, typespec): ret = _MappingType(cls_name, cls._bases, cls._namespace) ret._key_type = key_type ret._value_type = value_type - ret.register_subtypes() cls.register(ret) return ret + def __rfm_cast_str__(cls, s): + mappping_type = cls._type + key_type = cls._key_type + value_type = cls._value_type + seq = [] + for key_datum in s.split(','): + try: + k, v = key_datum.split(':') + except ValueError: + # Re-raise as TypeError + raise TypeError( + f'cannot convert string {s!r} to {cls.__name__!r}' + ) from None + + seq.append((key_type(k), value_type(v))) -class _StrType(_ContainerType): + return mappping_type(seq) + + +class _StrType(_SequenceType): '''A metaclass for type checking string types.''' def __instancecheck__(cls, inst): @@ -247,26 +335,28 @@ def __getitem__(cls, patt): ret = _StrType("%s[r'%s']" % (cls.__name__, patt), cls._bases, cls._namespace) ret._elem_type = patt - ret.register_subtypes() cls.register(ret) return ret + def __rfm_cast_str__(cls, s): + return s + class Dict(metaclass=_MappingType): - _subtypes = (dict,) + _type = dict -class List(metaclass=_ContainerType): - _subtypes = (list,) +class List(metaclass=_SequenceType): + _type = list -class Set(metaclass=_ContainerType): - _subtypes = (set,) +class Set(metaclass=_SequenceType): + _type = set class Str(metaclass=_StrType): - _subtypes = (str,) + _type = str class Tuple(metaclass=_TupleType): - _subtypes = (tuple,) + _type = tuple diff --git a/unittests/resources/checks_unlisted/externalvars.py b/unittests/resources/checks_unlisted/externalvars.py new file mode 100644 index 0000000000..dec5ec5aa9 --- /dev/null +++ b/unittests/resources/checks_unlisted/externalvars.py @@ -0,0 +1,29 @@ +import reframe as rfm +import reframe.utility.sanity as sn +import reframe.utility.typecheck as typ + + +@rfm.simple_test +class external_x(rfm.RunOnlyRegressionTest): + valid_systems = ['*'] + valid_prog_environs = ['*'] + foo = variable(int, value=1) + executable = 'echo' + + @sanity_function + def assert_foo(self): + return sn.assert_eq(self.foo, 3) + + +@rfm.simple_test +class external_y(external_x): + foolist = variable(typ.List[int]) + bar = variable(type(None), str) + + @sanity_function + def assert_foolist(self): + return sn.all([ + sn.assert_eq(self.foo, 2), + sn.assert_eq(self.foolist, [3, 4]), + sn.assert_eq(self.bar, None) + ]) diff --git a/unittests/test_cli.py b/unittests/test_cli.py index 6a1055e8b9..48abab4af3 100644 --- a/unittests/test_cli.py +++ b/unittests/test_cli.py @@ -777,3 +777,24 @@ def test_detect_host_topology_file(run_reframe, tmp_path): assert returncode == 0 with open(topo_file) as fp: assert json.load(fp) == cpuinfo() + + +def test_external_vars(run_reframe): + returncode, stdout, stderr = run_reframe( + checkpath=['unittests/resources/checks_unlisted/externalvars.py'], + more_options=['-S', 'external_x.foo=3', '-S', 'external_y.foo=2', + '-S', 'foolist=3,4', '-S', 'bar=@none'] + ) + assert 'Traceback' not in stdout + assert 'Traceback' not in stderr + assert returncode == 0 + + +def test_external_vars_invalid_expr(run_reframe): + returncode, stdout, stderr = run_reframe( + more_options=['-S', 'foo'] + ) + assert 'Traceback' not in stdout + assert 'Traceback' not in stderr + assert 'invalid test variable assignment' in stdout + assert returncode == 0 diff --git a/unittests/test_fields.py b/unittests/test_fields.py index 6422f58c8f..9e6843b8ad 100644 --- a/unittests/test_fields.py +++ b/unittests/test_fields.py @@ -71,6 +71,22 @@ def __init__(self, value): tester.field_any = 3 +def test_typed_field_convertible(): + class FieldTester: + fieldA = fields.TypedField(int, str) + fieldB = fields.TypedField(str, int) + fieldC = fields.TypedField(int) + + tester = FieldTester() + tester.fieldA = fields.make_convertible('10') + tester.fieldB = fields.make_convertible('10') + assert tester.fieldA == 10 + assert tester.fieldB == '10' + + with pytest.raises(TypeError): + tester.fieldC = fields.make_convertible(None) + + def test_timer_field(): class FieldTester: field = fields.TimerField() diff --git a/unittests/test_modules.py b/unittests/test_modules.py index 696eb00c31..2b85a27e94 100644 --- a/unittests/test_modules.py +++ b/unittests/test_modules.py @@ -243,7 +243,7 @@ def _emit_load_commands_tmod4(modules_system): 'module restore foo', f'module use {test_util.TEST_MODULES}' ] assert emit_cmds('foo/1.2') == ['module load foo/1.2'] - if modules_system.name is 'lmod': + if modules_system.name == 'lmod': assert emit_cmds('foo', path='/path') == ['module use /path', 'module load foo'] else: diff --git a/unittests/test_typecheck.py b/unittests/test_typecheck.py index efd342eda7..3014267131 100644 --- a/unittests/test_typecheck.py +++ b/unittests/test_typecheck.py @@ -9,6 +9,7 @@ def assert_type_hierarchy(builtin_type, ctype): + assert isinstance(ctype, type) assert issubclass(builtin_type, ctype) assert issubclass(ctype[int], ctype) assert issubclass(ctype[ctype[int]], ctype) @@ -35,6 +36,16 @@ def test_list_type(): with pytest.raises(TypeError): types.List[int, float] + # Test type conversions + assert types.List[int]('1,2') == [1, 2] + assert types.List[int]('1') == [1] + + with pytest.raises(ValueError): + types.List[int]('foo') + + with pytest.raises(TypeError): + types.List[int](1) + def test_set_type(): s = {1, 2} @@ -54,6 +65,15 @@ def test_set_type(): with pytest.raises(TypeError): types.Set[int, float] + assert types.Set[int]('1,2') == {1, 2} + assert types.Set[int]('1') == {1} + + with pytest.raises(ValueError): + types.Set[int]('foo') + + with pytest.raises(TypeError): + types.Set[int](1) + def test_uniform_tuple_type(): t = (1, 2) @@ -74,6 +94,15 @@ def test_uniform_tuple_type(): with pytest.raises(TypeError): types.Set[3] + assert types.Tuple[int]('1,2') == (1, 2) + assert types.Tuple[int]('1') == (1,) + + with pytest.raises(ValueError): + types.Tuple[int]('foo') + + with pytest.raises(TypeError): + types.Tuple[int](1) + def test_non_uniform_tuple_type(): t = (1, 2.3, '4', ['a', 'b']) @@ -86,6 +115,14 @@ def test_non_uniform_tuple_type(): with pytest.raises(TypeError): types.Set[int, 3] + assert types.Tuple[int, str]('1,2') == (1, '2') + + with pytest.raises(TypeError): + types.Tuple[int, str]('1') + + with pytest.raises(TypeError): + types.Tuple[int, str](1) + def test_mapping_type(): d = {'one': 1, 'two': 2} @@ -106,6 +143,12 @@ def test_mapping_type(): with pytest.raises(TypeError): types.Dict[int, 3] + # Test conversions + assert types.Dict[str, int]('a:1,b:2') == {'a': 1, 'b': 2} + + with pytest.raises(TypeError): + types.Dict[str, int]('a:1,b') + def test_str_type(): s = '123' @@ -121,6 +164,13 @@ def test_str_type(): with pytest.raises(TypeError): types.Str[int] + # Test conversion + typ = types.Str[r'\d+'] + assert typ('10') == '10' + + with pytest.raises(TypeError): + types.Str[r'\d+'](1) + def test_type_names(): assert 'List' == types.List.__name__ @@ -148,3 +198,34 @@ def __hash__(self): assert isinstance(d, types.Dict[int, C]) assert isinstance(cd, types.Dict[C, int]) assert isinstance(t, types.Tuple[int, C, str]) + + +def test_custom_types_conversion(): + class X(metaclass=types.ConvertibleType): + def __init__(self, x): + self.x = x + + @classmethod + def __rfm_cast_str__(cls, s): + return X(int(s)) + + class Y: + def __init__(self, s): + self.y = int(s) + + class Z: + def __init__(self, x, y): + self.z = x + y + + assert X('3').x == 3 + assert X(3).x == 3 + assert X(x='foo').x == 'foo' + + with pytest.raises(TypeError): + X(3, 4) + + with pytest.raises(TypeError): + X(s=3) + + assert Y('1').y == 1 + assert Z(5, 3).z == 8