-
Notifications
You must be signed in to change notification settings - Fork 255
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modularize Llama2 #137
Modularize Llama2 #137
Conversation
✅ Deploy Preview for torchtune-preview ready!
To edit notification comments on pull requests, go to your Netlify site configuration. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a complete review, just one comment for thought.
output_proj (nn.Module): projection layer for output. | ||
pos_embeddings (Optional[nn.Module]): positional embeddings layer, e.g. RotaryPositionalEmbeddings. | ||
If not specified, then no positional embeddings are used. | ||
kv_cache (Optional[KVCache]): KVCache object used to cache key and value. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious about the tradeoffs here. Before, we'd detect when running for inference and enable kv cache optimization out of the box. Now we're relying on users to pass this in. This could cause some friction when trying to switch btwn training and inference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @ebsmothers on this as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I may be being dense here, but before didn't we just use a different flag (max_batch_size) to control whether KV caching was enabled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before, we needed them to include a max_batch_size
for them to include this. I think if we pick-n-choose which modules we include by default and which ones we allow to be passed in, it signals inconsistent design principles.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree we should strive for a consistent design here. Ideally we have some flag that's relatively clear and indicates whether a component (be it self-attention or full transformer decoder) is in KV caching mode. Either way idk we have to tackle it in this PR
@@ -13,8 +13,8 @@ | |||
|
|||
import torch | |||
|
|||
from torchtune.models.llama2.tokenizer import Tokenizer | |||
from torchtune.models.llama2.transformer import TransformerDecoder | |||
from torchtune.modules.tokenizer import Tokenizer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know it's not the main point of this PR, but wonder if tokenizers could/should go in something like torchtune/transforms instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not opposed - good for a follow-up.
torchtune/models/__init__.py
Outdated
@@ -9,7 +9,7 @@ | |||
import torch | |||
from torch.nn import Module | |||
|
|||
from torchtune.models.llama2.models import llama2_7b, llama2_tokenizer | |||
from .llama2 import llama2_7b, llama2_tokenizer, small_test_ckpt # noqa |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why relative import? And why do we need to import small_test_ckpt here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is pretty standard for at least init.py files. See https://github.com/pytorch/pytorch/blob/main/torch/optim/__init__.py for a single example, but there are more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
small_test_ckpt
was used in tests. I can move out of here to private if desired.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's stick w/absolute imports IMO. They're much cleaner and should probably be a design principle / coding best practice if it's not already.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For everything except __init__.py
files, I absolutely agree. I think the assumption is that the __init__.py
for a specific submodule will ALWAYS be in the same place. That way, any changes to higher-level folder structure won't mess up these simple imports.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I am OK with the relative imports; I think they are often used to define the importable modules from a given package. I see you defined __all__
in the other __init__.py
file, so this makes sense to me. Btw to scrap the noqa tags you can try this
torchtune/models/llama2.py
Outdated
) | ||
|
||
|
||
def small_test_ckpt(vocab_size: int) -> TransformerDecoder: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vocab_size not used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has been moved to unit tests in main
torchtune/models/llama2.py
Outdated
return tokenizer | ||
|
||
|
||
class Llama2FeedForward(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is one case where I do not like the usage of an extra class wrapping FeedForward. This is basically just a passthrough for everything but activation, I don't think we should be nesting modules here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit on name, technically all layer types are feedforward if they're not recurrent. Isn't this either a Linear layer or an MLP?
torchtune/models/llama2.py
Outdated
(1, 1, seq_len, seq_len), float("-inf"), device=tokens.device | ||
) | ||
mask = torch.triu(mask, diagonal=curr_pos + 1) | ||
return self.model(tokens, mask, curr_pos) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So once we get all the way up to the top level model, if I wanna get an individual MLP block I will need to do self.model.layers[0].layer.mlp.ff
, right? Imo this is unintuitive
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a consequence of utilizing classes as opposed to builder functions. Not sure if this discussion was officially closed. I see the confusion here in accessing the class through this method.
torchtune/models/llama2.py
Outdated
if seq_len > 1 and self.max_batch_size is not None: | ||
mask = torch.full( | ||
(1, 1, seq_len, seq_len), float("-inf"), device=tokens.device | ||
) | ||
mask = torch.triu(mask, diagonal=curr_pos + 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we offload mask construction to a utility or something?
output_proj (nn.Module): projection layer for output. | ||
pos_embeddings (Optional[nn.Module]): positional embeddings layer, e.g. RotaryPositionalEmbeddings. | ||
If not specified, then no positional embeddings are used. | ||
kv_cache (Optional[KVCache]): KVCache object used to cache key and value. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I may be being dense here, but before didn't we just use a different flag (max_batch_size) to control whether KV caching was enabled?
torchtune/modules/attention.py
Outdated
@@ -198,17 +180,19 @@ def forward( | |||
k = k.expand(bsz, seq_len, self.num_kv_heads, q_per_kv, self.head_dim) | |||
v = v.expand(bsz, seq_len, self.num_kv_heads, q_per_kv, self.head_dim) | |||
|
|||
# Apply RoPE embeddings | |||
# if self.pos_embeddings is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this commented out? I think in init you have pos_embeddings as optional
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will be fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bump
@@ -13,8 +13,8 @@ | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worth making this a recipe? It could be tested then and take advantage of the cli
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still not 100% clear on the purpose of this file tbh. As is it reads like a parity check across inference with no KV caching, inference with KV caching, and the HF version of the model. If that's the case we should not make it a recipe. But if we drop the transformers dep and make the KV caching configurable I agree this would make a nice recipe
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file should be deleted eventually. Before MVP
torchtune/models/llama2.py
Outdated
super().__init__() | ||
self.max_batch_size = max_batch_size | ||
token_embeddings = nn.Embedding(vocab_size, embed_dim) | ||
layer = Llama2DecoderLayer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this layer type was passed in, it would make this code more modular. A user could write a custom Lllama deocder layer and pass that in.
@@ -13,8 +13,8 @@ | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still not 100% clear on the purpose of this file tbh. As is it reads like a parity check across inference with no KV caching, inference with KV caching, and the HF version of the model. If that's the case we should not make it a recipe. But if we drop the transformers dep and make the KV caching configurable I agree this would make a nice recipe
self.w1 = linear_class(dim, hidden_dim, bias=False) | ||
self.w2 = linear_class(hidden_dim, dim, bias=False) | ||
self.w3 = linear_class(dim, hidden_dim, bias=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to change these (I think this is what you were mentioning earlier?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bumping this comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this not work? Would we want to do more than just allow a different linear class?? What interface does LoRALinear
support?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would pass three different nn.Modules instead of dim and hidden_dim. Also the way you're passing linear_class rn it is not an nn.Module, it is just a type (since you have not actually initialized it outside of FeedForward).
@@ -91,37 +99,19 @@ def __init__( | |||
if attn_dropout < 0 or attn_dropout > 1: | |||
raise ValueError(f"attn_dropout ({embed_dim}) must be between 0.0 and 1.0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be nice to offload all these checks to a self._validate_parameters
method or something to keep the init clean
layer: TransformerDecoderLayer, | ||
num_layers: int, | ||
norm: nn.Module, | ||
output: nn.Linear, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could also just be generic nn.Module (e.g. if someone wants to use an MLP instead of a single linear). Not a huge deal either way though
# shape: [b, s, d] | ||
h = self.tok_embeddings(tokens) | ||
|
||
if seq_len > 1 and self.layers[0].attn.kv_cache is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general I would try to avoid accessing lower-level components' attributes as much as possible. I know you've typed things all the way down so this isn't gonna directly break anything, but still..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sucks and should be fixed, will in a follow-up PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a todo?
torchtune/models/llama2.py
Outdated
num_kv_heads=32, | ||
embed_dim=4096, | ||
max_seq_len=2048, | ||
max_batch_size=32, # Need to figure out the actual default used by Llama2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should depend on the user's setup, right? (I.e. how much memory they have) Also does this mean that the default here is for inference mode? (Not that there's anything wrong with that, but we may want to distinguish this somehow)
ba9f932
to
f623f8d
Compare
from .position_embeddings import RotaryPositionalEmbeddings # noqa | ||
from .rms_norm import RMSNorm # noqa | ||
from .tokenizer import Tokenizer # noqa | ||
from .transformer import TransformerDecoder, TransformerDecoderLayer # noqa |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[no need to address in this PR!]
As an illustration of what I'm advocating for in #25 (comment), these imports in this __init__.py
file would look like:
from ._tokenizer import Tokenizer # noqa
from ._transformer import TransformerDecoder, TransformerDecoderLayer # noqa
And everything else stays unchanged.
This allows us to write whatever we want within _tokenizer.py
or _transformer.py
without worrying about whether what we're writing should be public or private.
torchtune/models/llama2.py
Outdated
hidden_dim = 4 * int(2 * embed_dim / 3) | ||
# Round hidden dimension to nearest multiple of `multiple_of` | ||
multiple_of = 256 | ||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is one case where I think a feedforward-specific builder will be useful. Otherwise we are repeating this logic a lot (or using magic #s as in the transformer decoder test). Why not let FeedForward take three nn.Modules + activation as args, then provide a single feedforward builder that takes hidden_dim and embed_dim and does all the math?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
3910845
to
4ab0275
Compare
TorchTune seeks to leverage the power of native PyTorch including PyTorch-native distributed APIs such as FSDP and Tensor Parallelism. To train Llama-2 models using TorchTune, checkpoints must be | ||
converted as a result. | ||
|
||
TorchTune seeks to leverage the power of native PyTorch including PyTorch-native distributed APIs such as FSDP and Tensor Parallelism. To train Llama-2 models using TorchTune, checkpoints must be converted as a result. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
checkpoints must be converted as a result.
converted from what to what?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will be remedied in a follow-up PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few more comments but overall looks good. There are still a bunch of open comments so please make sure to address those (I explicitly bumped the couple most important ones imo). Thanks for your diligence in working through all these changes and addressing everything! Accepting now so you're not blocked on me
torchtune/models/llama2.py
Outdated
num_kv_heads = num_kv_heads if num_kv_heads else num_heads | ||
qkv_dim = (num_heads + 2 * num_kv_heads) * head_dim | ||
layers = nn.ModuleList() | ||
for _ in range(num_layers): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Not blocking for this PR) There are gonna be a lot of similar for loops if we aren't using deepcopy, which I don't love. Short-term: maybe some simple utility to slightly reduce boilerplate? Longer-term I would really like to figure out a better solution though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed back to _get_clones.
self.w1 = linear_class(dim, hidden_dim, bias=False) | ||
self.w2 = linear_class(hidden_dim, dim, bias=False) | ||
self.w3 = linear_class(dim, hidden_dim, bias=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bumping this comment
torchtune/modules/attention.py
Outdated
@@ -198,17 +180,19 @@ def forward( | |||
k = k.expand(bsz, seq_len, self.num_kv_heads, q_per_kv, self.head_dim) | |||
v = v.expand(bsz, seq_len, self.num_kv_heads, q_per_kv, self.head_dim) | |||
|
|||
# Apply RoPE embeddings | |||
# if self.pos_embeddings is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bump
@@ -202,11 +201,21 @@ def compare_attention( | |||
attn_out_ref = attn_ref(input_t, freq_cis, mask) | |||
|
|||
# current implementation; initialize with constant to compare outputs | |||
attn = LlamaSelfAttention( | |||
head_dim = embed_dim // num_heads |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not to be too much of a broken record, but this block (plus L122-152 of compare_decoder_layer.py, plus several blocks of test_attention.py, plus L69-94 of test_transformer_decoder.py) is why I would like to see builder functions for intermediate components as well.
7799d50
to
cd06fd8
Compare
ef7e33a
to
5d37cd9
Compare
Co-authored-by: Evan Smothers <ebs@fb.com> Co-authored-by: Danielle Pintz <38207072+daniellepintz@users.noreply.github.com>
attn_dropout: float = 0.0, | ||
max_batch_size: Optional[int] = None, | ||
norm_eps: float = 1e-6, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing docstring
Summary
Based on RFC principles outlined in #102, llama2 (and derived models) can be built using principles of composability and shared modules.
Changelog
modules/
__init__.py
so as to be importable frommodules/
dirllama2
components based on these modules in a single filepytorch-multimodal/llama2-7b-01052024
Testing
pytest tests/
pytest recipes/
Notes
llama2
builder function constructs aTransformerDecoder
model from all classes, rather than a bunch of builder functions. This is the hybrid approach mentioned by @ebsmothers and @kartikayk and requires more feedback if this is the direction we want to go in.TransformerDecoder
b/c it initializes to the same data_ptr. Therefore, a big change is that thefor loop
takes place in thellama2
builder function. Not sure if this is ideal. Curious to hear thoughts.Docs