-
Notifications
You must be signed in to change notification settings - Fork 272
/
transformer.py
247 lines (210 loc) · 9.12 KB
/
transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
from typing import Optional
import torch
from torch import nn, Tensor
from torchtune.modules import CausalSelfAttention, KVCache
class TransformerDecoderLayer(nn.Module):
"""Transformer layer derived from the Llama2 model. Normalization is applied before the attention **and** FF layer.
Args:
attn (CausalSelfAttention): Attention module.
mlp (nn.Module): Feed-forward module.
sa_norm (nn.Module): Normalization to be applied before self-attention.
mlp_norm (nn.Module): Normalization to be applied before the feed-forward layer.
"""
def __init__(
self,
attn: CausalSelfAttention,
mlp: nn.Module,
sa_norm: nn.Module,
mlp_norm: nn.Module,
) -> None:
super().__init__()
self.sa_norm = sa_norm
self.attn = attn
self.mlp_norm = mlp_norm
self.mlp = mlp
def forward(
self,
x: Tensor,
*,
mask: Optional[Tensor] = None,
input_pos: Optional[Tensor] = None,
) -> Tensor:
"""
Args:
x (Tensor): input tensor with shape
[batch_size x seq_length x embed_dim]
mask (Optional[Tensor]): Optional boolean tensor which contains the attention mask
with shape [batch_size x seq_length x seq_length]. This is applied after
the query-key multiplication and before the softmax. A value of True in row i
and column j means token i attends to token j. A value of False means token i
does not attend to token j. If no mask is specified, a causal mask
is used by default. Default is None.
input_pos (Optional[Tensor]): Optional tensor which contains the position ids
of each token. During training, this is used to indicate the positions
of each token relative to its sample when packed, shape [b x s].
During inference, this indicates the position of the current token.
If none, assume the index of the token is its position id. Default is None.
Returns:
Tensor: output tensor with same shape as input
[batch_size x seq_length x embed_dim]
TODO:
- Make position of norm configurable
"""
# Input tensor and attention output have the same shape
# [b, s, d]
# Norm applied before self-attention
attn_out = self.attn(self.sa_norm(x), mask=mask, input_pos=input_pos)
# Residual connection; shape: [batch_size, seq_length, embed_dim]
h = attn_out + x
# Norm applied before the feedforward layer
mlp_out = self.mlp(self.mlp_norm(h))
# Residual connection; shape: [batch_size, seq_length, embed_dim]
out = h + mlp_out
return out
def _get_clones(module: nn.Module, n: int) -> nn.ModuleList:
"""
Return a list of ``n`` identical layers.
Args:
module (nn.Module): module to be cloned
n (int): number of clones
Returns:
nn.ModuleList: list of ``n`` identical layers
"""
# FIXME: copy.deepcopy() is not defined on nn.module
return nn.ModuleList([copy.deepcopy(module) for i in range(n)])
class TransformerDecoder(nn.Module):
"""
Transformer Decoder derived from the Llama2 architecture.
Args:
tok_embeddings (nn.Embedding): PyTorch embedding layer, to be used to move
tokens to an embedding space.
layer (TransformerDecoderLayer): Transformer Decoder layer.
num_layers (int): Number of Transformer Decoder layers.
max_seq_len (int): maximum sequence length the model will be run with, as used
by :func:`~torchtune.modules.KVCache`
num_heads (int): number of query heads. For MHA this is also the
number of heads for key and value. This is used to setup the
:func:`~torchtune.modules.KVCache`
head_dim (int): embedding dimension for each head in self-attention. This is used
to setup the :func:`~torchtune.modules.KVCache`
norm (nn.Module): Callable that applies normalization to the output of the decoder,
before final MLP.
output (nn.Linear): Callable that applies a linear transformation to the output of
the decoder.
Note:
Arg values are checked for correctness (eg: ``attn_dropout`` belongs to [0,1])
in the module where they are used. This helps reduces the number of raise
statements in code and improves readability.
"""
def __init__(
self,
tok_embeddings: nn.Embedding,
layer: TransformerDecoderLayer,
num_layers: int,
max_seq_len: int,
num_heads: int,
head_dim: int,
norm: nn.Module,
output: nn.Linear,
) -> None:
super().__init__()
self.tok_embeddings = tok_embeddings
self.layers = _get_clones(layer, num_layers)
self.norm = norm
self.output = output
self.max_seq_len = max_seq_len
self.num_heads = num_heads
self.head_dim = head_dim
self.causal_mask = None
def setup_caches(self, batch_size: int, dtype: torch.dtype) -> None:
"""Setup key value caches for attention calculation.
Args:
batch_size (int): batch size for the caches.
dtype (torch.dtype): dtype for the caches.
"""
for layer in self.layers:
layer.attn.kv_cache = KVCache(
batch_size=batch_size,
max_seq_len=self.max_seq_len,
num_heads=self.num_heads,
head_dim=self.head_dim,
dtype=dtype,
)
# causal_mask is used during inference to ensure we're attending
# to the right tokens
self.causal_mask = torch.tril(
torch.ones(self.max_seq_len, self.max_seq_len, dtype=torch.bool)
)
def reset_caches(self):
"""Reset the key value caches."""
if self.layers[0].attn.kv_cache is None:
raise RuntimeError(
"Key value caches are not setup. Call ``setup_caches()`` first."
)
for layer in self.layers:
layer.attn.kv_cache.reset()
def forward(
self,
tokens: Tensor,
*,
mask: Optional[Tensor] = None,
input_pos: Optional[Tensor] = None,
) -> Tensor:
"""
Args:
tokens (Tensor): input tensor with shape [b x s]
mask (Optional[Tensor]): Optional boolean tensor which contains the attention mask
with shape [b x s x s]. This is applied after the query-key multiplication and
before the softmax. A value of True in row i and column j means token i attends
to token j. A value of False means token i does not attend to token j. If no
mask is specified, a causal mask is used by default. Default is None.
input_pos (Optional[Tensor]): Optional tensor which contains the position ids
of each token. During training, this is used to indicate the positions
of each token relative to its sample when packed, shape [b x s].
During inference, this indicates the position of the current token.
If none, assume the index of the token is its position id. Default is None.
Note: At the very first step of inference, when the model is provided with a prompt,
``input_pos`` would contain the positions of all of the tokens in the prompt
(eg: ``torch.arange(prompt_length)``). This is because we will need to compute the
KV values for each position.
Returns:
Tensor: output tensor with shape [b x s x v]
Raises:
ValueError: if causal_mask is set but input_pos is None
Notation used for tensor shapes:
- b: batch size
- s: sequence length
- v: vocab size
- d: embed dim
- m_s: max seq len
"""
# input tensor of shape [b, s]
bsz, seq_len = tokens.shape
# shape: [b, s, d]
h = self.tok_embeddings(tokens)
if self.causal_mask is not None:
if input_pos is None:
raise ValueError(
"Caches are setup, but the position of input token is missing"
)
if mask is not None:
raise ValueError(
"An attention mask was set. Cannot use a non-causal mask for inference"
)
# shape: [1, input_pos_len, m_s]
# in most cases input_pos_len should be 1
mask = self.causal_mask[None, input_pos]
for layer in self.layers:
# shape: [b, s, d]
h = layer(h, mask=mask, input_pos=input_pos)
# shape: [b, s, d]
h = self.norm(h)
# shape: [b, s, out_dim] - out_dim is usually the vocab size
output = self.output(h).float()
return output