Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix biachuan-7b tp #598

Merged
merged 1 commit into from
Aug 1, 2023
Merged

Conversation

Sanster
Copy link
Contributor

@Sanster Sanster commented Jul 27, 2023

The main modifications are in the "load_weights" function.

Before:
image

After:
image

@LiVincent-Zhang
Copy link

Is the same reason for baichuan-13b? #530

@Sanster
Copy link
Contributor Author

Sanster commented Jul 27, 2023

Is the same reason for baichuan-13b? #530

Yes. I have tested it on both baichuan13b and 7b, and it can output normal output under tp.

@LiVincent-Zhang
Copy link

Is the same reason for baichuan-13b? #530

Yes. I have tested it on both baichuan13b and 7b, and it can output normal output under tp.

Can I use this PR directly on 13B?

Copy link
Collaborator

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your contribution! Can you use our official formatting script and remove other additional format changes?

Comment on lines 282 to 294
if "embed_tokens" in name or "lm_head" in name:
# Consider padding in the vocab size.
param = state_dict[name]
padded_vocab_size = param.shape[0] * tp_world_size
num_extra_rows = padded_vocab_size - self.config.vocab_size
extra_rows = torch.empty(num_extra_rows, loaded_weight.shape[1])
extra_rows = extra_rows.to(loaded_weight)
loaded_weight = torch.cat([loaded_weight, extra_rows], dim=0)

if "W_pack" in name:
# W_pack.weight.shape [3*hidden_size, hidden_size] [3*4096, 4096] = [12,288, 4096]
total_num_heads = self.config.num_attention_heads
hidden_size = self.config.hidden_size
head_size = hidden_size // total_num_heads
num_heads = total_num_heads // tp_world_size
head_start = tp_rank * num_heads
head_end = (tp_rank + 1) * num_heads

loaded_weight = loaded_weight.view(
3, total_num_heads, head_size, hidden_size
)
loaded_weight = loaded_weight[:, head_start:head_end, :, :]
loaded_weight = loaded_weight.reshape(-1, hidden_size)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this part the only part that actually changes the code logic? Can you remove other format-only modifications and use format.sh script provided by us to re-format the code? Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, I have already modified the content of the PR and removed the invalid format part.

Copy link
Collaborator

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thank you for your contribution!

@zhuohan123 zhuohan123 merged commit d4c7755 into vllm-project:main Aug 1, 2023
2 checks passed
hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Feb 13, 2024
Co-authored-by: wq.chu <wq.chu@tianrang-inc.com>
sjchoi1 pushed a commit to casys-kaist-internal/vllm that referenced this pull request May 7, 2024
Co-authored-by: wq.chu <wq.chu@tianrang-inc.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants