Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ZTUtils tests pass on Py3k #137

Merged
merged 2 commits into from May 25, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
168 changes: 81 additions & 87 deletions src/ZTUtils/Tree.py
Expand Up @@ -20,11 +20,6 @@
from ComputedAttribute import ComputedAttribute
import six

try:
maketrans = str.maketrans
except AttributeError: # Py2
from string import maketrans


class TreeNode(Explicit):
__allow_access_to_unprotected_subobjects__ = 1
Expand Down Expand Up @@ -76,32 +71,15 @@ class TreeMaker:
__allow_access_to_unprotected_subobjects__ = 1

_id = 'tpId'
_assume_children = False
_expand_root = True
_values = 'tpValues'
_assume_children = 0
_values_filter = None
_values_function = None
_state_function = None
_expand_root = 1

_cached_children = None

def setChildAccess(self, attrname=_marker, filter=_marker,
function=_marker):
'''Set the criteria for fetching child nodes.

Child nodes can be accessed through either an attribute name
or callback function. Children fetched by attribute name can
be filtered through a callback function.
'''
if function is _marker:
self._values_function = None
if attrname is not _marker:
self._values = str(attrname)
if filter is not _marker:
self._values_filter = filter
else:
self._values_function = function

def setIdAttr(self, id):
"""Set the attribute or method name called to get a unique Id.

Expand Down Expand Up @@ -137,6 +115,23 @@ def setAssumeChildren(self, assume):
"""
self._assume_children = assume and True or False

def setChildAccess(self, attrname=_marker, filter=_marker,
function=_marker):
'''Set the criteria for fetching child nodes.

Child nodes can be accessed through either an attribute name
or callback function. Children fetched by attribute name can
be filtered through a callback function.
'''
if function is _marker:
self._values_function = None
if attrname is not _marker:
self._values = str(attrname)
if filter is not _marker:
self._values_filter = filter
else:
self._values_function = function

def setStateFunction(self, function):
"""Set the expansion state function.

Expand All @@ -153,41 +148,6 @@ def setStateFunction(self, function):
"""
self._state_function = function

def tree(self, root, expanded=None, subtree=0):
'''Create a tree from root, with specified nodes expanded.

"expanded" must be false, true, or a mapping.
Each key of the mapping is the id of a top-level expanded
node, and each value is the "expanded" value for the
children of that node.
'''
node = self.node(root)
child_exp = expanded
if not simple_type(expanded):
# Assume a mapping
expanded = node.id in expanded
child_exp = child_exp.get(node.id)

expanded = expanded or (not subtree and self._expand_root)
# Set state to 0 (leaf), 1 (opened), or -1 (closed)
state = self.hasChildren(root) and (expanded or -1)
if self._state_function is not None:
state = self._state_function(node.object, state)
node.state = state
if state > 0:
for child in self.getChildren(root):
node._add_child(self.tree(child, child_exp, 1))

if not subtree:
node.depth = 0
return node

def node(self, object):
node = TreeNode()
node.object = object
node.id = b2a(self.getId(object))
return node

def getId(self, object):
id_attr = self._id
if hasattr(object, id_attr):
Expand All @@ -199,13 +159,25 @@ def getId(self, object):
return str(object._p_oid)
return id(object)

def node(self, object):
node = TreeNode()
node.object = object
obid = self.getId(object)
node.id = b2a(self.getId(object))
return node

def hasChildren(self, object):
if self._assume_children:
return 1
# Cache generated children for a subsequent call to getChildren
self._cached_children = (object, self.getChildren(object))
return not not self._cached_children[1]

def filterChildren(self, children):
if self._values_filter:
return self._values_filter(children)
return children

def getChildren(self, object):
# Check and clear cache first
if self._cached_children is not None:
Expand All @@ -225,40 +197,64 @@ def getChildren(self, object):

return self.filterChildren(children)

def filterChildren(self, children):
if self._values_filter:
return self._values_filter(children)
return children
def tree(self, root, expanded=None, subtree=0):
'''Create a tree from root, with specified nodes expanded.

"expanded" must be false, true, or a mapping.
Each key of the mapping is the id of a top-level expanded
node, and each value is the "expanded" value for the
children of that node.
'''
node = self.node(root)
child_exp = expanded
if not simple_type(expanded):
# Assume a mapping
expanded = node.id in expanded
child_exp = child_exp.get(node.id)

expanded = expanded or (not subtree and self._expand_root)
# Set state to 0 (leaf), 1 (opened), or -1 (closed)
state = self.hasChildren(root) and (expanded or -1)
if self._state_function is not None:
state = self._state_function(node.object, state)
node.state = state
if state > 0:
for child in self.getChildren(root):
node._add_child(self.tree(child, child_exp, 1))

if not subtree:
node.depth = 0
return node

def simple_type(ob,
is_simple={type(''): 1, type(0): 1, type(0.0): 1,
type(None): 1}.__contains__):
return is_simple(type(ob))

_SIMPLE_TYPES = set([type(u''), type(b''), type(0), type(0.0), type(None)])

a2u_map = maketrans('+/=', '-._')
u2a_map = maketrans('-._', '+/=')
def simple_type(ob):
return type(ob) in _SIMPLE_TYPES


def b2a(s):
'''Encode a value as a cookie- and url-safe string.
'''Encode a bytes/string as a cookie- and url-safe string.

Encoded string use only alpahnumeric characters, and "._-".
'''
text = str(s).translate(a2u_map)
if six.PY3:
text = text.encode('utf-8')
return base64.b64encode(text)
if not isinstance(s, bytes):
s = str(s)
if isinstance(s, six.text_type):
s = s.encode('utf-8')
return base64.urlsafe_b64encode(s)


def a2b(s):
'''Decode a b2a-encoded string.'''
return base64.b64decode(s.translate(u2a_map))
'''Decode a b2a-encoded value to bytes.'''
if not isinstance(s, bytes):
if isinstance(s, six.text_type):
s = s.encode('ascii')
return base64.urlsafe_b64decode(s)


def encodeExpansion(nodes, compress=1):
'''Encode the expanded node ids of a tree into a string.
'''Encode the expanded node ids of a tree into bytes.

Accepts a list of nodes, such as that produced by root.flat().
Marks each expanded node with an expansion_number attribute.
Expand All @@ -267,34 +263,32 @@ def encodeExpansion(nodes, compress=1):
'''
steps = []
last_depth = -1
n = 0
for node in nodes:
for n, node in enumerate(nodes):
if node.state <= 0:
continue
dd = last_depth - node.depth + 1
last_depth = node.depth
if dd > 0:
steps.append('_' * dd)
steps.append(node.id)
steps.append(node.id) # id is bytes
node.expansion_number = n
n = n + 1
result = ':'.join(steps)
result = b':'.join(steps)
if compress and len(result) > 2:
zresult = ':' + b2a(zlib.compress(result, 9))
zresult = b':' + b2a(zlib.compress(result, 9))
if len(zresult) < len(result):
result = zresult
return result


def decodeExpansion(s, nth=None, maxsize=8192):
'''Decode an expanded node map from a string.
'''Decode an expanded node map from bytes.

If nth is an integer, also return the (map, key) pair for the nth entry.
'''
if len(s) > maxsize: # Set limit to avoid DoS attacks.
raise ValueError('Encoded node map too large')

if s[0] == ':': # Compressed state
if s.startswith(b':'): # Compressed state
dec = zlib.decompressobj()
s = dec.decompress(a2b(s[1:]), maxsize)
if dec.unconsumed_tail:
Expand All @@ -308,8 +302,8 @@ def decodeExpansion(s, nth=None, maxsize=8192):
if nth is not None:
nth_pair = (None, None)
obid = None
for step in s.split(':'):
if step.startswith('_'):
for step in s.split(b':'):
if step.startswith(b'_'):
pop = len(step) - 1
continue
if pop < 0:
Expand Down