Permalink
Browse files

update config

1 parent 4acf7ab commit c87a7e2c6ed62c968bf89198380f531bf0edc10d @ry committed Jun 12, 2016
Showing with 33 additions and 2 deletions.
  1. +33 −2 config.py
View
@@ -11,6 +11,20 @@ def __init__(self):
root[k] = v
self.stack = [ root ]
+ def iteritems(self):
+ return self.to_dict().iteritems()
+
+ def to_dict(self):
+ self._pop_stale()
+ out = {}
+ # Work backwards from the flags to top fo the stack
+ # overwriting keys that were found earlier.
+ for i in range(len(self.stack)):
+ cs = self.stack[-i]
+ for name in cs:
+ out[name] = cs[name]
+ return out
+
def _pop_stale(self):
var_scope_name = tf.get_variable_scope().name
top = self.stack[0]
@@ -20,7 +34,7 @@ def _pop_stale(self):
top = self.stack[0]
def __getitem__(self, name):
- self._pop_stale()
+ self._pop_stale()
# Recursively extract value
for i in range(len(self.stack)):
cs = self.stack[i]
@@ -29,8 +43,20 @@ def __getitem__(self, name):
raise KeyError(name)
+ def set_default(self, name, value):
+ if not name in self:
+ self[name] = value
+
+ def __contains__(self, name):
+ self._pop_stale()
+ for i in range(len(self.stack)):
+ cs = self.stack[i]
+ if name in cs:
+ return True
+ return False
+
def __setitem__(self, name, value):
- self._pop_stale()
+ self._pop_stale()
top = self.stack[0]
var_scope_name = tf.get_variable_scope().name
assert top.contains(var_scope_name)
@@ -67,16 +93,21 @@ def assert_raises(exception, fn):
assert c['hello'] == 1
with tf.variable_scope('foo'):
+ c.set_default("bar", 10)
c['bar'] = 2
assert c['bar'] == 2
assert c['hello'] == 1
+ c.set_default("mario", True)
+
with tf.variable_scope('meow'):
c['dog'] = 3
assert c['dog'] == 3
assert c['bar'] == 2
assert c['hello'] == 1
+ assert c['mario'] == True
+
assert_raises(KeyError, lambda: c['dog'])
assert c['bar'] == 2
assert c['hello'] == 1

0 comments on commit c87a7e2

Please sign in to comment.