From 6de243690367e513648390e8996cee71fc198669 Mon Sep 17 00:00:00 2001 From: disconnect3d Date: Mon, 14 Jan 2019 21:07:57 +0100 Subject: [PATCH] Review changes --- manticore/utils/config.py | 46 +++++++++++++++++++++++++++------------ tests/test_config.py | 34 ++++++++++++++++++----------- 2 files changed, 53 insertions(+), 27 deletions(-) diff --git a/manticore/utils/config.py b/manticore/utils/config.py index 365272478..b3ef1054f 100644 --- a/manticore/utils/config.py +++ b/manticore/utils/config.py @@ -49,7 +49,7 @@ class _Group: Can also be used with a `with-statement context` so it would revert the value, e.g.: group.var = 100 - with group: + with group.temp_vals(): group.var = 123 # group.var is 123 for the time of with statement # group.var is back to 100 @@ -63,10 +63,6 @@ def __init__(self, name: str): object.__setattr__(self, '_name', name) object.__setattr__(self, '_vars', {}) - # Whether we are in a context manager (`with group:`) - object.__setattr__(self, '_entered', False) - object.__setattr__(self, '_saved', {}) - @property def name(self) -> str: return self._name @@ -131,9 +127,6 @@ def __getattr__(self, name): raise AttributeError(f"Group '{self.name}' has no variable '{name}'") def __setattr__(self, name, new_value): - if self._entered and name not in self._saved: - self._saved[name] = self._vars[name].value - self._vars[name].value = new_value def __iter__(self): @@ -142,20 +135,45 @@ def __iter__(self): def __contains__(self, key): return key in self._vars + def temp_vals(self) -> "_TemporaryGroup": + """ + Returns a contextmanager that can be used to set temporary config variables. + E.g.: + group.var = 123 + + with group.temp_vals(): + group.var = 456 + # var is 456 + + # group.var is back to 123 + """ + return _TemporaryGroup(self) + + +class _TemporaryGroup: + def __init__(self, group: _Group): + object.__setattr__(self, '_group', group) + object.__setattr__(self, '_entered', False) + object.__setattr__(self, '_saved', {k: v.value for k, v in group._vars.items()}) + + def __getattr__(self, item): + return getattr(self._grp, item) + + def __setattr__(self, key, value): + if self._entered and key not in self._saved: + self._saved[key] = getattr(self._group, key).value + def __enter__(self): if self._entered is True: - raise ConfigError("Can't use `with group` recursively!") + raise ConfigError("Can't use temporary group recursively!") object.__setattr__(self, '_entered', True) - self._saved.clear() def __exit__(self, *_): - object.__setattr__(self, '_entered', False) - for k in self._saved: - self._vars[k].value = self._saved[k] + setattr(self._group, k, self._saved[k]) - self._saved.clear() + object.__setattr__(self, '_entered', False) def get_group(name: str) -> _Group: diff --git a/tests/test_config.py b/tests/test_config.py index b6fbf214b..9c8988338 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -59,26 +59,34 @@ def test_update(self): self.assertEqual(o.description, 'description') self.assertEqual(g.val2, 56) - def test_group_context_manager(self): + def test_group_temp_vals(self): g = config.get_group('test') - g.add('val1', default=123) + g.add('val', default=123) - self.assertEqual(g.val1, 123) + self.assertEqual(g.val, 123) - with g: - self.assertEqual(g.val1, 123) - g.val1 = 456 - self.assertEqual(g.val1, 456) - g.val1 = 789 - self.assertEqual(g.val1, 789) + with g.temp_vals(): + self.assertEqual(g.val, 123) + g.val = 456 + self.assertEqual(g.val, 456) + g.val = 789 + self.assertEqual(g.val, 789) + with g.temp_vals(): + g.val = 123456 + self.assertEqual(g.val, 123456) + + self.assertEqual(g.val, 789) + + self.assertEqual(g.val, 123) + + t = g.temp_vals() + with t: with self.assertRaises(config.ConfigError) as e: - with g: + with t: pass - self.assertEqual(str(e.exception), "Can't use `with group` recursively!") - - self.assertEqual(g.val1, 123) + self.assertEqual(str(e.exception), "Can't use temporary group recursively!") def test_getattr(self): g = config.get_group('attrs')