-
Notifications
You must be signed in to change notification settings - Fork 101
/
transformer.py
307 lines (239 loc) · 11.1 KB
/
transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import utils
# pylint: disable=arguments-differ
def initialize_weight(x):
nn.init.xavier_uniform_(x.weight)
if x.bias is not None:
nn.init.constant_(x.bias, 0)
class FeedForwardNetwork(nn.Module):
def __init__(self, hidden_size, filter_size, dropout_rate):
super(FeedForwardNetwork, self).__init__()
self.layer1 = nn.Linear(hidden_size, filter_size)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout_rate)
self.layer2 = nn.Linear(filter_size, hidden_size)
initialize_weight(self.layer1)
initialize_weight(self.layer2)
def forward(self, x):
x = self.layer1(x)
x = self.relu(x)
x = self.dropout(x)
x = self.layer2(x)
return x
class MultiHeadAttention(nn.Module):
def __init__(self, hidden_size, dropout_rate, head_size=8):
super(MultiHeadAttention, self).__init__()
self.head_size = head_size
self.att_size = att_size = hidden_size // head_size
self.scale = att_size ** -0.5
self.linear_q = nn.Linear(hidden_size, head_size * att_size, bias=False)
self.linear_k = nn.Linear(hidden_size, head_size * att_size, bias=False)
self.linear_v = nn.Linear(hidden_size, head_size * att_size, bias=False)
initialize_weight(self.linear_q)
initialize_weight(self.linear_k)
initialize_weight(self.linear_v)
self.att_dropout = nn.Dropout(dropout_rate)
self.output_layer = nn.Linear(head_size * att_size, hidden_size,
bias=False)
initialize_weight(self.output_layer)
def forward(self, q, k, v, mask, cache=None):
orig_q_size = q.size()
d_k = self.att_size
d_v = self.att_size
batch_size = q.size(0)
# head_i = Attention(Q(W^Q)_i, K(W^K)_i, V(W^V)_i)
q = self.linear_q(q).view(batch_size, -1, self.head_size, d_k)
if cache is not None and 'encdec_k' in cache:
k, v = cache['encdec_k'], cache['encdec_v']
else:
k = self.linear_k(k).view(batch_size, -1, self.head_size, d_k)
v = self.linear_v(v).view(batch_size, -1, self.head_size, d_v)
if cache is not None:
cache['encdec_k'], cache['encdec_v'] = k, v
q = q.transpose(1, 2) # [b, h, q_len, d_k]
v = v.transpose(1, 2) # [b, h, v_len, d_v]
k = k.transpose(1, 2).transpose(2, 3) # [b, h, d_k, k_len]
# Scaled Dot-Product Attention.
# Attention(Q, K, V) = softmax((QK^T)/sqrt(d_k))V
q.mul_(self.scale)
x = torch.matmul(q, k) # [b, h, q_len, k_len]
x.add_(torch.zeros_like(x).masked_fill_(mask.unsqueeze(1), -1e9))
x = torch.softmax(x, dim=3)
x = self.att_dropout(x)
x = x.matmul(v) # [b, h, q_len, attn]
x = x.transpose(1, 2).contiguous() # [b, q_len, h, attn]
x = x.view(batch_size, -1, self.head_size * d_v)
x = self.output_layer(x)
assert x.size() == orig_q_size
return x
class EncoderLayer(nn.Module):
def __init__(self, hidden_size, filter_size, dropout_rate):
super(EncoderLayer, self).__init__()
self.self_attention_norm = nn.LayerNorm(hidden_size, eps=1e-6)
self.self_attention = MultiHeadAttention(hidden_size, dropout_rate)
self.self_attention_dropout = nn.Dropout(dropout_rate)
self.ffn_norm = nn.LayerNorm(hidden_size, eps=1e-6)
self.ffn = FeedForwardNetwork(hidden_size, filter_size, dropout_rate)
self.ffn_dropout = nn.Dropout(dropout_rate)
def forward(self, x, mask): # pylint: disable=arguments-differ
y = self.self_attention_norm(x)
y = self.self_attention(y, y, y, mask)
y = self.self_attention_dropout(y)
x = x + y
y = self.ffn_norm(x)
y = self.ffn(y)
y = self.ffn_dropout(y)
x = x + y
return x
class DecoderLayer(nn.Module):
def __init__(self, hidden_size, filter_size, dropout_rate):
super(DecoderLayer, self).__init__()
self.self_attention_norm = nn.LayerNorm(hidden_size, eps=1e-6)
self.self_attention = MultiHeadAttention(hidden_size, dropout_rate)
self.self_attention_dropout = nn.Dropout(dropout_rate)
self.enc_dec_attention_norm = nn.LayerNorm(hidden_size, eps=1e-6)
self.enc_dec_attention = MultiHeadAttention(hidden_size, dropout_rate)
self.enc_dec_attention_dropout = nn.Dropout(dropout_rate)
self.ffn_norm = nn.LayerNorm(hidden_size, eps=1e-6)
self.ffn = FeedForwardNetwork(hidden_size, filter_size, dropout_rate)
self.ffn_dropout = nn.Dropout(dropout_rate)
def forward(self, x, enc_output, self_mask, i_mask, cache):
y = self.self_attention_norm(x)
y = self.self_attention(y, y, y, self_mask)
y = self.self_attention_dropout(y)
x = x + y
if enc_output is not None:
y = self.enc_dec_attention_norm(x)
y = self.enc_dec_attention(y, enc_output, enc_output, i_mask,
cache)
y = self.enc_dec_attention_dropout(y)
x = x + y
y = self.ffn_norm(x)
y = self.ffn(y)
y = self.ffn_dropout(y)
x = x + y
return x
class Encoder(nn.Module):
def __init__(self, hidden_size, filter_size, dropout_rate, n_layers):
super(Encoder, self).__init__()
encoders = [EncoderLayer(hidden_size, filter_size, dropout_rate)
for _ in range(n_layers)]
self.layers = nn.ModuleList(encoders)
self.last_norm = nn.LayerNorm(hidden_size, eps=1e-6)
def forward(self, inputs, mask):
encoder_output = inputs
for enc_layer in self.layers:
encoder_output = enc_layer(encoder_output, mask)
return self.last_norm(encoder_output)
class Decoder(nn.Module):
def __init__(self, hidden_size, filter_size, dropout_rate, n_layers):
super(Decoder, self).__init__()
decoders = [DecoderLayer(hidden_size, filter_size, dropout_rate)
for _ in range(n_layers)]
self.layers = nn.ModuleList(decoders)
self.last_norm = nn.LayerNorm(hidden_size, eps=1e-6)
def forward(self, targets, enc_output, i_mask, t_self_mask, cache):
decoder_output = targets
for i, dec_layer in enumerate(self.layers):
layer_cache = None
if cache is not None:
if i not in cache:
cache[i] = {}
layer_cache = cache[i]
decoder_output = dec_layer(decoder_output, enc_output,
t_self_mask, i_mask, layer_cache)
return self.last_norm(decoder_output)
class Transformer(nn.Module):
def __init__(self, i_vocab_size, t_vocab_size,
n_layers=6,
hidden_size=512,
filter_size=2048,
dropout_rate=0.1,
share_target_embedding=True,
has_inputs=True,
src_pad_idx=None,
trg_pad_idx=None,
max_seq_len=256):
super(Transformer, self).__init__()
self.hidden_size = hidden_size
self.emb_scale = hidden_size ** 0.5
self.has_inputs = has_inputs
self.src_pad_idx = src_pad_idx
self.trg_pad_idx = trg_pad_idx
self.t_vocab_embedding = nn.Embedding(t_vocab_size, hidden_size)
nn.init.normal_(self.t_vocab_embedding.weight, mean=0,
std=hidden_size**-0.5)
self.t_emb_dropout = nn.Dropout(dropout_rate)
self.decoder = Decoder(hidden_size, filter_size,
dropout_rate, n_layers)
if has_inputs:
if not share_target_embedding:
self.i_vocab_embedding = nn.Embedding(i_vocab_size,
hidden_size)
nn.init.normal_(self.i_vocab_embedding.weight, mean=0,
std=hidden_size**-0.5)
else:
self.i_vocab_embedding = self.t_vocab_embedding
self.i_emb_dropout = nn.Dropout(dropout_rate)
self.encoder = Encoder(hidden_size, filter_size,
dropout_rate, n_layers)
# For positional encoding
num_timescales = self.hidden_size // 2
max_timescale = 10000.0
min_timescale = 1.0
log_timescale_increment = (
math.log(float(max_timescale) / float(min_timescale)) /
max(num_timescales - 1, 1))
inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float32) *
-log_timescale_increment)
self.register_buffer('inv_timescales', inv_timescales)
def forward(self, inputs, targets):
enc_output, i_mask = None, None
if self.has_inputs:
i_mask = utils.create_pad_mask(inputs, self.src_pad_idx)
enc_output = self.encode(inputs, i_mask)
t_mask = utils.create_pad_mask(targets, self.trg_pad_idx)
target_size = targets.size()[1]
t_self_mask = utils.create_trg_self_mask(target_size,
device=targets.device)
return self.decode(targets, enc_output, i_mask, t_self_mask, t_mask)
def encode(self, inputs, i_mask):
# Input embedding
input_embedded = self.i_vocab_embedding(inputs)
input_embedded.masked_fill_(i_mask.squeeze(1).unsqueeze(-1), 0)
input_embedded *= self.emb_scale
input_embedded += self.get_position_encoding(inputs)
input_embedded = self.i_emb_dropout(input_embedded)
return self.encoder(input_embedded, i_mask)
def decode(self, targets, enc_output, i_mask, t_self_mask, t_mask,
cache=None):
# target embedding
target_embedded = self.t_vocab_embedding(targets)
target_embedded.masked_fill_(t_mask.squeeze(1).unsqueeze(-1), 0)
# Shifting
target_embedded = target_embedded[:, :-1]
target_embedded = F.pad(target_embedded, (0, 0, 1, 0))
target_embedded *= self.emb_scale
target_embedded += self.get_position_encoding(targets)
target_embedded = self.t_emb_dropout(target_embedded)
# decoder
decoder_output = self.decoder(target_embedded, enc_output, i_mask,
t_self_mask, cache)
# linear
output = torch.matmul(decoder_output,
self.t_vocab_embedding.weight.transpose(0, 1))
return output
def get_position_encoding(self, x):
max_length = x.size()[1]
position = torch.arange(max_length, dtype=torch.float32,
device=x.device)
scaled_time = position.unsqueeze(1) * self.inv_timescales.unsqueeze(0)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)],
dim=1)
signal = F.pad(signal, (0, 0, 0, self.hidden_size % 2))
signal = signal.view(1, max_length, self.hidden_size)
return signal