-
Notifications
You must be signed in to change notification settings - Fork 40
/
model.py
367 lines (309 loc) · 12.2 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable as Var
import Constants
import utils
class BinaryTreeLeafModule(nn.Module):
"""
local input = nn.Identity()()
local c = nn.Linear(self.in_dim, self.mem_dim)(input)
local h
if self.gate_output then
local o = nn.Sigmoid()(nn.Linear(self.in_dim, self.mem_dim)(input))
h = nn.CMulTable(){o, nn.Tanh()(c)}
else
h = nn.Tanh()(c)
end
local leaf_module = nn.gModule({input}, {c, h})
"""
def __init__(self, cuda, in_dim, mem_dim):
super(BinaryTreeLeafModule, self).__init__()
self.cudaFlag = cuda
self.in_dim = in_dim
self.mem_dim = mem_dim
self.cx = nn.Linear(self.in_dim, self.mem_dim)
self.ox = nn.Linear(self.in_dim, self.mem_dim)
if self.cudaFlag:
self.cx = self.cx.cuda()
self.ox = self.ox.cuda()
def forward(self, input):
c = self.cx(input)
o = F.sigmoid(self.ox(input))
h = o * F.tanh(c)
return c, h
class BinaryTreeComposer(nn.Module):
"""
local lc, lh = nn.Identity()(), nn.Identity()()
local rc, rh = nn.Identity()(), nn.Identity()()
local new_gate = function()
return nn.CAddTable(){
nn.Linear(self.mem_dim, self.mem_dim)(lh),
nn.Linear(self.mem_dim, self.mem_dim)(rh)
}
end
local i = nn.Sigmoid()(new_gate()) -- input gate
local lf = nn.Sigmoid()(new_gate()) -- left forget gate
local rf = nn.Sigmoid()(new_gate()) -- right forget gate
local update = nn.Tanh()(new_gate()) -- memory cell update vector
local c = nn.CAddTable(){ -- memory cell
nn.CMulTable(){i, update},
nn.CMulTable(){lf, lc},
nn.CMulTable(){rf, rc}
}
local h
if self.gate_output then
local o = nn.Sigmoid()(new_gate()) -- output gate
h = nn.CMulTable(){o, nn.Tanh()(c)}
else
h = nn.Tanh()(c)
end
local composer = nn.gModule(
{lc, lh, rc, rh},
{c, h})
"""
def __init__(self, cuda, in_dim, mem_dim):
super(BinaryTreeComposer, self).__init__()
self.cudaFlag = cuda
self.in_dim = in_dim
self.mem_dim = mem_dim
def new_gate():
lh = nn.Linear(self.mem_dim, self.mem_dim)
rh = nn.Linear(self.mem_dim, self.mem_dim)
return lh, rh
self.ilh, self.irh = new_gate()
self.lflh, self.lfrh = new_gate()
self.rflh, self.rfrh = new_gate()
self.ulh, self.urh = new_gate()
if self.cudaFlag:
self.ilh = self.ilh.cuda()
self.irh = self.irh.cuda()
self.lflh = self.lflh.cuda()
self.lfrh = self.lfrh.cuda()
self.rflh = self.rflh.cuda()
self.rfrh = self.rfrh.cuda()
self.ulh = self.ulh.cuda()
def forward(self, lc, lh , rc, rh):
i = F.sigmoid(self.ilh(lh) + self.irh(rh))
lf = F.sigmoid(self.lflh(lh) + self.lfrh(rh))
rf = F.sigmoid(self.rflh(lh) + self.rfrh(rh))
update = F.tanh(self.ulh(lh) + self.urh(rh))
c = i* update + lf*lc + rf*rc
h = F.tanh(c)
return c, h
class BinaryTreeLSTM(nn.Module):
def __init__(self, cuda, in_dim, mem_dim, criterion):
super(BinaryTreeLSTM, self).__init__()
self.cudaFlag = cuda
self.in_dim = in_dim
self.mem_dim = mem_dim
self.criterion = criterion
self.leaf_module = BinaryTreeLeafModule(cuda,in_dim, mem_dim)
self.composer = BinaryTreeComposer(cuda, in_dim, mem_dim)
self.output_module = None
def set_output_module(self, output_module):
self.output_module = output_module
def getParameters(self):
"""
Get flatParameters
note that getParameters and parameters is not equal in this case
getParameters do not get parameters of output module
:return: 1d tensor
"""
params = []
for m in [self.ix, self.ih, self.fx, self.fh, self.ox, self.oh, self.ux, self.uh]:
# we do not get param of output module
l = list(m.parameters())
params.extend(l)
one_dim = [p.view(p.numel()) for p in params]
params = F.torch.cat(one_dim)
return params
def forward(self, tree, embs, training = False):
# add singleton dimension for future call to node_forward
# embs = F.torch.unsqueeze(self.emb(inputs),1)
loss = Var(torch.zeros(1)) # init zero loss
if self.cudaFlag:
loss = loss.cuda()
if tree.num_children == 0:
# leaf case
tree.state = self.leaf_module.forward(embs[tree.idx-1])
else:
for idx in range(tree.num_children):
_, child_loss = self.forward(tree.children[idx], embs, training)
loss = loss + child_loss
lc, lh, rc, rh = self.get_child_state(tree)
tree.state = self.composer.forward(lc, lh, rc, rh)
if self.output_module != None:
output = self.output_module.forward(tree.state[1], training)
tree.output = output
if training and tree.gold_label != None:
target = Var(utils.map_label_to_target_sentiment(tree.gold_label))
if self.cudaFlag:
target = target.cuda()
loss = loss + self.criterion(output, target)
return tree.state, loss
def get_child_state(self, tree):
lc, lh = tree.children[0].state
rc, rh = tree.children[1].state
return lc, lh, rc, rh
###################################################################
# module for childsumtreelstm
class ChildSumTreeLSTM(nn.Module):
def __init__(self, cuda, in_dim, mem_dim, criterion):
super(ChildSumTreeLSTM, self).__init__()
self.cudaFlag = cuda
self.in_dim = in_dim
self.mem_dim = mem_dim
# self.emb = nn.Embedding(vocab_size,in_dim,
# padding_idx=Constants.PAD)
# torch.manual_seed(123)
self.ix = nn.Linear(self.in_dim,self.mem_dim)
self.ih = nn.Linear(self.mem_dim,self.mem_dim)
self.fh = nn.Linear(self.mem_dim, self.mem_dim)
self.fx = nn.Linear(self.in_dim,self.mem_dim)
self.ux = nn.Linear(self.in_dim,self.mem_dim)
self.uh = nn.Linear(self.mem_dim,self.mem_dim)
self.ox = nn.Linear(self.in_dim,self.mem_dim)
self.oh = nn.Linear(self.mem_dim,self.mem_dim)
self.criterion = criterion
self.output_module = None
def set_output_module(self, output_module):
self.output_module = output_module
def getParameters(self):
"""
Get flatParameters
note that getParameters and parameters is not equal in this case
getParameters do not get parameters of output module
:return: 1d tensor
"""
params = []
for m in [self.ix, self.ih, self.fx, self.fh, self.ox, self.oh, self.ux, self.uh]:
# we do not get param of output module
l = list(m.parameters())
params.extend(l)
one_dim = [p.view(p.numel()) for p in params]
params = F.torch.cat(one_dim)
return params
def node_forward(self, inputs, child_c, child_h):
"""
:param inputs: (1, 300)
:param child_c: (num_children, 1, mem_dim)
:param child_h: (num_children, 1, mem_dim)
:return: (tuple)
c: (1, mem_dim)
h: (1, mem_dim)
"""
child_h_sum = F.torch.sum(torch.squeeze(child_h,1),0)
i = F.sigmoid(self.ix(inputs)+self.ih(child_h_sum))
o = F.sigmoid(self.ox(inputs)+self.oh(child_h_sum))
u = F.tanh(self.ux(inputs)+self.uh(child_h_sum))
# add extra singleton dimension
fx = F.torch.unsqueeze(self.fx(inputs),1)
f = F.torch.cat([self.fh(child_hi)+fx for child_hi in child_h], 0)
f = F.sigmoid(f)
# f = F.torch.unsqueeze(f,1) # comment to fix dimension missmatch
fc = F.torch.squeeze(F.torch.mul(f,child_c),1)
c = F.torch.mul(i,u) + F.torch.sum(fc,0)
h = F.torch.mul(o, F.tanh(c))
return c, h
def forward(self, tree, embs, training = False):
"""
Child sum tree LSTM forward function
:param tree:
:param embs: (sentence_length, 1, 300)
:param training:
:return:
"""
# add singleton dimension for future call to node_forward
# embs = F.torch.unsqueeze(self.emb(inputs),1)
loss = Var(torch.zeros(1)) # init zero loss
if self.cudaFlag:
loss = loss.cuda()
for idx in range(tree.num_children):
_, child_loss = self.forward(tree.children[idx], embs, training)
loss = loss + child_loss
child_c, child_h = self.get_child_states(tree)
tree.state = self.node_forward(embs[tree.idx-1], child_c, child_h)
if self.output_module != None:
output = self.output_module.forward(tree.state[1], training)
tree.output = output
if training and tree.gold_label != None:
target = Var(utils.map_label_to_target_sentiment(tree.gold_label))
if self.cudaFlag:
target = target.cuda()
loss = loss + self.criterion(output, target)
return tree.state, loss
def get_child_states(self, tree):
"""
Get c and h of all children
:param tree:
:return: (tuple)
child_c: (num_children, 1, mem_dim)
child_h: (num_children, 1, mem_dim)
"""
# add extra singleton dimension in middle...
# because pytorch needs mini batches... :sad:
if tree.num_children==0:
child_c = Var(torch.zeros(1,1,self.mem_dim))
child_h = Var(torch.zeros(1,1,self.mem_dim))
if self.cudaFlag:
child_c, child_h = child_c.cuda(), child_h.cuda()
else:
child_c = Var(torch.Tensor(tree.num_children,1,self.mem_dim))
child_h = Var(torch.Tensor(tree.num_children,1,self.mem_dim))
if self.cudaFlag:
child_c, child_h = child_c.cuda(), child_h.cuda()
for idx in range(tree.num_children):
child_c[idx] = tree.children[idx].state[0]
child_h[idx] = tree.children[idx].state[1]
# child_c[idx], child_h[idx] = tree.children[idx].state
return child_c, child_h
##############################################################################
# output module
class SentimentModule(nn.Module):
def __init__(self, cuda, mem_dim, num_classes, dropout = False):
super(SentimentModule, self).__init__()
self.cudaFlag = cuda
self.mem_dim = mem_dim
self.num_classes = num_classes
self.dropout = dropout
# torch.manual_seed(456)
self.l1 = nn.Linear(self.mem_dim, self.num_classes)
self.logsoftmax = nn.LogSoftmax()
if self.cudaFlag:
self.l1 = self.l1.cuda()
def forward(self, vec, training = False):
"""
Sentiment module forward function
:param vec: (1, mem_dim)
:param training:
:return:
(1, number_of_class)
"""
if self.dropout:
out = self.logsoftmax(self.l1(F.dropout(vec, training = training)))
else:
out = self.logsoftmax(self.l1(vec))
return out
class TreeLSTMSentiment(nn.Module):
def __init__(self, cuda, vocab_size, in_dim, mem_dim, num_classes, model_name, criterion):
super(TreeLSTMSentiment, self).__init__()
self.cudaFlag = cuda
self.model_name = model_name
if self.model_name == 'dependency':
self.tree_module = ChildSumTreeLSTM(cuda, in_dim, mem_dim, criterion)
elif self.model_name == 'constituency':
self.tree_module = BinaryTreeLSTM(cuda, in_dim, mem_dim, criterion)
self.output_module = SentimentModule(cuda, mem_dim, num_classes, dropout=True)
self.tree_module.set_output_module(self.output_module)
def forward(self, tree, inputs, training = False):
"""
TreeLSTMSentiment forward function
:param tree:
:param inputs: (sentence_length, 1, 300)
:param training:
:return:
"""
tree_state, loss = self.tree_module(tree, inputs, training)
output = tree.output
return output, loss