-
Notifications
You must be signed in to change notification settings - Fork 0
/
TPR_utils.py
107 lines (95 loc) · 3.7 KB
/
TPR_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
import torch.nn as nn
import torch.nn.functional as F
from node import Node
class TPR(nn.Module):
def __init__(self, args, num_fillers, num_roles, d_filler=32, d_role=32) -> None:
super().__init__()
self.filler_emb = nn.Embedding(num_fillers, d_filler)
self.role_emb = nn.Embedding(num_roles, d_role)
self.filler_emb.requires_grad = False
self.filler_emb.weight.requires_grad = False
self.role_emb.requires_grad = False
self.role_emb.weight.requires_grad = False
self.reset_parameters()
def reset_parameters(self):
nn.init.orthogonal_(self.filler_emb.weight, gain=1)
self.filler_emb.weight.data[0, :] = 0
nn.init.orthogonal_(self.role_emb.weight, gain=1)
def forward(self, tree_tensor):
'''
Given a binary tree represented by a tensor, construct the TPR
'''
x = self.filler_emb(tree_tensor)
return torch.einsum('brm,rn->bmn', x, self.role_emb.weight)
def unbind(self, tpr_tensor, decode=False):
'''
Given a TPR, unbind it
'''
unbinded = torch.einsum('bmn,rn->brm', tpr_tensor, self.role_emb.weight)
if not decode:
return unbinded
return torch.einsum('brm,fm->brf', unbinded, self.filler_emb.weight)
@torch.no_grad()
def build_E(role_emb):
'''
Build E matrices given the role embeddings (binary trees-only)
'''
d_role = role_emb.weight.size(1)
E_l = role_emb.weight.new_zeros(d_role, d_role)
E_r = role_emb.weight.new_zeros(d_role, d_role)
def _add_to(mat, ind_from, ind_to):
if ind_to >= role_emb.weight.size(0):
return
mat += torch.einsum('a,b->ab', role_emb.weight[ind_to], role_emb.weight[ind_from])
_add_to(mat, ind_from*2+1, ind_to*2+1)
_add_to(mat, ind_from*2+2, ind_to*2+2)
_add_to(E_l, 0, 1)
_add_to(E_r, 0, 2)
E_l.requires_grad = False
E_r.requires_grad = False
return E_l, E_r
@torch.no_grad()
def build_D(role_emb):
'''
Build D matrices given the role embeddings (binary trees-only)
'''
d_role = role_emb.weight.size(1)
D_l = role_emb.weight.new_zeros(d_role, d_role)
D_r = role_emb.weight.new_zeros(d_role, d_role)
def _add_to(mat, ind_from, ind_to):
if ind_from >= role_emb.weight.size(0):
return
mat += torch.einsum('a,b->ab', role_emb.weight[ind_to], role_emb.weight[ind_from])
_add_to(mat, ind_from*2+1, ind_to*2+1)
_add_to(mat, ind_from*2+2, ind_to*2+2)
_add_to(D_l, 1, 0)
_add_to(D_r, 2, 0)
D_l.requires_grad = False
D_r.requires_grad = False
return D_l, D_r
def DecodedTPR2Tree(decoded_tpr, eps=1e-2):
contain_symbols = decoded_tpr.norm(p=2, dim=-1) > eps
return torch.where(contain_symbols, decoded_tpr.argmax(dim=-1), 0)
# works for binary trees only
def Symbols2NodeTree(index_tree, i2v):
def _traverse_and_detensorify(par, ind):
if not index_tree[ind]:
return par
cur = Node(i2v[index_tree[ind]])
if par:
par.children.append(cur)
if len(index_tree) > ind*2+1 and index_tree[ind*2+1]:
# work on the left child
_traverse_and_detensorify(cur, ind*2+1)
if len(index_tree) > ind*2+2 and index_tree[ind*2+2]:
# work on the right child
_traverse_and_detensorify(cur, ind*2+2)
return cur
node_tree = _traverse_and_detensorify(None, 0)
return node_tree
# example usage in main.py: BatchSymbols2NodeTree(fully_decoded, train_data.ind2vocab)
def BatchSymbols2NodeTree(decoded_tpr_batch, i2v):
def s2nt(index_tree):
return Symbols2NodeTree(index_tree, i2v)
return list(map(s2nt, decoded_tpr_batch))