Skip to content

Commit

Permalink
config.Group: with statement usage
Browse files Browse the repository at this point in the history
This changes introduces possibility to make temporary changes to config groups.
  • Loading branch information
disconnect3d committed Jan 14, 2019
1 parent 7fdc684 commit 2856c09
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
29 changes: 29 additions & 0 deletions manticore/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ class _Group:
And then use their value in the code as:
group.some_var
Can also be used with a `with-statement context` so it would revert the value, e.g.:
group.var = 100
with group:
group.var = 123
# group.var is 123 for the time of with statement
# group.var is back to 100
Note that it is not recommended to use it as a default argument value for a function as it will be evaluated once.
Also don't forget that a given variable can be set through CLI or .yaml file!
(see config.py)
Expand All @@ -56,6 +63,10 @@ def __init__(self, name: str):
object.__setattr__(self, '_name', name)
object.__setattr__(self, '_vars', {})

# Whether we are in a context manager (`with group:`)
object.__setattr__(self, '_entered', False)
object.__setattr__(self, '_saved', {})

@property
def name(self) -> str:
return self._name
Expand Down Expand Up @@ -120,6 +131,9 @@ def __getattr__(self, name):
raise AttributeError(f"Group '{self.name}' has no variable '{name}'")

def __setattr__(self, name, new_value):
if self._entered and name not in self._saved:
self._saved[name] = self._vars[name].value

self._vars[name].value = new_value

def __iter__(self):
Expand All @@ -128,6 +142,21 @@ def __iter__(self):
def __contains__(self, key):
return key in self._vars

def __enter__(self):
if self._entered is True:
raise ConfigError("Can't use `with group` recursively!")

object.__setattr__(self, '_entered', True)
self._saved.clear()

def __exit__(self, *_):
object.__setattr__(self, '_entered', False)

for k in self._saved:
self._vars[k].value = self._saved[k]

self._saved.clear()


def get_group(name: str) -> _Group:
"""
Expand Down
21 changes: 21 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,27 @@ def test_update(self):
self.assertEqual(o.description, 'description')
self.assertEqual(g.val2, 56)

def test_group_context_manager(self):
g = config.get_group('test')
g.add('val1', default=123)

self.assertEqual(g.val1, 123)

with g:
self.assertEqual(g.val1, 123)
g.val1 = 456
self.assertEqual(g.val1, 456)
g.val1 = 789
self.assertEqual(g.val1, 789)

with self.assertRaises(config.ConfigError) as e:
with g:
pass

self.assertEqual(str(e.exception), "Can't use `with group` recursively!")

self.assertEqual(g.val1, 123)

def test_getattr(self):
g = config.get_group('attrs')
with self.assertRaises(AttributeError):
Expand Down

0 comments on commit 2856c09

Please sign in to comment.