Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply QLoRA to output projections and token embedding #1000

Open
rohan-varma opened this issue May 17, 2024 · 4 comments
Open

Apply QLoRA to output projections and token embedding #1000

rohan-varma opened this issue May 17, 2024 · 4 comments

Comments

@rohan-varma
Copy link
Member

Currently, we don't apply QLoRA to either the output projection or token embeddings. There's no great reason to not apply quantization to output projections, we simply don't do this due to limitations in torchao (quantized large weights somehow taking up more memory than unquantized). We should begin to quantize output proj once this is fixed in AO.

Code pointer to where we skip quantizing output projection in torchtune: https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama2/_component_builders.py#L235-L239.

On the token embedding quantization, applying LoRA to embeddings is relatively unexplored research wise, but there have been some requests i.e. in huggingface/peft#349. We might want to explore this for even more memory saving.

@Optimox
Copy link
Contributor

Optimox commented Jun 4, 2024

@rohan-varma is it ok to simply unfreeze the embedding layer before fine-tuning with LoRA with something like this :

for param in model.tok_embeddings.parameters():
    param.requires_grad = True

?

@ebsmothers
Copy link
Contributor

Hi @Optimox is your idea to just fine-tune the embedding layer directly without any additional LoRA weights? If so this will work. If you're doing a distributed run you may need to be a bit careful about the FSDP wrapping though (I don't think anything will break, but there's the possibility of using extra memory if you don't change the wrapping). Btw adding a proper LoRAEmbedding layer is still pretty high on our wishlist, if you're interested in helping out this'd be a great contribution.

@Optimox
Copy link
Contributor

Optimox commented Jun 4, 2024

@ebsmothers yes I've tried to simply unfreeze the embeddings and train the rest of the model with LORA. It seems to be working ok on a single machine. I'm just wondering if this is a good practice or if there is a good reason to keep the embeddings frozen when finetuning with LORA ?
About the LoRAEmbedding will the initial embedding layer be a bottle neck at some point in terms of memory usage ?

@ebsmothers
Copy link
Contributor

@Optimox sorry somehow your message slipped through the cracks here.

I'm not sure what the best practice is here in terms of model quality. One thing that could matter is whether you are trying to learn new embeddings. E.g. if you have a new special token in your tokenize and it's untrained, making the full embedding trainable during fine-tuning may be a good way to learn richer information about that special token (whereas for LoRA you would wind up learning a much lower-dimensional representation of this token than for other tokens that the model has already been pretrained on).

Re memory usage, say we use a LoRA rank of 8. Taking Llama3-8B as an example, the vocab size is about 128k and the embed dim is about 4k. This means if you fully fine-tune the embedding matrix you will have 128k * 4k gradients, in bf16 this would be about 1 GB. But with LoRA you would only have 128k * 8 + 4k * 8, which is more like 2-3 MB of gradients (all math here is very approximate). So the memory savings are nontrivial by applying LoRA to the embedding, just depends on how much memory you are using elsewhere and how much you have available.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants