/
transformer.py
410 lines (371 loc) · 16.2 KB
/
transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
from math import pi, sqrt
import torch as th
from torch import nn
# Negative infinity constant
NEG_INF = -100000
def GlorotLinear(input_dim, output_dim):
"""Returns a Glorot initialized linear layer for optimal gradient flow"""
linear = nn.Linear(input_dim, output_dim)
nn.init.xavier_uniform_(linear.weight)
nn.init.constant_(linear.bias, 0)
return linear
class MultiHeadAttention(nn.Module):
"""Multi head attention"""
def __init__(self, embed_dim, n_heads):
super(MultiHeadAttention, self).__init__()
# Hyper-parameters
self.embed_dim = embed_dim
self.n_heads = n_heads
self.head_dim = embed_dim // n_heads
if embed_dim % n_heads != 0:
raise ValueError("embed_dim must be a multiple of n_heads")
# Input projection layers
self.query = GlorotLinear(self.embed_dim, self.embed_dim)
self.key = GlorotLinear(self.embed_dim, self.embed_dim)
self.value = GlorotLinear(self.embed_dim, self.embed_dim)
# Output projection layers
self.output = GlorotLinear(self.embed_dim, self.embed_dim)
def forward(
self,
queries,
keys,
values,
in_mask=None,
causal_masking=False,
return_weights=False,
):
"""
:param queries: Tensor of shape m x b x embed_dim where m is the length
dimension and b the batch dimension
:param keys: Tensor of shape n x b x embed_dim where n is the length
dimension and b the batch dimension
:param values: Tensor of shape n x b x embed_dim where n is the length
dimension and b the batch dimension
:param in_mask: n x b mask with 1 at positions that shouldn't be
attended to (typically padding tokens)
:param causal_masking: For each query position i, set the attention to
all key positions j >i to 0, thus preventing the model from
attending "to the future" (typically in unidirectional
language models)
:param return_weights: Return attention weights
"""
m, bsz, _ = queries.size()
n, _, _ = keys.size()
# Project keys, queries and values (all of shape m/n x b x embed_dim)
# Reshape the last dim as n_heads x head_dims
q = self.query(queries).view(m, bsz, self.n_heads, self.head_dim)
k = self.key(keys).view(n, bsz, self.n_heads, self.head_dim)
v = self.value(values).view(n, bsz, self.n_heads, self.head_dim)
# Compute attention potentials
potentials = th.einsum("mbhd,nbhd->mnbh", [q, k])
# Rescale by inverse sqrt of the dimension for well behaved softmax
potentials /= sqrt(self.embed_dim)
# Mask certain input positions
if in_mask is not None:
in_mask = in_mask.view(1, n, bsz, 1)
potentials = potentials.masked_fill(in_mask, NEG_INF)
# Causal masking: make it impossible to "attend to the future"
if causal_masking:
# We want causal_mask[i, j] = 1 if j > i
causal_mask = th.triu(th.ones(m, n), diagonal=1).view(m, n, 1, 1)
causal_mask = causal_mask.eq(1).to(potentials.device)
potentials = potentials.masked_fill(causal_mask, NEG_INF)
# Softmax over the input length n, differently for each head
weights = nn.functional.softmax(potentials, dim=1)
# Compute the pooled values
pooled_v = th.einsum("mnbh,nbhd->mbhd", [weights, v]).contiguous()
# Output projection
output = self.output(pooled_v.view(m, bsz, -1))
if return_weights:
return output, weights
else:
return output
class FeedForwardTransducer(nn.Module):
"""Applies a 2-layer MLP to each position in a sequence"""
def __init__(self, embed_dim, hidden_dim, dropout=0.0):
super(FeedForwardTransducer, self).__init__()
# Hyper parameters
self.embed_dim = embed_dim
self.hidden_dim = hidden_dim
self.dropout = dropout
# Layers
self.layers = nn.Sequential(
GlorotLinear(self.embed_dim, self.hidden_dim), # Input projection
nn.ReLU(), # Activation
nn.Dropout(p=self.dropout), # Dropout
GlorotLinear(self.hidden_dim, self.embed_dim), # Output projection
)
def forward(self, x):
"""
:param x: Tensor of shape n x b x embed_dim where n is the length
dimension and b the batch dimension
"""
return self.layers(x)
class EncoderLayer(nn.Module):
"""Transformer encoder layer"""
def __init__(self, embed_dim, n_heads, hidden_dim, dropout=0.0):
super(EncoderLayer, self).__init__()
# Hyper parameters
self.embed_dim = embed_dim
self.n_heads = n_heads
self.hidden_dim = hidden_dim
self.dropout = dropout
# Sub-layers
# Self attention
self.layer_norm_self_att = nn.LayerNorm(embed_dim)
self.self_att = MultiHeadAttention(embed_dim, n_heads)
self.drop_self_att = nn.Dropout(p=dropout)
# Feed forward
self.layer_norm_ff = nn.LayerNorm(embed_dim)
self.ff = FeedForwardTransducer(embed_dim, hidden_dim, dropout)
self.drop_ff = nn.Dropout(p=dropout)
def forward(self, x, src_mask=None):
"""
:param x: Tensor of shape n x b x embed_dim where n is the length
dimension and b the batch dimension
:param src_mask: Mask of shape n x b indicating padding tokens in
the source sentences (for masking in self-attention)
"""
# TODO 1: Implement the forward pass of a transformer encoder layer
# Remember, there are 2 modules: self-attention and position-wise
# feed forward
# Don't forget layer normalization and residual connections!
raise NotImplementedError("TODO 1")
class DecoderLayer(nn.Module):
"""Transformer decoder layer"""
def __init__(self, embed_dim, n_heads, hidden_dim, dropout=0.0):
super(DecoderLayer, self).__init__()
# Hyper parameters
self.embed_dim = embed_dim
self.n_heads = n_heads
self.hidden_dim = hidden_dim
self.dropout = dropout
# Sub-layers
# Self attention
self.layer_norm_self_att = nn.LayerNorm(embed_dim)
self.self_att = MultiHeadAttention(embed_dim, n_heads)
self.drop_self_att = nn.Dropout(p=dropout)
# Encoder attention
self.layer_norm_enc_att = nn.LayerNorm(embed_dim)
self.enc_att = MultiHeadAttention(embed_dim, n_heads)
self.drop_enc_att = nn.Dropout(p=dropout)
# Feed forward
self.layer_norm_ff = nn.LayerNorm(embed_dim)
self.ff = FeedForwardTransducer(embed_dim, hidden_dim, dropout)
self.drop_ff = nn.Dropout(p=dropout)
def forward(self, x, encodings, src_mask=None):
"""
:param x: Input to this layer. Tensor of shape n x b x embed_dim where
n is the length dimension and b the batch dimension
:param encodings: Output from the encoder. Tensor of shape
n x b x embed_dim where n is the length dimension and b the batch
dimension
:param src_mask: Mask of shape n x b indicating padding tokens in
the source sentences (for masking in encoder-attention)
"""
# TODO 1: Implement the forward pass of a transformer decoder layer
# Remember, there are 3 modules: self-attention, encoder attention
# and position-wise feed forward
# Don't forget layer normalization and residual connections!
raise NotImplementedError("TODO 1")
def decode_step(
self,
x,
encodings,
state,
src_mask=None,
):
"""
This performs a forward pass on a single vector.
This is used during decoding.
:param x: Tensor of shape 1 x b x embed_dim where b is the batch
dimension. This is the input at the current position only
:param encodings: Output from the encoder. Tensor of shape
n x b x embed_dim where n is the length dimension and b the batch
dimension
:param src_mask: Mask of shape n x b indicating padding tokens in
the source sentences (for masking in self-attention)
:param state: This is either None or a n x b x embed_dim tensor
containing the inputs to the self attention layers up until
this position. This method returns an updated state
"""
# TODO 2: implement a decode step of the transformer decoder layer
# this is more or less the same as the forward pass except for 2 facts:
# 1. The input is now a single vector (or a batch of vector)
# 2. You need to handle the state of the decoder. At decoding step t
# (for batch size bsz), the state for this layer will have shape
# t x bsz x embed_dim. It represents the input to the self attention
# layer for all previous positions. This is the only onformation we
# need to compute the layer's aoutput at step t. You need to both
# use the state during the forward pass and update it to account for
# the current step. Finally, for the 1st step, the state will be None
# (you should handle this case)
raise NotImplementedError()
def sin_embeddings(max_pos, dim):
"""Returns sinusoidal embedings(for position embeddings)"""
# Scale for each dimension
dim_scale = 2 * (th.arange(dim) / 2).long().float() / dim
dim_scale = th.pow(th.full((dim,), 10000.0), dim_scale).view(1, -1)
# Phase to change sine to cosine every other dim
phase = th.zeros((1, dim))
phase[0, 1::2] = pi / 2
# Position value
pos = th.arange(max_pos).float().view(-1, 1)
# Embeddings
embeds = th.sin(pos / dim_scale + phase)
return embeds
class Transformer(nn.Module):
"""The full transformer model"""
def __init__(
self,
n_layers,
embed_dim,
hidden_dim,
n_heads,
vocab,
dropout=0.0
):
"""
:param n_layers: Number of layers (both encoder and decoder)
:param embed_dim: Embedding dimension (dimension throughout the model)
:param hidden_dim: Dimension of the hidden layer in position-wise
feed-forward layers
:param n_heads: Number of attention heads
:param vocab: Vocabulary object (see data.py)
:param dropout: Dopout probability
"""
super(Transformer, self).__init__()
# Hyper-parameters
self.n_layers = n_layers
self.embed_dim = embed_dim
self.hidden_dim = hidden_dim
self.n_heads = n_heads
self.vocab = vocab
# Token embeddings (this will be shared for encoder/decoder)
self.embeds = nn.Embedding(len(vocab), embed_dim, 0)
nn.init.normal_(self.embeds.weight, std=1/sqrt(embed_dim))
self.embed_drop = nn.Dropout(p=dropout)
# Positional embeddings
self.pos_embeds = sin_embeddings(2048, embed_dim)
# Encoder Layers
self.encoder_layers = nn.ModuleList([
EncoderLayer(embed_dim, n_heads, hidden_dim, dropout=dropout)
for l in range(n_layers)
])
# Final encoder layer norm
self.layer_norm_enc = nn.LayerNorm(embed_dim)
# Output proj (this is important because the embeddings are tied)
# and the output has been "layer-normalized".
# this layer can adjust the scale before the logits.
self.out_proj = GlorotLinear(embed_dim, embed_dim)
# Decoder Layers
self.decoder_layers = nn.ModuleList([
DecoderLayer(embed_dim, n_heads, hidden_dim, dropout=dropout)
for l in range(n_layers)
])
# Final decoder layer norm
self.layer_norm_dec = nn.LayerNorm(embed_dim)
# Output projection for the logits
self.logits = GlorotLinear(embed_dim, len(vocab))
# Share embedding and softmax weights
self.logits.weight = self.embeds.weight
def encode(self, src_tokens, src_mask=None):
"""
This encodes a batch of tokens (for feeding into the decoder)
:param src_tokens: Tensor of integers of shape n x b representing
the source tokens
:param src_mask: Tensor of shape n x b identifying the padding
tokens for masking
"""
# Embed and rescale
x = self.embeds(src_tokens) * sqrt(self.embed_dim)
# Apply dropout
x = self.embed_drop(x)
# Add position embedding
pos_offset = self.pos_embeds[:x.size(0)].view(-1, 1, self.embed_dim)
x += pos_offset.to(x.device).detach()
# Run through the encoder
for layer in self.encoder_layers:
x = layer(x, src_mask=src_mask)
# Layer normalize
# (to prevent all the residual connections from blowing up)
return self.layer_norm_enc(x)
def forward(self, src_tokens, tgt_tokens, src_mask=None):
"""
Returns a tensor log_p of shape m x b x |V| where log_p[i, k, w]
corresponds to the log probability of word w being at position i
in the bth target sentence (conditioned on the bth source sentence
and all the tokens at positions <i).
:param src_tokens: Tensor of integers of shape n x b representing
the source tokens
:param src_tokens: Tensor of integers of shape m x b representing
the target tokens
:param src_mask: Tensor of shape n x b identifying the padding
tokens for masking
"""
# Encode source tokens
encodings = self.encode(src_tokens, src_mask)
# Embed target tokens
h = self.embeds(tgt_tokens) * sqrt(self.embed_dim)
h = self.embed_drop(h)
# Add position embeddings
pos_offset = self.pos_embeds[:h.size(0)].view(-1, 1, self.embed_dim)
h += pos_offset.to(h.device).detach()
# Pass through all decoder layers
for layer in self.decoder_layers:
h = layer(h, encodings, src_mask=src_mask)
# Final layer norm so things don't blow up
h = self.layer_norm_dec(h)
# Output proj (into the embedding dimension).
# This is necessary to make the model expressive enough since the
# softmax weights are shared with the embeddings
h = self.out_proj(h)
# obtain logits for every word
logits = self.logits(h)
# Return log probs
return nn.functional.log_softmax(logits, dim=-1)
def decode_step(
self,
tgt_token,
encodings,
states,
src_mask=None,
):
"""
This performs a forward pass on a single vector.
This is used during decoding.
:param x: Tensor of shape 1 x b x embed_dim where b is the batch
dimension. This is the input at the current position only
:param src_mask: Mask of shape n x b indicating padding tokens in
the source sentences (for masking in self-attention)
:param states: This is a list of either None or a n x b x embed_dim
tensors containing the inputs to each self attention layers up
until this position. This method returns an updated state.
"""
new_states = []
h = self.embeds(tgt_token) * sqrt(self.embed_dim)
h = self.embed_drop(h)
# Add position embedding
pos = 0 if states[0] is None else states[0].size(0)
pos_offset = self.pos_embeds[pos].view(1, 1, -1)
h += pos_offset.to(h.device).detach()
# Pass through all layers
for layer, state in zip(self.decoder_layers, states):
h, new_state = layer.decode_step(
h,
encodings,
state,
src_mask=src_mask,
)
new_states.append(new_state)
# Final layer norm so things don't blow up
h = self.layer_norm_dec(h)
# Output proj
h = self.out_proj(h)
logits = self.logits(h)
# Log prob at this position
log_p = nn.functional.log_softmax(logits, dim=-1)
return log_p, new_states
def initial_state(self):
"""Returns the initial state for decoding (a list of None)"""
return [None for _ in range(self.n_layers)]