-
Notifications
You must be signed in to change notification settings - Fork 156
/
Copy pathmodel.py
168 lines (143 loc) · 7.21 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from torch.nn.init import xavier_uniform, xavier_normal, orthogonal
class SubNet(nn.Module):
'''
The subnetwork that is used in TFN for video and audio in the pre-fusion stage
'''
def __init__(self, in_size, hidden_size, dropout):
'''
Args:
in_size: input dimension
hidden_size: hidden layer dimension
dropout: dropout probability
Output:
(return value in forward) a tensor of shape (batch_size, hidden_size)
'''
super(SubNet, self).__init__()
self.norm = nn.BatchNorm1d(in_size)
self.drop = nn.Dropout(p=dropout)
self.linear_1 = nn.Linear(in_size, hidden_size)
self.linear_2 = nn.Linear(hidden_size, hidden_size)
self.linear_3 = nn.Linear(hidden_size, hidden_size)
def forward(self, x):
'''
Args:
x: tensor of shape (batch_size, in_size)
'''
normed = self.norm(x)
dropped = self.drop(normed)
y_1 = F.relu(self.linear_1(dropped))
y_2 = F.relu(self.linear_2(y_1))
y_3 = F.relu(self.linear_3(y_2))
return y_3
class TextSubNet(nn.Module):
'''
The LSTM-based subnetwork that is used in TFN for text
'''
def __init__(self, in_size, hidden_size, out_size, num_layers=1, dropout=0.2, bidirectional=False):
'''
Args:
in_size: input dimension
hidden_size: hidden layer dimension
num_layers: specify the number of layers of LSTMs.
dropout: dropout probability
bidirectional: specify usage of bidirectional LSTM
Output:
(return value in forward) a tensor of shape (batch_size, out_size)
'''
super(TextSubNet, self).__init__()
self.rnn = nn.LSTM(in_size, hidden_size, num_layers=num_layers, dropout=dropout, bidirectional=bidirectional, batch_first=True)
self.dropout = nn.Dropout(dropout)
self.linear_1 = nn.Linear(hidden_size, out_size)
def forward(self, x):
'''
Args:
x: tensor of shape (batch_size, sequence_len, in_size)
'''
_, final_states = self.rnn(x)
h = self.dropout(final_states[0].squeeze())
y_1 = self.linear_1(h)
return y_1
class TFN(nn.Module):
'''
Implements the Tensor Fusion Networks for multimodal sentiment analysis as is described in:
Zadeh, Amir, et al. "Tensor fusion network for multimodal sentiment analysis." EMNLP 2017 Oral.
'''
def __init__(self, input_dims, hidden_dims, text_out, dropouts, post_fusion_dim):
'''
Args:
input_dims - a length-3 tuple, contains (audio_dim, video_dim, text_dim)
hidden_dims - another length-3 tuple, similar to input_dims
text_out - int, specifying the resulting dimensions of the text subnetwork
dropouts - a length-4 tuple, contains (audio_dropout, video_dropout, text_dropout, post_fusion_dropout)
post_fusion_dim - int, specifying the size of the sub-networks after tensorfusion
Output:
(return value in forward) a scalar value between -3 and 3
'''
super(TFN, self).__init__()
# dimensions are specified in the order of audio, video and text
self.audio_in = input_dims[0]
self.video_in = input_dims[1]
self.text_in = input_dims[2]
self.audio_hidden = hidden_dims[0]
self.video_hidden = hidden_dims[1]
self.text_hidden = hidden_dims[2]
self.text_out= text_out
self.post_fusion_dim = post_fusion_dim
self.audio_prob = dropouts[0]
self.video_prob = dropouts[1]
self.text_prob = dropouts[2]
self.post_fusion_prob = dropouts[3]
# define the pre-fusion subnetworks
self.audio_subnet = SubNet(self.audio_in, self.audio_hidden, self.audio_prob)
self.video_subnet = SubNet(self.video_in, self.video_hidden, self.video_prob)
self.text_subnet = TextSubNet(self.text_in, self.text_hidden, self.text_out, dropout=self.text_prob)
# define the post_fusion layers
self.post_fusion_dropout = nn.Dropout(p=self.post_fusion_prob)
self.post_fusion_layer_1 = nn.Linear((self.text_out + 1) * (self.video_hidden + 1) * (self.audio_hidden + 1), self.post_fusion_dim)
self.post_fusion_layer_2 = nn.Linear(self.post_fusion_dim, self.post_fusion_dim)
self.post_fusion_layer_3 = nn.Linear(self.post_fusion_dim, 1)
# in TFN we are doing a regression with constrained output range: (-3, 3), hence we'll apply sigmoid to output
# shrink it to (0, 1), and scale\shift it back to range (-3, 3)
self.output_range = Parameter(torch.FloatTensor([6]), requires_grad=False)
self.output_shift = Parameter(torch.FloatTensor([-3]), requires_grad=False)
def forward(self, audio_x, video_x, text_x):
'''
Args:
audio_x: tensor of shape (batch_size, audio_in)
video_x: tensor of shape (batch_size, video_in)
text_x: tensor of shape (batch_size, sequence_len, text_in)
'''
audio_h = self.audio_subnet(audio_x)
video_h = self.video_subnet(video_x)
text_h = self.text_subnet(text_x)
batch_size = audio_h.data.shape[0]
# next we perform "tensor fusion", which is essentially appending 1s to the tensors and take Kronecker product
if audio_h.is_cuda:
DTYPE = torch.cuda.FloatTensor
else:
DTYPE = torch.FloatTensor
_audio_h = torch.cat((Variable(torch.ones(batch_size, 1).type(DTYPE), requires_grad=False), audio_h), dim=1)
_video_h = torch.cat((Variable(torch.ones(batch_size, 1).type(DTYPE), requires_grad=False), video_h), dim=1)
_text_h = torch.cat((Variable(torch.ones(batch_size, 1).type(DTYPE), requires_grad=False), text_h), dim=1)
# _audio_h has shape (batch_size, audio_in + 1), _video_h has shape (batch_size, _video_in + 1)
# we want to perform outer product between the two batch, hence we unsqueenze them to get
# (batch_size, audio_in + 1, 1) X (batch_size, 1, video_in + 1)
# fusion_tensor will have shape (batch_size, audio_in + 1, video_in + 1)
fusion_tensor = torch.bmm(_audio_h.unsqueeze(2), _video_h.unsqueeze(1))
# next we do kronecker product between fusion_tensor and _text_h. This is even trickier
# we have to reshape the fusion tensor during the computation
# in the end we don't keep the 3-D tensor, instead we flatten it
fusion_tensor = fusion_tensor.view(-1, (self.audio_hidden + 1) * (self.video_hidden + 1), 1)
fusion_tensor = torch.bmm(fusion_tensor, _text_h.unsqueeze(1)).view(batch_size, -1)
post_fusion_dropped = self.post_fusion_dropout(fusion_tensor)
post_fusion_y_1 = F.relu(self.post_fusion_layer_1(post_fusion_dropped))
post_fusion_y_2 = F.relu(self.post_fusion_layer_2(post_fusion_y_1))
post_fusion_y_3 = F.sigmoid(self.post_fusion_layer_3(post_fusion_y_2))
output = post_fusion_y_3 * self.output_range + self.output_shift
return output