-
Notifications
You must be signed in to change notification settings - Fork 0
/
functions.py
80 lines (70 loc) · 2.26 KB
/
functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import numpy as np
import trees
import vocabulary
def binarize_tree(tree):
"""Binarizes a tree by choosing the leftmost split point.
"""
if isinstance(tree, trees.LeafParseNode):
return tree
else:
if len(tree.children) == 1:
return tree
elif len(tree.children) == 2:
left_child = binarize_tree(tree.children[0])
right_child = binarize_tree(tree.children[1])
else:
left_child = binarize_tree(tree.children[0])
right_child = binarize_tree(trees.InternalParseNode((), tree.children[1:]))
return trees.InternalParseNode(tree.label, [left_child, right_child])
def debinarize_tree(tree):
"""Debinarizes the tree.
"""
def helper(tree):
if isinstance(tree, trees.LeafParseNode):
return [tree]
children = []
for child in tree.children:
children.extend(helper(child))
if tree.label:
return [trees.InternalParseNode(tree.label, children)]
return children
nodes = helper(tree)
if len(nodes) == 1:
return nodes[0]
return trees.InternalParseNode(('S',), nodes)
def tree_to_distance(root):
if isinstance(root, trees.InternalParseNode) and len(root.children) == 2:
d_l, c_l, t_l, h_l = tree_to_distance(root.children[0])
d_r, c_r, t_r, h_r = tree_to_distance(root.children[1])
h = max(h_l, h_r) + 1
d = d_l + [h] + d_r
c = c_l + [root.label] + c_r
t = t_l + t_r
else:
# unary chain
d = []
c = []
h = 0
if isinstance(root, trees.InternalParseNode):
# handle the unary chains here
assert len(root.children) == 1
t = [root.label]
else:
# just predict a empty label
assert isinstance(root, trees.LeafParseNode)
t = [()]
return d, c, t, h
def distance_to_tree(dist, cons, unary, leaves):
assert len(dist) == len(leaves) - 1
assert len(cons) == len(dist)
assert len(unary) == len(leaves)
if not len(dist):
tree = leaves[0]
if unary[0] != vocabulary.PAD and unary[0] != ():
tree = trees.InternalParseNode(unary[0], [tree])
else:
i = np.argmax(dist)
tree_l = distance_to_tree(dist[:i], cons[:i], unary[:i + 1], leaves[:i + 1])
tree_r = distance_to_tree(dist[i + 1:], cons[i + 1:], unary[i + 1:], leaves[i + 1:])
tree = trees.InternalParseNode(cons[i], [tree_l, tree_r])
return tree