-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtransformer.py
473 lines (419 loc) · 16.3 KB
/
transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
"""
The original transformer model with encoder
and masked decoder, and with post-LayerNorm.
"""
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
class PositionalEncoding(nn.Module):
def __init__(self, max_len, d_model, dropout=0.1):
"""
:param max_len: Input length sequence.
:param d_model: Embedding dimension.
:param dropout: Dropout value (default=0.1)
"""
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
"""
Inputs of forward function
:param x: the sequence fed to the positional encoder model (required).
Shape:
x: [sequence length, batch size, embed dim]
output: [sequence length, batch size, embed dim]
"""
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
class Embedding(nn.Module):
def __init__(self, vocab_size, embed_dim):
"""
:param vocab_size: Size of vocabulary, an integer indicating
the maximum unique words in the dataset.
:param embed_dim: The embedding layer dimension.
"""
super(Embedding, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_dim)
def forward(self, x):
"""
:param x: Input vector.
Returns:
out: Embedding vector.
"""
out = self.embed(x)
return out
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim=512, n_heads=8):
"""
:param embed_dim: Embedding dimension.
:param n_heads = Number of attention heads.
"""
super(MultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.n_heads = n_heads
assert embed_dim % n_heads == 0, \
f"Embedding dimension should be divisible by number of heads"
self.head_dim = int(self.embed_dim / self.n_heads)
# Query matrix (64, 64).
self.q = nn.Linear(self.head_dim, self.head_dim)
# Key matrix (64, 64).
self.k = nn.Linear(self.head_dim, self.head_dim)
# Value matrix (64, 64).
self.v = nn.Linear(self.head_dim, self.head_dim)
self.out = nn.Linear(self.n_heads*self.head_dim, self.embed_dim)
def forward(self, key, query, value, mask=None):
"""
:param key: key vector.
:param query: query vector.
:param value: value vector.
:param mask: Whether masking or not, for decoder.
"""
batch_size = key.size(0) # Batch size.
seq_len = key.size(1) # Max. sequence length.
inp_emb = key.size(2) # Embedding dim.
assert inp_emb == self.embed_dim, \
f"Input embedding {inp_emb} should match layer embedding {self.embed_dim}"
seq_len_query = query.size(1)
key = key.view(
batch_size, seq_len, self.n_heads, self.head_dim
) # [bs, seq_len, n_heads, head_dim] ~ [32, 1024, 8, 64]
query = query.view(
batch_size, seq_len_query, self.n_heads, self.head_dim
) # [bs, seq_len, n_heads, head_dim] ~ [32, 1024, 8, 64]
value = value.view(
batch_size, seq_len, self.n_heads, self.head_dim
) # [bs, seq_len, n_heads, head_dim] ~ [32, 1024, 8, 64]
k = self.k(key)
q = self.q(query)
v = self.v(value)
k = k.transpose(1, 2) # [batch_size, n_heads, seq_len, head_dim]
q = q.transpose(1, 2) # [batch_size, n_heads, seq_len, head_dim]
v = v.transpose(1, 2) # [batch_size, n_heads, seq_len, head_dim]
# Scaled-dot product attention.
# Transposed key for matrix multiplication.
k_transposed = k.transpose(-1, -2)
dot = torch.matmul(q, k_transposed)
if mask is not None:
dot = dot.masked_fill_(mask == 0, float('-1e20'))
# Scaling.
dot = dot / math.sqrt(self.head_dim) # / 64.
scores = F.softmax(dot, dim=-1)
# Dot product with value matix.
scores = torch.matmul(scores, v)
concat = scores.transpose(1,2).contiguous().view(
batch_size, seq_len_query, self.head_dim*self.n_heads
)
out = self.out(concat)
return out
class TransformerBlock(nn.Module):
def __init__(self, embed_dim, expansion_factor=4, n_heads=8, dropout=0.3):
super(TransformerBlock, self).__init__()
"""
:param embed_dim: Embedding dimension.
:param expansion_factor: Factor determining the output dimension
of the linear layer.
:param n_heads: Number of attention heads.
"""
self.attention = MultiHeadAttention(embed_dim, n_heads)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.ffn = nn.Sequential(
nn.Linear(embed_dim, expansion_factor*embed_dim),
nn.ReLU(),
nn.Linear(expansion_factor*embed_dim, embed_dim)
)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, key, query, value, mask=None):
"""
:param key: Key vector.
:param query: Query vector.
:param value: Value vector.
Returns:
out: Output of the transformer block.
"""
# Apply attention, then add residual connection, followed by normalization
attn_out = self.attention(key, query, value, mask)
x = key + self.dropout1(attn_out)
x = self.norm1(x)
# Apply feed-forward network, then add residual connection, followed by normalization
ffn_out = self.ffn(x)
x = x + self.dropout2(ffn_out)
x = self.norm2(x)
return x
class TransformerEncoder(nn.Module):
def __init__(
self,
seq_len,
vocab_size,
embed_dim,
num_layers=6,
expansion_factor=4,
n_heads=8,
dropout=0.3
):
"""
:param seq_len: Input sequence length.
:param vocab_size: Number of unique tokens.
:param embed_dim: Embedding dimension.
:param num_layers: Number of encoder layers.
:param expansion_factor: Factor determining the output feature
dimension of the linear layers.
:param n_heads: Number of attention heads.
Returns:
out: Transformer encoder output.
"""
super(TransformerEncoder, self).__init__()
self.embedding = Embedding(vocab_size, embed_dim)
self.positional_encoding = PositionalEncoding(seq_len, embed_dim)
self.layers = nn.ModuleList(
[TransformerBlock(embed_dim, expansion_factor, n_heads, dropout) \
for _ in range(num_layers)]
)
def forward(self, x, mask=None):
x = self.embedding(x)
out = self.positional_encoding(x)
for layer in self.layers:
out = layer(out, out, out, mask) # Query, Key, Value are the same.
return out
class DecoderBlock(nn.Module):
def __init__(self, embed_dim, expansion_factor=4, n_heads=8, dropout=0.3):
"""
:param embed_dim: Embedding dimension.
:param expansion_factor: Factor determining the feature dimension
of linear layers.
:param n_heads: Number of attention heads.
"""
super(DecoderBlock, self).__init__()
self.attention = MultiHeadAttention(embed_dim, n_heads)
self.norm1 = nn.LayerNorm(embed_dim)
self.norm2 = nn.LayerNorm(embed_dim)
self.norm3 = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
self.transformer_block = TransformerBlock(
embed_dim, expansion_factor, n_heads, dropout
)
def forward(self, x, enc_out, src_mask=None, tgt_mask=None):
"""
:param x: Input target vector.
:param enc_out: Encoder output.
:param src_mask: Mask for encoder output.
:param tgt_mask: Mask for decoder self-attention.
Returns:
out: Output of the decoder block.
"""
# Self-attention on target sequence
attn_out = self.attention(x, x, x, mask=tgt_mask)
x = x + self.dropout(attn_out)
x = self.norm1(x)
# Cross-attention with encoder output
attn_out = self.attention(enc_out, x, enc_out, mask=src_mask)
x = x + self.dropout(attn_out)
x = self.norm2(x)
# Feed-forward network
ffn_out = self.transformer_block.ffn(x)
x = x + self.dropout(ffn_out)
x = self.norm3(x)
return x
class TransformerDecoder(nn.Module):
def __init__(
self,
tgt_vocab_size,
embed_dim,
seq_len,
num_layers=6,
expansion_factor=4,
n_heads=8,
dropout=0.3
):
"""
:param tgt_vocab_size: Target vocabuluary size.
:param embed_dim: Embedding dimension.
:param seq_len: Input sequence lenght.
:param num_layers: Number of transformer layers.
:param expansion_factor: Factor to determine the intermediate
output feature dimension of linear layers.
:param n_heads: Number of self attention heads.
"""
super(TransformerDecoder, self).__init__()
self.embedding = Embedding(tgt_vocab_size, embed_dim)
self.postional_encoding = PositionalEncoding(seq_len, embed_dim)
self.layers = nn.ModuleList(
[
DecoderBlock(embed_dim, expansion_factor, n_heads, dropout) \
for _ in range(num_layers)
]
)
self.fc = nn.Linear(embed_dim, tgt_vocab_size)
self.dropout = nn.Dropout(dropout)
def forward(self, x, enc_out, src_mask, tgt_mask):
"""
:param x: Input target vector.
:param enc_out: Encoder layer output.
:param mask: Decoder self attention mask.
Returns:
out: Output vector.
"""
x = self.embedding(x)
x = self.postional_encoding(x)
x = self.dropout(x)
for layer in self.layers:
x = layer(x, enc_out, src_mask, tgt_mask)
out = self.fc(x)
return out
class Transformer(nn.Module):
def __init__(
self,
embed_dim,
src_vocab_size,
tgt_vocab_size,
seq_len,
num_layers=6,
expansion_factor=4,
n_heads=8,
dropout=0.3,
device='cpu'
):
"""
:param embed_dim: Embedding dimension.
:param src_vocab_size: Source vocabulary size.
:param tgt_vocab_size: Target vocabuluary size.
:param seq_len: Input sequence lenght.
:param num_layers: Number of transformer layers.
:param expansion_factor: Factor to determine the intermediate
output feature dimension of linear layers.
:param n_heads: Number of self attention heads.
"""
super(Transformer, self).__init__()
self.tgt_vocab_size = tgt_vocab_size
self.encoder = TransformerEncoder(
seq_len,
src_vocab_size,
embed_dim,
num_layers,
expansion_factor,
n_heads,
dropout
)
self.decoder = TransformerDecoder(
tgt_vocab_size,
embed_dim,
seq_len,
num_layers,
expansion_factor,
n_heads,
dropout
)
self.device=device
def make_tgt_mask(self, tgt, pad_token_id=1):
"""
:param tgt: Target sequence.
:param pad_token_id: Padding token ID, default 1.
Returns:
tgt_mask: Target mask.
"""
batch_size = tgt.shape[0]
device = tgt.device
# Some help from here:
# https://github.com/gordicaleksa/pytorch-original-transformer/blob/main/utils/data_utils.py
# Same as src_mask but we additionally want to mask tokens from looking forward into the future tokens
# Note: wherever the mask value is true we want to attend to that token, otherwise we mask (ignore) it.
sequence_length = tgt.shape[1] # trg_token_ids shape = (B, T) where T max trg token-sequence length
trg_padding_mask = (tgt != pad_token_id).view(batch_size, 1, 1, -1) # shape = (B, 1, 1, T)
trg_no_look_forward_mask = torch.triu(torch.ones((
1, 1, sequence_length, sequence_length), device=device
) == 1).transpose(2, 3)
# logic AND operation (both padding mask and no-look-forward must be true to attend to a certain target token)
tgt_mask = trg_padding_mask & trg_no_look_forward_mask # final shape = (B, 1, T, T)
return tgt_mask
def make_src_mask(self, src, pad_token_id=1):
"""
:param src: Source sequence.
:param pad_token_id: Padding token ID, default 1.
Returns:
src_mask: Source mask.
"""
batch_size = src.shape[0]
# Some help from here:
# https://github.com/gordicaleksa/pytorch-original-transformer/blob/main/utils/data_utils.py
# src_mask shape = (B, 1, 1, S) check out attention function in transformer_model.py where masks are applied
# src_mask only masks pad tokens as we want to ignore their representations (no information in there...)
src_mask = (src != pad_token_id).view(batch_size, 1, 1, -1)
return src_mask
def decode(self, src, tgt):
"""
:param src: Encoder input
:param tgt: Decoder input
Returns:
out_labels: Final prediction sequence
"""
tgt_mask = self.make_tgt_mask(tgt)
src_mask = self.make_src_mask(src)
enc_out = self.encoder(src)
out_labels = []
batch_size, seq_len = src.shape[0], src.shape[1]
out = tgt
for i in range(seq_len):
out = F.log_softmax(self.decoder(out, enc_out, src_mask, tgt_mask), dim=-1)
# out = out[:, -1, :]
out = out.reshape(-1, out.shape[-1])
# out = out.argmax(-1)
num_of_trg_tokens = len(tgt[0])
out = out[num_of_trg_tokens-1::num_of_trg_tokens]
out = torch.argmax(out, dim=-1)
out_labels.append(out.item())
out = torch.unsqueeze(out, 0)
return out_labels
def forward(self, src, tgt):
"""
:param src: Encoder input.
:param tgt: Decoder input
Returns:
out: Output vector containing probability of each token.
"""
src_mask = self.make_src_mask(src).to(self.device)
tgt_mask = self.make_tgt_mask(tgt).to(self.device)
enc_out = self.encoder(src, src_mask)
out = self.decoder(tgt, enc_out, src_mask, tgt_mask)
return out
if __name__ == "__main__":
# Parameters for testing
embed_dim = 512
src_vocab_size = 10000
tgt_vocab_size = 10000
seq_len = 512
num_layers = 6
expansion_factor = 4
n_heads = 8
dropout = 0.3
# Create a dummy input tensor for testing
src_input = torch.randint(0, src_vocab_size, (1, seq_len))
tgt_input = torch.randint(0, tgt_vocab_size, (1, seq_len))
# Initialize the model with provided parameters
transformer_model = Transformer(
embed_dim,
src_vocab_size,
tgt_vocab_size,
seq_len,
num_layers=num_layers,
expansion_factor=expansion_factor,
n_heads=n_heads,
dropout=dropout
)
# Forward pass through the encoder and decoder to check shapes
src_mask = transformer_model.make_src_mask(src_input)
tgt_mask = transformer_model.make_tgt_mask(tgt_input)
enc_output = transformer_model.encoder(src_input, src_mask)
print(f"Encoder output shape: {enc_output.shape}")
dec_output = transformer_model.decoder(tgt_input, enc_output, src_mask, tgt_mask)
print(f"Decoder output shape: {dec_output.shape}")
# Forward pass through the entire model
out_labels = transformer_model(src_input, tgt_input)
print(f"Model output shape: {out_labels.shape}")