-
Notifications
You must be signed in to change notification settings - Fork 144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
self extend / longlm #186
self extend / longlm #186
Conversation
@tgaddair could you please take a look for the failing rust test ? |
tested and working :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! I have some doubts about the way we're computing attention. Happy to discuss further.
k_embed = (k * cos) + (rotate_half(k) * sin) | ||
return q_embed, k_embed | ||
|
||
def self_extend_forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This signature is quite different from what the current implementation without self-extend uses. Instead of using flash attention and paged attention, it uses the more conventional attention computation with past_key_values
. As such, because the rest of the FlashMistral, etc. classes don't pass in the past_key_values
, my expectation is that the attention computation will be incorrect during the decode phase.
I think what we need is a variation on this function that works with the existing Flash / Paged Attention computation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried with and without the self extend and the generated response was the same
The actual implementation is used from the authors, i think (https://github.com/datamllab/LongLM)
I definitely think that we should add the flash attention version, but dont know if it would make sense to wait until they released it with tested results
from the paper:
Limitation: The limitation of the proposed Self-Extend includes the lack of implementation of Flash Attention (Dao
et al., 2022) and the performance degradation with too large
group size, which means the context window still cannot be
extended to infinity with current SelfExtend. Meanwhile,
like many regular tasks, there is still no consensus at present
about how to do evaluation for long context tasks, which
may cause problematic evaluation results.
Future Work: For future work, we will implement Flash
Attention for Self-Extend to enhance its efficiency. We
are also interested in testing SelfExtend on models using
other positional encoding. Larger models, longer context
and more challenging tasks will be tested if we can have
access to more computational resources in the future. In
the meantime, more sophisticated mapping methods will
be considered as the replacement of the simple FLOOR operation, so as to achieve better long context understanding
abilities and longer extended context window length.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @flozi00, it looks like there was a bug in the code causing the self extend code path to never be executed. I fixed the issue with plumbing through the self_extend_attention
param, and now there are some errors showing up. So my suspicion is that the answers were identical because we were executing the non-extended code in both cases.
Here's the current error:
File "/data/lorax/server/lorax_server/models/flash_mistral.py", line 427, in forward
logits = model.forward(
File "/data/lorax/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py", line 565, in forward
hidden_states = self.model(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/data/lorax/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py", line 508, in forward
hidden_states, residual = layer(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/data/lorax/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py", line 431, in forward
attn_output = self.self_attn(
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
TypeError: self_extend_forward() got multiple values for argument 'group_size_1'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, okay
will take this tomorrow
Thanks for finding this missing one
"Another good news: the flash attention version will come in days!" Will wait for that :) |
Hi, I don't have to set up the vars because I am editing branches from this repo and not in a fork |
What does this PR do?
Fixes # (issue)
Before submitting
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.