Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support llama3.1-8b generation #947

Merged
merged 4 commits into from
Jul 24, 2024
Merged

Support llama3.1-8b generation #947

merged 4 commits into from
Jul 24, 2024

Conversation

Gasoonjia
Copy link
Contributor

@Gasoonjia Gasoonjia commented Jul 24, 2024

Llama3.1 8b now is supported in torchchat! 🎉

Local test:
image

Copy link

pytorch-bot bot commented Jul 24, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/947

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1f309dd with merge base 7b4fa7c (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 24, 2024
@Gasoonjia Gasoonjia changed the title Support llama3.1 generation Support llama3.1-8b generation Jul 24, 2024
build/model.py Outdated
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
for freq in freqs:
wavelen = 2 * math.pi / freq
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
wavelen = 2 * math.pi / freq
wavelen = 2 * torch.pi / freq

build/model.py Outdated
Comment on lines 8 to 9
import math

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use pi from torch rather than math

Suggested change
import math

@@ -40,6 +40,12 @@
"distribution_path": "meta-llama/Meta-Llama-3-70B-Instruct",
"transformer_params_key": "Meta-Llama-3-70B"
},
"meta-llama/Meta-Llama-3.1-8B-Instruct": {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: technically you can also add support for the pre-trained model (meta-llama/Meta-Llama-3.1-8B) and the llama guard (meta-llama/Llama-Guard-3-8B). Not a requirement though

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably want support for the pre-trained (non-Instruct) version to match with our support for Llama 3.0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah of course. This PR is focused on enabling llama3.1 in torchchat, so didn't cover all possible models. Will have another PR to handle that.


low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
new_freqs = []
Copy link
Contributor

@malfet malfet Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, it feels like one can write this logic in a much more pytorchy-style:

new_freqs = 360 / freq.rad2deg()
new_freqs[new_freqs > high_freq_wavelen ] /= scale_factor
new_freqs[new_freqs < low_freq_wavelen * scale_factor] = apply_smooth_here 

@byjlw
Copy link
Contributor

byjlw commented Jul 24, 2024

Awesome! Please update the readme so people can discover it :)

@Gasoonjia Gasoonjia merged commit 3e28e5d into main Jul 24, 2024
51 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants