diff --git a/manticore/utils/config.py b/manticore/utils/config.py index 9356df2b7..365272478 100644 --- a/manticore/utils/config.py +++ b/manticore/utils/config.py @@ -47,6 +47,13 @@ class _Group: And then use their value in the code as: group.some_var + Can also be used with a `with-statement context` so it would revert the value, e.g.: + group.var = 100 + with group: + group.var = 123 + # group.var is 123 for the time of with statement + # group.var is back to 100 + Note that it is not recommended to use it as a default argument value for a function as it will be evaluated once. Also don't forget that a given variable can be set through CLI or .yaml file! (see config.py) @@ -56,6 +63,10 @@ def __init__(self, name: str): object.__setattr__(self, '_name', name) object.__setattr__(self, '_vars', {}) + # Whether we are in a context manager (`with group:`) + object.__setattr__(self, '_entered', False) + object.__setattr__(self, '_saved', {}) + @property def name(self) -> str: return self._name @@ -120,6 +131,9 @@ def __getattr__(self, name): raise AttributeError(f"Group '{self.name}' has no variable '{name}'") def __setattr__(self, name, new_value): + if self._entered and name not in self._saved: + self._saved[name] = self._vars[name].value + self._vars[name].value = new_value def __iter__(self): @@ -128,6 +142,21 @@ def __iter__(self): def __contains__(self, key): return key in self._vars + def __enter__(self): + if self._entered is True: + raise ConfigError("Can't use `with group` recursively!") + + object.__setattr__(self, '_entered', True) + self._saved.clear() + + def __exit__(self, *_): + object.__setattr__(self, '_entered', False) + + for k in self._saved: + self._vars[k].value = self._saved[k] + + self._saved.clear() + def get_group(name: str) -> _Group: """ diff --git a/tests/test_config.py b/tests/test_config.py index ac05eea9c..b6fbf214b 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -59,6 +59,27 @@ def test_update(self): self.assertEqual(o.description, 'description') self.assertEqual(g.val2, 56) + def test_group_context_manager(self): + g = config.get_group('test') + g.add('val1', default=123) + + self.assertEqual(g.val1, 123) + + with g: + self.assertEqual(g.val1, 123) + g.val1 = 456 + self.assertEqual(g.val1, 456) + g.val1 = 789 + self.assertEqual(g.val1, 789) + + with self.assertRaises(config.ConfigError) as e: + with g: + pass + + self.assertEqual(str(e.exception), "Can't use `with group` recursively!") + + self.assertEqual(g.val1, 123) + def test_getattr(self): g = config.get_group('attrs') with self.assertRaises(AttributeError):