Skip to content

Commit

Permalink
Review changes
Browse files Browse the repository at this point in the history
  • Loading branch information
disconnect3d committed Jan 14, 2019
1 parent 2856c09 commit 6de2436
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 27 deletions.
46 changes: 32 additions & 14 deletions manticore/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class _Group:
Can also be used with a `with-statement context` so it would revert the value, e.g.:
group.var = 100
with group:
with group.temp_vals():
group.var = 123
# group.var is 123 for the time of with statement
# group.var is back to 100
Expand All @@ -63,10 +63,6 @@ def __init__(self, name: str):
object.__setattr__(self, '_name', name)
object.__setattr__(self, '_vars', {})

# Whether we are in a context manager (`with group:`)
object.__setattr__(self, '_entered', False)
object.__setattr__(self, '_saved', {})

@property
def name(self) -> str:
return self._name
Expand Down Expand Up @@ -131,9 +127,6 @@ def __getattr__(self, name):
raise AttributeError(f"Group '{self.name}' has no variable '{name}'")

def __setattr__(self, name, new_value):
if self._entered and name not in self._saved:
self._saved[name] = self._vars[name].value

self._vars[name].value = new_value

def __iter__(self):
Expand All @@ -142,20 +135,45 @@ def __iter__(self):
def __contains__(self, key):
return key in self._vars

def temp_vals(self) -> "_TemporaryGroup":
"""
Returns a contextmanager that can be used to set temporary config variables.
E.g.:
group.var = 123
with group.temp_vals():
group.var = 456
# var is 456
# group.var is back to 123
"""
return _TemporaryGroup(self)


class _TemporaryGroup:
def __init__(self, group: _Group):
object.__setattr__(self, '_group', group)
object.__setattr__(self, '_entered', False)
object.__setattr__(self, '_saved', {k: v.value for k, v in group._vars.items()})

def __getattr__(self, item):
return getattr(self._grp, item)

def __setattr__(self, key, value):
if self._entered and key not in self._saved:
self._saved[key] = getattr(self._group, key).value

def __enter__(self):
if self._entered is True:
raise ConfigError("Can't use `with group` recursively!")
raise ConfigError("Can't use temporary group recursively!")

object.__setattr__(self, '_entered', True)
self._saved.clear()

def __exit__(self, *_):
object.__setattr__(self, '_entered', False)

for k in self._saved:
self._vars[k].value = self._saved[k]
setattr(self._group, k, self._saved[k])

self._saved.clear()
object.__setattr__(self, '_entered', False)


def get_group(name: str) -> _Group:
Expand Down
34 changes: 21 additions & 13 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,26 +59,34 @@ def test_update(self):
self.assertEqual(o.description, 'description')
self.assertEqual(g.val2, 56)

def test_group_context_manager(self):
def test_group_temp_vals(self):
g = config.get_group('test')
g.add('val1', default=123)
g.add('val', default=123)

self.assertEqual(g.val1, 123)
self.assertEqual(g.val, 123)

with g:
self.assertEqual(g.val1, 123)
g.val1 = 456
self.assertEqual(g.val1, 456)
g.val1 = 789
self.assertEqual(g.val1, 789)
with g.temp_vals():
self.assertEqual(g.val, 123)
g.val = 456
self.assertEqual(g.val, 456)
g.val = 789
self.assertEqual(g.val, 789)

with g.temp_vals():
g.val = 123456
self.assertEqual(g.val, 123456)

self.assertEqual(g.val, 789)

self.assertEqual(g.val, 123)

t = g.temp_vals()
with t:
with self.assertRaises(config.ConfigError) as e:
with g:
with t:
pass

self.assertEqual(str(e.exception), "Can't use `with group` recursively!")

self.assertEqual(g.val1, 123)
self.assertEqual(str(e.exception), "Can't use temporary group recursively!")

def test_getattr(self):
g = config.get_group('attrs')
Expand Down

0 comments on commit 6de2436

Please sign in to comment.