Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

self extend / longlm #186

Closed
wants to merge 9 commits into from
Closed

self extend / longlm #186

wants to merge 9 commits into from

Conversation

flozi00
Copy link
Collaborator

@flozi00 flozi00 commented Jan 16, 2024

What does this PR do?

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Was this discussed/approved via a Github issue or the discord / slack channel? Please add a link
    to it if that's the case.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@flozi00 flozi00 linked an issue Jan 16, 2024 that may be closed by this pull request
@flozi00
Copy link
Collaborator Author

flozi00 commented Jan 16, 2024

@tgaddair could you please take a look for the failing rust test ?
And as everytime its not tested at the moment, xould you provide a docker again ?

@flozi00 flozi00 requested a review from tgaddair January 17, 2024 14:32
@flozi00
Copy link
Collaborator Author

flozi00 commented Jan 17, 2024

tested and working :)

Copy link
Contributor

@tgaddair tgaddair left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I have some doubts about the way we're computing attention. Happy to discuss further.

k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed

def self_extend_forward(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This signature is quite different from what the current implementation without self-extend uses. Instead of using flash attention and paged attention, it uses the more conventional attention computation with past_key_values. As such, because the rest of the FlashMistral, etc. classes don't pass in the past_key_values, my expectation is that the attention computation will be incorrect during the decode phase.

I think what we need is a variation on this function that works with the existing Flash / Paged Attention computation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried with and without the self extend and the generated response was the same
The actual implementation is used from the authors, i think (https://github.com/datamllab/LongLM)

I definitely think that we should add the flash attention version, but dont know if it would make sense to wait until they released it with tested results

from the paper:

Limitation: The limitation of the proposed Self-Extend includes the lack of implementation of Flash Attention (Dao
et al., 2022) and the performance degradation with too large
group size, which means the context window still cannot be
extended to infinity with current SelfExtend. Meanwhile,
like many regular tasks, there is still no consensus at present
about how to do evaluation for long context tasks, which
may cause problematic evaluation results.

Future Work: For future work, we will implement Flash
Attention for Self-Extend to enhance its efficiency. We
are also interested in testing SelfExtend on models using
other positional encoding. Larger models, longer context
and more challenging tasks will be tested if we can have
access to more computational resources in the future. In
the meantime, more sophisticated mapping methods will
be considered as the replacement of the simple FLOOR operation, so as to achieve better long context understanding
abilities and longer extended context window length.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @flozi00, it looks like there was a bug in the code causing the self extend code path to never be executed. I fixed the issue with plumbing through the self_extend_attention param, and now there are some errors showing up. So my suspicion is that the answers were identical because we were executing the non-extended code in both cases.

Here's the current error:

  File "/data/lorax/server/lorax_server/models/flash_mistral.py", line 427, in forward                                                                                            
    logits = model.forward(                                                                                                                                                       
  File "/data/lorax/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py", line 565, in forward                                                                   
    hidden_states = self.model(                                                                                                                                                   
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl                                                                     
    return self._call_impl(*args, **kwargs)                                                                                                                                       
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl                                                                             
    return forward_call(*args, **kwargs)                                                                                                                                          
  File "/data/lorax/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py", line 508, in forward                                                                   
    hidden_states, residual = layer(                                                                                                                                              
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl                                                                     
    return self._call_impl(*args, **kwargs)                                                                                                                                       
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl                                                                             
    return forward_call(*args, **kwargs)                                                                                                                                          
  File "/data/lorax/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py", line 431, in forward                                                                   
    attn_output = self.self_attn(                                                                                                                                                 
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl                                                                     
    return self._call_impl(*args, **kwargs)                                                                                                                                       
  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl                                                                             
    return forward_call(*args, **kwargs)                                                                                                                                          
TypeError: self_extend_forward() got multiple values for argument 'group_size_1' 

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, okay
will take this tomorrow
Thanks for finding this missing one

@flozi00 flozi00 marked this pull request as draft January 19, 2024 16:24
@flozi00
Copy link
Collaborator Author

flozi00 commented Jan 20, 2024

"Another good news: the flash attention version will come in days!"
https://github.com/datamllab/LongLM#:~:text=Another%20good%20news%3A%20the%20flash%20attention%20version%20will%20come%20in%20days!

Will wait for that :)

@akelch11
Copy link

Hi @flozi00 , I saw that you had a recent pull request that passed all the tests. Do you have to do anything to supply environment variables or configs to the tests? My PR is having issues with this, failing server tests due to not being able to login into HuggingFace and connect to Llama2

@flozi00
Copy link
Collaborator Author

flozi00 commented Jan 23, 2024

Hi @flozi00 , I saw that you had a recent pull request that passed all the tests. Do you have to do anything to supply environment variables or configs to the tests? My PR is having issues with this, failing server tests due to not being able to login into HuggingFace and connect to Llama2

Hi, I don't have to set up the vars because I am editing branches from this repo and not in a fork

@flozi00 flozi00 closed this Feb 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

LongLM
3 participants