-
Notifications
You must be signed in to change notification settings - Fork 282
/
_component_builders.py
424 lines (385 loc) · 14.7 KB
/
_component_builders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from functools import partial
from typing import List, Optional
from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook
from torch import nn
from torchtune.models.llama2._model_utils import scale_hidden_dim_for_mlp
from torchtune.modules import (
CausalSelfAttention,
FeedForward,
RMSNorm,
RotaryPositionalEmbeddings,
TransformerDecoder,
TransformerDecoderLayer,
)
from torchtune.modules.peft import LORA_ATTN_MODULES, LoRALinear
"""
Component builders for the Llama2 model and popular variants such as LoRA.
torchtune provides composable building blocks. Builder functions help
stitch these building blocks into higher-level components. This design has
two benefits:
- The building blocks themselves are very flexible. For example, ``CausalSelfAttention``
can take either nn.Linear or nn.LoRALinear for ``q_proj``.
- Builder functions expose a set of configurable params which keep the constructors of
the building blocks simple.
"""
# ------------------ Vanilla Llama2 ------------------
def llama2(
vocab_size: int,
num_layers: int,
num_heads: int,
num_kv_heads: int,
embed_dim: int,
max_seq_len: int,
attn_dropout: float = 0.0,
intermediate_dim: Optional[int] = None,
norm_eps: float = 1e-5,
) -> TransformerDecoder:
"""
Build the decoder associated with the Llama2 model. This includes:
- Token embeddings
- num_layers number of TransformerDecoderLayer blocks
- RMS Norm layer applied to the output of the transformer
- Final projection into token space
Args:
vocab_size (int): number of tokens in vocabulary.
num_layers (int): number of layers in the transformer decoder.
num_heads (int): number of query heads. For MHA this is also the
number of heads for key and value
num_kv_heads (int): number of key and value heads. User should ensure
`num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`,
for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1.
embed_dim (int): embedding dimension for self-attention
max_seq_len (int): maximum sequence length the model will be run with, as used
by :func:`~torchtune.modules.KVCache`
attn_dropout (float): dropout value passed onto scaled_dot_product_attention.
Default: 0.0
intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified,
this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp`
norm_eps (float): epsilon in RMS norms.
Returns:
TransformerDecoder: Instantiation of Llama2 model.
"""
head_dim = embed_dim // num_heads
num_kv_heads = num_kv_heads if num_kv_heads else num_heads
rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len)
self_attn = CausalSelfAttention(
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False),
k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
output_proj=nn.Linear(embed_dim, embed_dim, bias=False),
pos_embeddings=rope,
kv_cache=None,
max_seq_len=max_seq_len,
attn_dropout=attn_dropout,
)
hidden_dim = (
intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim)
)
mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim)
layer = TransformerDecoderLayer(
attn=self_attn,
mlp=mlp,
sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps),
mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps),
)
tok_embeddings = nn.Embedding(vocab_size, embed_dim)
output_proj = nn.Linear(embed_dim, vocab_size, bias=False)
return TransformerDecoder(
tok_embeddings=tok_embeddings,
layer=layer,
num_layers=num_layers,
max_seq_len=max_seq_len,
num_heads=num_heads,
head_dim=head_dim,
norm=RMSNorm(embed_dim, eps=norm_eps),
output=output_proj,
)
def llama2_mlp(dim: int, hidden_dim: int) -> FeedForward:
"""
Build the MLP layer associated with the Llama model.
"""
gate_proj = nn.Linear(dim, hidden_dim, bias=False)
down_proj = nn.Linear(hidden_dim, dim, bias=False)
up_proj = nn.Linear(dim, hidden_dim, bias=False)
return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj)
# ------------------ LoRA Llama2 ------------------
def lora_llama2(
lora_attn_modules: List[LORA_ATTN_MODULES],
apply_lora_to_mlp: bool = False,
apply_lora_to_output: bool = False,
*,
# llama2 args
vocab_size: int,
num_layers: int,
num_heads: int,
num_kv_heads: int,
embed_dim: int,
max_seq_len: int,
intermediate_dim: Optional[int] = None,
attn_dropout: float = 0.0,
norm_eps: float = 1e-5,
# LoRA args
lora_rank: int,
lora_alpha: float,
lora_dropout: float = 0.0,
# Quantization args
quantize_base: bool = False,
) -> TransformerDecoder:
"""
Return a version of Llama2 (an instance of :func:`~torchtune.modules.TransformerDecoder`)
with LoRA applied based on the passed in configuration.
Args:
lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers
LoRA should be applied to in each self-attention block. Options are
``{"q_proj", "k_proj", "v_proj", "output_proj"}``.
apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer.
Default: False
apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection.
Default: False
vocab_size (int): number of tokens in vocabulary.
num_layers (int): number of layers in the transformer decoder.
num_heads (int): number of query heads. For MHA this is also the
number of heads for key and value
num_kv_heads (int): number of key and value heads. User should ensure
`num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`,
for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1.
embed_dim (int): embedding dimension for self-attention
max_seq_len (int): maximum sequence length the model will be run with, as used
by :func:`~torchtune.modules.KVCache`
attn_dropout (float): dropout value passed onto scaled_dot_product_attention.
Default: 0.0
intermediate_dim (Optional[int]): intermediate dimension for MLP. If not specified,
this is computed using :func:`~torchtune.modules.scale_hidden_dim_for_mlp`
norm_eps (float): epsilon in RMS norms.
lora_rank (int): rank of each low-rank approximation
lora_alpha (float): scaling factor for the low-rank approximation
lora_dropout (float): LoRA dropout probability. Default: 0.0
quantize_base: (bool): Whether to quantize base model weights or not. Only applied to base
weights within linear layers LoRA is applied to. The final output linear projection is not
supported for quantization currently.
Returns:
TransformerDecoder: Instantiation of Llama2 model with LoRA applied to
a subset of the attention projections in each layer.
"""
self_attn = lora_llama2_self_attention(
lora_modules=lora_attn_modules,
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
max_seq_len=max_seq_len,
attn_dropout=attn_dropout,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
quantize_base=quantize_base,
)
hidden_dim = (
intermediate_dim if intermediate_dim else scale_hidden_dim_for_mlp(embed_dim)
)
if apply_lora_to_mlp:
mlp = lora_llama2_mlp(
dim=embed_dim,
hidden_dim=hidden_dim,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
quantize_base=quantize_base,
lora_dropout=lora_dropout,
)
else:
mlp = llama2_mlp(dim=embed_dim, hidden_dim=hidden_dim)
layer = TransformerDecoderLayer(
attn=self_attn,
mlp=mlp,
sa_norm=RMSNorm(dim=embed_dim, eps=norm_eps),
mlp_norm=RMSNorm(dim=embed_dim, eps=norm_eps),
)
tok_embeddings = nn.Embedding(vocab_size, embed_dim)
# TODO: quantize_base is not applied to final output_proj currently.
output_proj = (
LoRALinear(embed_dim, vocab_size, rank=lora_rank, alpha=lora_alpha, dropout=lora_dropout)
if apply_lora_to_output
else nn.Linear(embed_dim, vocab_size, bias=False)
)
model = TransformerDecoder(
tok_embeddings=tok_embeddings,
layer=layer,
num_layers=num_layers,
max_seq_len=max_seq_len,
num_heads=num_heads,
head_dim=(embed_dim // num_heads),
norm=RMSNorm(embed_dim, eps=norm_eps),
output=output_proj,
)
if quantize_base:
# For QLoRA, we reparametrize 4-bit tensors to higher precision, and offload to CPU on the fly
# so as to not increase peak memory
model._register_state_dict_hook(
partial(
reparametrize_as_dtype_state_dict_post_hook,
# TODO this is clowny, figure out a better way to get what precision the rest
# of the model is in
dtype=tok_embeddings.weight.dtype,
offload_to_cpu=True,
)
)
return model
def lora_llama2_self_attention(
lora_modules: List[LORA_ATTN_MODULES],
*,
# CausalSelfAttention args
embed_dim: int,
num_heads: int,
num_kv_heads: int,
max_seq_len: int,
attn_dropout: float = 0.0,
# LoRA args
lora_rank: int,
lora_alpha: float,
lora_dropout: float = 0.0,
quantize_base: bool = False,
) -> CausalSelfAttention:
"""
Return an instance of :func:`~torchtune.modules.CausalSelfAttention` with LoRA
applied to a subset of its linear layers
Args:
lora_modules (List[LORA_ATTN_MODULES]): list of which linear layers
LoRA should be applied to. Options are ``{"q_proj", "k_proj", "v_proj",
"output_proj"}``.
embed_dim (int): embedding dimension for self-attention
num_heads (int): number of query heads. For MHA this is also the
number of heads for key and value
num_kv_heads (int): number of key and value heads. User should ensure
`num_heads` % `num_kv_heads` == 0. For standard MHA set `num_kv_heads` == `num_heads`,
for GQA `num_kv_heads` < `num_heads`, and for MQA set `num_kv_heads` == 1.
max_seq_len (int): maximum sequence length the model will be run with, as used
by :func:`~torchtune.modules.KVCache`
attn_dropout (float): dropout value passed onto scaled_dot_product_attention.
Default: 0.0
lora_rank (int): rank of each low-rank approximation
lora_alpha (float): scaling factor for the low-rank approximation
lora_dropout (float): LoRA dropout probability. Default: 0.0
quantize_base (bool): Whether to quantize base model parameters for linear layers
LoRA is being applied to. Default is ``False``.
Returns:
CausalSelfAttention: instantiation of self-attention module with LoRA
applied to a subset of Q, K, V, output projections.
Raises:
ValueError: If lora_modules arg is an empty list
"""
if not lora_modules:
raise ValueError(
f"Must pass one or more of {LORA_ATTN_MODULES} as lora_modules"
)
head_dim = embed_dim // num_heads
num_kv_heads = num_kv_heads if num_kv_heads else num_heads
q_proj = (
LoRALinear(
embed_dim,
num_heads * head_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
)
if "q_proj" in lora_modules
else nn.Linear(embed_dim, num_heads * head_dim, bias=False)
)
k_proj = (
LoRALinear(
embed_dim,
num_kv_heads * head_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
)
if "k_proj" in lora_modules
else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
)
v_proj = (
LoRALinear(
embed_dim,
num_kv_heads * head_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
)
if "v_proj" in lora_modules
else nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False)
)
output_proj = (
LoRALinear(
embed_dim,
embed_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
)
if "output_proj" in lora_modules
else nn.Linear(embed_dim, embed_dim, bias=False)
)
rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len)
self_attn = CausalSelfAttention(
embed_dim=embed_dim,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
q_proj=q_proj,
k_proj=k_proj,
v_proj=v_proj,
output_proj=output_proj,
pos_embeddings=rope,
kv_cache=None,
max_seq_len=max_seq_len,
attn_dropout=attn_dropout,
)
return self_attn
def lora_llama2_mlp(
*,
dim: int,
hidden_dim: int,
lora_rank: int,
lora_alpha: float,
lora_dropout: float = 0.0,
quantize_base: bool = False,
) -> FeedForward:
gate_proj = LoRALinear(
in_dim=dim,
out_dim=hidden_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
)
down_proj = LoRALinear(
in_dim=hidden_dim,
out_dim=dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
)
up_proj = LoRALinear(
in_dim=dim,
out_dim=hidden_dim,
rank=lora_rank,
alpha=lora_alpha,
dropout=lora_dropout,
quantize_base=quantize_base,
)
return FeedForward(
gate_proj=gate_proj,
down_proj=down_proj,
up_proj=up_proj,
)