Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemma #630

Merged
merged 73 commits into from
Apr 4, 2024
Merged

Gemma #630

merged 73 commits into from
Apr 4, 2024

Conversation

solitude-alive
Copy link
Contributor

Context

  • To support full fine-tune with Gemma-2B model

Changelog

  • Create gemma module, update for loading with .safetensors, support tied weight model.

Test plan

  • ....

Copy link

pytorch-bot bot commented Apr 1, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/630

Note: Links to docs will display an error until the docs builds have been completed.

❌ 8 New Failures

As of commit 321f59e with merge base aacaadd (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 1, 2024
@solitude-alive solitude-alive reopened this Apr 1, 2024
Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great - thanks for the contribution! I left a couple comments, but generally looks good.

Can you add a screenshot of running a distributed full finetune with Gemma to confirm it works?

return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj, activation=activation)


def lora_gemma(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like you included the LoRA version of Gemma for this PR. Are you planning on including LoRA, as well, or just starting with the full fine-tuning version?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May take a longer time to complete LoRA version of Gemma, could I PR full fine-tuning version first?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be fine to start with just full fine-tune for now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@solitude-alive Can you remove all the LoRA code since we won't be addressing it in this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I remove them in the latest version.

):
super().__init__()
self.w1 = gate_proj
self.w2 = down_proj
self.w3 = up_proj
self.activation = F.silu
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good abstraction!

@@ -11,14 +11,17 @@
import torch
import torch.nn as nn
import torch.optim as optim
from safetensors import safe_open
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add this to requirements.txt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, thank you for your suggestion, I add it in the latest version.

TransformerDecoder: Instantiation of Gemma 2B model
"""
return gemma(
vocab_size=256_000,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still shocked by this vocab size - so large!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, 😂

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean the embedding(/output projection since they're tied) constitutes a full 25% of their params?!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I calculate it with count_trainable_parameters , the params of embed_tokens is 21%.

def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

@@ -203,6 +210,7 @@ def _setup_model(
cfg_model: DictConfig,
enable_activation_checkpointing: bool,
model_state_dict: Dict[str, Any],
mode_tie: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mode_tie: bool = False,
model_tie: bool = False,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your suggestion, I have fixed it in the latest version.

@@ -259,6 +267,10 @@ def _setup_model(
),
)

if mode_tie: # Tie the weights of the model if required
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if mode_tie: # Tie the weights of the model if required
if model_tie: # Tie the weights of the model if required

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your suggestion, I have fixed it in the latest version.

@solitude-alive
Copy link
Contributor Author

Yeah, this is a screenshot of running a distributed full finetune with Gemma.
Screenshot 2024-04-02 at 8 16 34 AM

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR! Really excited to see how nicely this is shaping up.

Re testing, aside from making sure training runs, let's try to get a sanity check that the model forward here lines up with the one from the original implementation on some dummy data (assuming you haven't done so already). We have a bunch of scripts we've used in the past for this with various components in the library, so you can use these as a reference if it helps. For example (Note: you do not have to actually write a script like this and check it in, this is meant more as a reference if it helps you)

# --config gemma/2B_full \
# checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works best when the model is being fine-tuned on 2+ GPUs.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think if we are running with full_finetune_distributed recipe it will only work on 2+ GPUs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I update it in the latest version.

recipes/configs/gemma/2B_full.yaml Outdated Show resolved Hide resolved

def gemma_2b() -> TransformerDecoder:
"""
Builder for creating a Gemma 2B model initialized w/ the default 2b parameter values
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add pointer to the paper or blog post here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I add it in the latest version.

TransformerDecoder: Instantiation of Gemma 2B model
"""
return gemma(
vocab_size=256_000,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean the embedding(/output projection since they're tied) constitutes a full 25% of their params?!

return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj, activation=activation)


def lora_gemma(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be fine to start with just full fine-tune for now

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from torchtune.utils._distributed import contains_fsdp
from transformers.utils import is_safetensors_available
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we shouldn't import from transformers here as it's not in our core dependencies. If you've added safetensors to our core dependencies (based on the above comment) probably don't need to do this check anyways.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I remove it in the latest version.

@@ -25,12 +26,13 @@ def __init__(
gate_proj: nn.Module,
down_proj: nn.Module,
up_proj: nn.Module,
activation: nn.Module = F.silu,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: technically F.silu is a Callable, not an nn.Module

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I replace it with nn.SiLU() in the latest version.

Default: False

Returns:
FeedForward: instantiation of the MLP module with LoRA applied to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the second line of this docstring got lost somewhere along the way

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I fix it in the latest version.

Comment on lines 155 to 158
if cfg.checkpointer.model_type == "GEMMA":
model_tie = True
else:
model_tie = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also consider associating a weight tying config explicitly with the model type and using that in the checkpointer. E.g.

@dataclass
ModelType
	name: str
	weight_tying_config: Dict[str, str] = field(default_factory=dict)

Then Gemma would be ModelType(name="GEMMA", weight_tying_config={"tok_embeddings.weight": "output.weight"}

(Anyways, not a blocker for this PR as it's more of a design question)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your suggestion, I update in the gemma_full_finetune.py.

@@ -259,6 +267,10 @@ def _setup_model(
),
)

if model_tie: # Tie the weights of the model if required
model.output.weight = model.tok_embeddings.weight
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was pointed out by @rohan-varma that this may not actually do what we expect because FSDP has already sharded the params, so let's double-confirm via testing that the weights are tied correctly here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing out it, I check the model weight after training, they are not same. Is there any solution? I'm not familiar with that. This can cause some problems if the weights are tied before FSDP. issue

Copy link
Contributor

@ebsmothers ebsmothers Apr 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK sorry for the back and forth on this. Confirmed with @rohan-varma that we should not tie weights after FSDP wrapping after all. The main issue was not FSDP but the initialization on meta device. Unfortunately, weight tying + meta device is tricky because the usage of to_empty breaks existing references.

Instead, for Gemma we can do everything on CPU without using meta device at all, basically initializing the model on CPU for every rank and then defining a more vanilla FSDP without the param_init_fn we currently have. This should work fine for smaller models (at least up to 7B). @kartikayk put together a snippet on what this can look like, you can find it here.

We need to decide what the best way to expose this is, but for now feel free to create a separate recipe for Gemma, e.g. gemma_full_finetune.py. It should look pretty much the same as the existing full_finetune_distributed.py, but with the changes needed to initialize everything on CPU and perform weight tying there before wrapping with FSDP.

Thanks also to @awgu for helping debug this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your suggestion, it works well.

Co-authored-by: ebsmothers <ebs@meta.com>
Copy link
Contributor

@kartikayk kartikayk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this PR @solitude-alive! This would be an awesome contribution to the repo!

Similar to @joecummings I have some questions in the code. My biggest question though is correctness. The loss from the screenshot seems to be much higher than what we've seen with Mistral/Llama2. have you compared this loss for gemma with the official implementation/some other implementation? Or have you seen some issues/blogs which show case the loss value during training that we can compare against?

Also, when adding models we provide some evidence of model numerical correctness - this is really important to build confidence with our users. Please see how we did this for llama2 13B and mistral 7B in the context section of this PR: #571. Would be great if you can add a similar check for Gemma2B. This check would look something like:

  • Load official implementation of Gemma2B and take a random tensor, run forward and get output
  • Load torchtune implementation, take same tensor, run forward and get output
  • Compare outputs with torch.allclose and make sure this returns a True.

@@ -259,6 +267,10 @@ def _setup_model(
),
)

if model_tie: # Tie the weights of the model if required
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find the use of model_tie to be a bit unintuitive. Can we rename this to something like share_weights or share_embed since I don't think we'll have other modules we share?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I update them in the latest version.

@@ -1,6 +1,7 @@
# Hugging Face Integration Reqs
datasets
huggingface_hub
safetensors
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joecummings do we need to explicitly add this if it's a part of the huggingface_hub? I guess it's good practice to explicitly call out?

Tie the weights of the output embeddings and the token embeddings in the model.

Args:
model (TransformerDecoder): The to tie the weights of the output embeddings and the token embeddings.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sentence is missing some info: "the to tie" reads a bit weird

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I modified it in the latest version.

num_kv_heads=1,
embed_dim=2048,
intermediate_dim=16384,
max_seq_len=32768,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this right? I thought this was 8192 for Gemma 2B

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, this is my problem, I fixed it in the latest version.

@@ -383,6 +383,14 @@ def load_checkpoint(self) -> Dict[str, Any]:
dim=self._config["hidden_size"],
)

if (
self._model_type == "GEMMA"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm so I have a question about this code.

hf_to_tune makes an assumption that head_dim * num_heads = dim (see here).

But this isn't true for Gemma 7B where num_heads=16 and head_dim= 256 but dim=3072 and not 4096. So we will need to differentiate between gemma 2b and 7b here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, please move to a utility function in checkpointer_utils so we can keep this code clean.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I allow explicit parameter num_heads passing in function hf_to_tune , is this allowed?

And I moved them to a utility function in checkpointer_utils.

Comment on lines 442 to 444
print(f"======={self._model_type}==========")
if (
self._model_type == "GEMMA"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this to a separate utility function in checkpointer_utils? we should keep the checkpointer as clean as possible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I did it in the latest version.

"because it is the same as the model embed_tokens weight"
)
else:
self._weight_map["lm_head.weight"] = "0002"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When will this else block hit? If we know the checkpoints don't contain this key, let's just work with that assumption? Anyways we're hard coding a bunch of stuff like the name of the key etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was because the parameters were not really tied before. Now I removed the else block.

state_dict = torch.load(
str(checkpoint_path), map_location="cpu", mmap=True, weights_only=True
)
if str(checkpoint_path).endswith(".safetensors") and is_safetensors_available():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a fan of this approach. Can we just add a key to the config, something like is_safetensors_file and then based on the value determine if we use torch.load or not. Also please break this down into a sub function (eg: load_from_safetensor or something similar.

@joecummings WDYT?

Copy link
Contributor

@kartikayk kartikayk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A huge thank you @solitude-alive for adding this functionality and also patiently addressing the many review comments. This functionality makes TorchTune better and we really appreciate all of your hard work. I'll merge this into MAIN, make a few small changes to the core recipes based on some upcoming changes and then add this to our README and cite you as the author. Thanks so much for all of the hard work!

@kartikayk kartikayk merged commit 09f9d95 into pytorch:main Apr 4, 2024
12 of 20 checks passed
tcapelle pushed a commit to tcapelle/torchtune that referenced this pull request Apr 5, 2024
Co-authored-by: ebsmothers <ebs@meta.com>
state_dict = result
else:
state_dict = torch.load(
str(checkpoint_path), map_location="cpu", mmap=True, weights_only=True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like weights_only arg is not passed around here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, but I just looked at the latest version and it has been updated.

state_dict = torch.load(
                str(checkpoint_path),
                map_location="cpu",
                mmap=True,
                weights_only=weights_only,
            )

mmap=True,
weights_only=weights_only,
is_safetensors_file = (
True if str(checkpoint_path).endswith(".safetensors") else False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: btw this seems to be the same as:

is_safetensors_file = str(...).endswith(".safetensors")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah.

@solitude-alive solitude-alive deleted the gemma branch April 18, 2024 02:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Could it support Gemma?
6 participants