Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for the LM_HEAD issue #475

Merged
merged 7 commits into from
May 23, 2024
Merged

Fix for the LM_HEAD issue #475

merged 7 commits into from
May 23, 2024

Conversation

ajtejankar
Copy link
Contributor

@ajtejankar ajtejankar commented May 18, 2024

Fix for the LM_HEAD issue

Root Cause. The error is caused by incorrect segments passed to the lora_b_sgmv kernel during the prefill stage. This happens because we do not want to forward all the tokens in the prompt through lm_head and associated adapters. The goal is to save compute and memory by not having to forward the tokens that are not involved in generating the next token. Only the last token is needed for this purpose. However, doing this changes the shape of the net batch size (batch size times number of tokens) that the lora kernel sees, but the segment start and end tensors are not changed. These tensors are used to slice/segment the batch in the kernel. Hence, keeping them unchanged is incorrect. Additionally, the resulting error from this problem not reported correctly since the kernel has a catch-call condition that reports a generic kernel not found for the dtype message. My guess is that it's an out-of-bounds memory access, but confirming this would require changing the kernel. This is out of the scope of this investigation.

Description of the Fix. The fix simply goes over the adapters for lm_head and adjusts the segment start and end tensors to correctly point to the segments in the batch. For now, this is done only during the prefill stage and when a LoRA adapter for the lm_head is present.

Tests. TBD

Fixes #163

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Was this discussed/approved via a Github issue or the discord / slack channel? Please add a link
    to it if that's the case.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Copy link
Contributor

@tgaddair tgaddair left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a good catch! In terms of making this more general, I think we can safely move this into the BatchLoraWeights.load function which stores off unique segment_indices per layer.

One thing we would just need to change is to (1) plumb through the batch.prefill_head_indices here and (2) plumb through the layer name (which is k) here.

Also, looks like there are a couple merge conflicts that need to be cleaned up.

@ajtejankar
Copy link
Contributor Author

Sounds good. Will make those changes and do a basic testing round with a few different model+LoRA combinations.

ajtejankar and others added 6 commits May 20, 2024 18:49
Refactored the fix to be in `BatchLoraWeights.load` method and added
plumbing for it in `generate_token` in flash_causal_lm.py so that all
models get this fix.
1. `AdaterBatchData.from_meta` calls in tests and `causal_lm.py`
2. `BatchLoraWeights.load` calls in tests
Copy link
Contributor

@tgaddair tgaddair left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the fix!

@tgaddair tgaddair merged commit da90421 into main May 23, 2024
1 check passed
@tgaddair tgaddair deleted the lm-head branch May 23, 2024 20:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

CUDA error when for sgmv_lora_b for LM_head with many concurrent requests
2 participants