Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modularize Llama2 #137

Merged
merged 41 commits into from
Jan 10, 2024
Merged

Modularize Llama2 #137

merged 41 commits into from
Jan 10, 2024

Conversation

joecummings
Copy link
Contributor

@joecummings joecummings commented Dec 28, 2023

Summary

Based on RFC principles outlined in #102, llama2 (and derived models) can be built using principles of composability and shared modules.

Changelog

  • Separated out llama2/*.py into modules/
    • Added these modules to __init__.py so as to be importable from modules/ dir
    • Updated README to correct import
  • Created llama2 components based on these modules in a single file
  • Updated all tests to test these components through the llama2 interface
  • Added modules and models to docs
  • Converted weights from old format to new format and uploaded to s3 @ pytorch-multimodal/llama2-7b-01052024

Testing

  1. pytest tests/
  2. pytest recipes/

Notes

  1. This is a major re-design of the llama2 model. Please take note that the llama2 builder function constructs a TransformerDecoder model from all classes, rather than a bunch of builder functions. This is the hybrid approach mentioned by @ebsmothers and @kartikayk and requires more feedback if this is the direction we want to go in.
  2. It is not possible to pass in a single layer w/ number of layers to TransformerDecoder b/c it initializes to the same data_ptr. Therefore, a big change is that the for loop takes place in the llama2 builder function. Not sure if this is ideal. Curious to hear thoughts.

Docs

Screenshot 2024-01-08 at 1 50 31 PM

Copy link

netlify bot commented Dec 28, 2023

Deploy Preview for torchtune-preview ready!

Name Link
🔨 Latest commit 5d37cd9
🔍 Latest deploy log https://app.netlify.com/sites/torchtune-preview/deploys/659f2b6488189f000872ed7f
😎 Deploy Preview https://deploy-preview-137--torchtune-preview.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 28, 2023
@joecummings joecummings changed the title modular llama2 Modularize Llama2 Dec 28, 2023
@joecummings joecummings marked this pull request as ready for review December 28, 2023 21:38
Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a complete review, just one comment for thought.

output_proj (nn.Module): projection layer for output.
pos_embeddings (Optional[nn.Module]): positional embeddings layer, e.g. RotaryPositionalEmbeddings.
If not specified, then no positional embeddings are used.
kv_cache (Optional[KVCache]): KVCache object used to cache key and value.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious about the tradeoffs here. Before, we'd detect when running for inference and enable kv cache optimization out of the box. Now we're relying on users to pass this in. This could cause some friction when trying to switch btwn training and inference.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ebsmothers on this as well

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may be being dense here, but before didn't we just use a different flag (max_batch_size) to control whether KV caching was enabled?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before, we needed them to include a max_batch_size for them to include this. I think if we pick-n-choose which modules we include by default and which ones we allow to be passed in, it signals inconsistent design principles.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree we should strive for a consistent design here. Ideally we have some flag that's relatively clear and indicates whether a component (be it self-attention or full transformer decoder) is in KV caching mode. Either way idk we have to tackle it in this PR

@rohan-varma rohan-varma self-requested a review January 2, 2024 08:28
@@ -13,8 +13,8 @@

import torch

from torchtune.models.llama2.tokenizer import Tokenizer
from torchtune.models.llama2.transformer import TransformerDecoder
from torchtune.modules.tokenizer import Tokenizer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it's not the main point of this PR, but wonder if tokenizers could/should go in something like torchtune/transforms instead

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not opposed - good for a follow-up.

@@ -9,7 +9,7 @@
import torch
from torch.nn import Module

from torchtune.models.llama2.models import llama2_7b, llama2_tokenizer
from .llama2 import llama2_7b, llama2_tokenizer, small_test_ckpt # noqa
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why relative import? And why do we need to import small_test_ckpt here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty standard for at least init.py files. See https://github.com/pytorch/pytorch/blob/main/torch/optim/__init__.py for a single example, but there are more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small_test_ckpt was used in tests. I can move out of here to private if desired.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's stick w/absolute imports IMO. They're much cleaner and should probably be a design principle / coding best practice if it's not already.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For everything except __init__.py files, I absolutely agree. I think the assumption is that the __init__.py for a specific submodule will ALWAYS be in the same place. That way, any changes to higher-level folder structure won't mess up these simple imports.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I am OK with the relative imports; I think they are often used to define the importable modules from a given package. I see you defined __all__ in the other __init__.py file, so this makes sense to me. Btw to scrap the noqa tags you can try this

)


def small_test_ckpt(vocab_size: int) -> TransformerDecoder:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vocab_size not used

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been moved to unit tests in main

torchtune/modules/__init__.py Show resolved Hide resolved
return tokenizer


class Llama2FeedForward(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is one case where I do not like the usage of an extra class wrapping FeedForward. This is basically just a passthrough for everything but activation, I don't think we should be nesting modules here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit on name, technically all layer types are feedforward if they're not recurrent. Isn't this either a Linear layer or an MLP?

torchtune/modules/transformer.py Show resolved Hide resolved
(1, 1, seq_len, seq_len), float("-inf"), device=tokens.device
)
mask = torch.triu(mask, diagonal=curr_pos + 1)
return self.model(tokens, mask, curr_pos)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So once we get all the way up to the top level model, if I wanna get an individual MLP block I will need to do self.model.layers[0].layer.mlp.ff, right? Imo this is unintuitive

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a consequence of utilizing classes as opposed to builder functions. Not sure if this discussion was officially closed. I see the confusion here in accessing the class through this method.

Comment on lines 167 to 171
if seq_len > 1 and self.max_batch_size is not None:
mask = torch.full(
(1, 1, seq_len, seq_len), float("-inf"), device=tokens.device
)
mask = torch.triu(mask, diagonal=curr_pos + 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we offload mask construction to a utility or something?

output_proj (nn.Module): projection layer for output.
pos_embeddings (Optional[nn.Module]): positional embeddings layer, e.g. RotaryPositionalEmbeddings.
If not specified, then no positional embeddings are used.
kv_cache (Optional[KVCache]): KVCache object used to cache key and value.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may be being dense here, but before didn't we just use a different flag (max_batch_size) to control whether KV caching was enabled?

@@ -198,17 +180,19 @@ def forward(
k = k.expand(bsz, seq_len, self.num_kv_heads, q_per_kv, self.head_dim)
v = v.expand(bsz, seq_len, self.num_kv_heads, q_per_kv, self.head_dim)

# Apply RoPE embeddings
# if self.pos_embeddings is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this commented out? I think in init you have pos_embeddings as optional

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will be fixed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bump

@@ -13,8 +13,8 @@

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth making this a recipe? It could be tested then and take advantage of the cli

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still not 100% clear on the purpose of this file tbh. As is it reads like a parity check across inference with no KV caching, inference with KV caching, and the HF version of the model. If that's the case we should not make it a recipe. But if we drop the transformers dep and make the KV caching configurable I agree this would make a nice recipe

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file should be deleted eventually. Before MVP

super().__init__()
self.max_batch_size = max_batch_size
token_embeddings = nn.Embedding(vocab_size, embed_dim)
layer = Llama2DecoderLayer(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this layer type was passed in, it would make this code more modular. A user could write a custom Lllama deocder layer and pass that in.

@@ -13,8 +13,8 @@

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still not 100% clear on the purpose of this file tbh. As is it reads like a parity check across inference with no KV caching, inference with KV caching, and the HF version of the model. If that's the case we should not make it a recipe. But if we drop the transformers dep and make the KV caching configurable I agree this would make a nice recipe

Comment on lines +28 to +29
self.w1 = linear_class(dim, hidden_dim, bias=False)
self.w2 = linear_class(hidden_dim, dim, bias=False)
self.w3 = linear_class(dim, hidden_dim, bias=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to change these (I think this is what you were mentioning earlier?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bumping this comment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this not work? Would we want to do more than just allow a different linear class?? What interface does LoRALinear support?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would pass three different nn.Modules instead of dim and hidden_dim. Also the way you're passing linear_class rn it is not an nn.Module, it is just a type (since you have not actually initialized it outside of FeedForward).

@@ -91,37 +99,19 @@ def __init__(
if attn_dropout < 0 or attn_dropout > 1:
raise ValueError(f"attn_dropout ({embed_dim}) must be between 0.0 and 1.0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be nice to offload all these checks to a self._validate_parameters method or something to keep the init clean

layer: TransformerDecoderLayer,
num_layers: int,
norm: nn.Module,
output: nn.Linear,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could also just be generic nn.Module (e.g. if someone wants to use an MLP instead of a single linear). Not a huge deal either way though

# shape: [b, s, d]
h = self.tok_embeddings(tokens)

if seq_len > 1 and self.layers[0].attn.kv_cache is not None:
Copy link
Contributor

@ebsmothers ebsmothers Jan 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general I would try to avoid accessing lower-level components' attributes as much as possible. I know you've typed things all the way down so this isn't gonna directly break anything, but still..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sucks and should be fixed, will in a follow-up PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a todo?

num_kv_heads=32,
embed_dim=4096,
max_seq_len=2048,
max_batch_size=32, # Need to figure out the actual default used by Llama2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should depend on the user's setup, right? (I.e. how much memory they have) Also does this mean that the default here is for inference mode? (Not that there's anything wrong with that, but we may want to distinguish this somehow)

from .position_embeddings import RotaryPositionalEmbeddings # noqa
from .rms_norm import RMSNorm # noqa
from .tokenizer import Tokenizer # noqa
from .transformer import TransformerDecoder, TransformerDecoderLayer # noqa
Copy link
Member

@NicolasHug NicolasHug Jan 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[no need to address in this PR!]

As an illustration of what I'm advocating for in #25 (comment), these imports in this __init__.py file would look like:

from ._tokenizer import Tokenizer  # noqa
from ._transformer import TransformerDecoder, TransformerDecoderLayer  # noqa

And everything else stays unchanged.

This allows us to write whatever we want within _tokenizer.py or _transformer.py without worrying about whether what we're writing should be public or private.

scripts/llama2_checkpoint/convert_llama2_to_native.py Outdated Show resolved Hide resolved
scripts/llama2_checkpoint/convert_llama2_to_native.py Outdated Show resolved Hide resolved
Comment on lines 87 to 90
hidden_dim = 4 * int(2 * embed_dim / 3)
# Round hidden dimension to nearest multiple of `multiple_of`
multiple_of = 256
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is one case where I think a feedforward-specific builder will be useful. Otherwise we are repeating this logic a lot (or using magic #s as in the transformer decoder test). Why not let FeedForward take three nn.Modules + activation as args, then provide a single feedforward builder that takes hidden_dim and embed_dim and does all the math?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

TorchTune seeks to leverage the power of native PyTorch including PyTorch-native distributed APIs such as FSDP and Tensor Parallelism. To train Llama-2 models using TorchTune, checkpoints must be
converted as a result.

TorchTune seeks to leverage the power of native PyTorch including PyTorch-native distributed APIs such as FSDP and Tensor Parallelism. To train Llama-2 models using TorchTune, checkpoints must be converted as a result.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checkpoints must be converted as a result.

converted from what to what?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be remedied in a follow-up PR.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few more comments but overall looks good. There are still a bunch of open comments so please make sure to address those (I explicitly bumped the couple most important ones imo). Thanks for your diligence in working through all these changes and addressing everything! Accepting now so you're not blocked on me

num_kv_heads = num_kv_heads if num_kv_heads else num_heads
qkv_dim = (num_heads + 2 * num_kv_heads) * head_dim
layers = nn.ModuleList()
for _ in range(num_layers):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Not blocking for this PR) There are gonna be a lot of similar for loops if we aren't using deepcopy, which I don't love. Short-term: maybe some simple utility to slightly reduce boilerplate? Longer-term I would really like to figure out a better solution though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed back to _get_clones.

torchtune/modules/position_embeddings.py Show resolved Hide resolved
torchtune/modules/transformer.py Outdated Show resolved Hide resolved
Comment on lines +28 to +29
self.w1 = linear_class(dim, hidden_dim, bias=False)
self.w2 = linear_class(hidden_dim, dim, bias=False)
self.w3 = linear_class(dim, hidden_dim, bias=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bumping this comment

@@ -198,17 +180,19 @@ def forward(
k = k.expand(bsz, seq_len, self.num_kv_heads, q_per_kv, self.head_dim)
v = v.expand(bsz, seq_len, self.num_kv_heads, q_per_kv, self.head_dim)

# Apply RoPE embeddings
# if self.pos_embeddings is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bump

@@ -202,11 +201,21 @@ def compare_attention(
attn_out_ref = attn_ref(input_t, freq_cis, mask)

# current implementation; initialize with constant to compare outputs
attn = LlamaSelfAttention(
head_dim = embed_dim // num_heads
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not to be too much of a broken record, but this block (plus L122-152 of compare_decoder_layer.py, plus several blocks of test_attention.py, plus L69-94 of test_transformer_decoder.py) is why I would like to see builder functions for intermediate components as well.

@joecummings joecummings force-pushed the modular-llama2 branch 3 times, most recently from 7799d50 to cd06fd8 Compare January 10, 2024 23:02
@joecummings joecummings merged commit 86c2318 into main Jan 10, 2024
15 checks passed
joecummings added a commit that referenced this pull request Jan 11, 2024
Co-authored-by: Evan Smothers <ebs@fb.com>
Co-authored-by: Danielle Pintz <38207072+daniellepintz@users.noreply.github.com>
attn_dropout: float = 0.0,
max_batch_size: Optional[int] = None,
norm_eps: float = 1e-6,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing docstring

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants