-
Notifications
You must be signed in to change notification settings - Fork 143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix for the LM_HEAD issue #475
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was a good catch! In terms of making this more general, I think we can safely move this into the BatchLoraWeights.load function which stores off unique segment_indices
per layer.
One thing we would just need to change is to (1) plumb through the batch.prefill_head_indices
here and (2) plumb through the layer name (which is k
) here.
Also, looks like there are a couple merge conflicts that need to be cleaned up.
Sounds good. Will make those changes and do a basic testing round with a few different model+LoRA combinations. |
Refactored the fix to be in `BatchLoraWeights.load` method and added plumbing for it in `generate_token` in flash_causal_lm.py so that all models get this fix.
1. `AdaterBatchData.from_meta` calls in tests and `causal_lm.py` 2. `BatchLoraWeights.load` calls in tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for the fix!
Fix for the LM_HEAD issue
Root Cause. The error is caused by incorrect segments passed to the
lora_b_sgmv
kernel during the prefill stage. This happens because we do not want to forward all the tokens in the prompt throughlm_head
and associated adapters. The goal is to save compute and memory by not having to forward the tokens that are not involved in generating the next token. Only the last token is needed for this purpose. However, doing this changes the shape of the net batch size (batch size times number of tokens) that the lora kernel sees, but the segment start and end tensors are not changed. These tensors are used to slice/segment the batch in the kernel. Hence, keeping them unchanged is incorrect. Additionally, the resulting error from this problem not reported correctly since the kernel has a catch-call condition that reports a generickernel not found for the dtype
message. My guess is that it's an out-of-bounds memory access, but confirming this would require changing the kernel. This is out of the scope of this investigation.Description of the Fix. The fix simply goes over the adapters for
lm_head
and adjusts the segment start and end tensors to correctly point to the segments in the batch. For now, this is done only during the prefill stage and when a LoRA adapter for thelm_head
is present.Tests. TBD
Fixes #163
Before submitting
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.