Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Phi #132

Merged
merged 12 commits into from
Dec 15, 2023
Merged

Added Phi #132

merged 12 commits into from
Dec 15, 2023

Conversation

tgaddair
Copy link
Contributor

@tgaddair tgaddair commented Dec 15, 2023

Example:

lorax-launcher --model-id microsoft/phi-2

Prompt:

curl 127.0.0.1:8080/generate \
    -X POST \
    -d '{"inputs": "Instruct: Write a detailed analogy between mathematics and a lighthouse.\nOutput:", "parameters": {"max_new_tokens": 64}}' \
    -H 'Content-Type: application/json'

Response:

Mathematics is like a lighthouse. Just as a lighthouse guides ships safely to shore, mathematics provides a guiding light in the world of numbers and logic. It helps us navigate through complex problems and find solutions. Just as a lighthouse emits a steady beam of light, mathematics provides a consistent framework for reasoning and problem-solving.

@@ -325,6 +328,19 @@ def get_model(
trust_remote_code=trust_remote_code,
)
raise NotImplementedError("Qwen model requires flash attention v2")

if model_type == "phi-msft" or model_type == "phi":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you think about using model_type in [] instead ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

super().__init__()
self.num_heads = config.n_head
self.hidden_size = config.n_embd
self.head_size = self.hidden_size // self.num_heads
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it intentional to use self.num_heads like this here but then update the value of self.num_heads by the number of process groups in line 110?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, when creating the modules we need the num_heads value pre-splitting. Then later on we use num_heads post-split.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comment.

revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're adding trust_remote_code, do we already install the einops library to convert weights when using microsoft/phi-1.5?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trust_remote_code is False by default. einops is installed already (for Falcon). This current implementation works with the microsoft version of the weights, rather than the changes HF made.

@tgaddair tgaddair merged commit 549bbb8 into main Dec 15, 2023
1 check passed
@tgaddair tgaddair deleted the phi branch December 15, 2023 17:29
@tgaddair tgaddair mentioned this pull request Jan 26, 2024
2 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants