Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update llama3 example. #743

Merged
merged 1 commit into from
Apr 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 23 additions & 26 deletions examples/llama3-8b-ec2/llama3_ec2.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
# # Deploy Llama2 13B Chat Model Inference on AWS EC2
# # Deploy Llama3 8B Chat Model Inference on AWS EC2

# This example demonstrates how to deploy a
# [LLama2 13B model from Hugging Face](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf)
# [LLama3 8B model from Hugging Face](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
# on AWS EC2 using Runhouse.
#
# Make sure to sign the waiver on the model page so that you can access it.
#
# ## Setup credentials and dependencies
#
# Optionally, set up a virtual environment:
# ```shell
# $ conda create -n llama-demo-apps python=3.8
# $ conda activate llama-demo-apps
# $ conda create -n llama3-rh
# $ conda activate llama3-rh
# ```
# Install the few required dependencies:
# ```shell
Expand All @@ -22,7 +24,7 @@
# $ aws configure
# $ sky check
# ```
# We'll be downloading the Llama2 model from Hugging Face, so we need to set up our Hugging Face token:
# We'll be downloading the Llama3 model from Hugging Face, so we need to set up our Hugging Face token:
# ```shell
# $ export HF_TOKEN=<your huggingface token>
# ```
Expand Down Expand Up @@ -55,7 +57,7 @@ def load_model(self):
self.pipeline = transformers.pipeline(
"text-generation",
model=self.model_id,
model_kwargs={"torch_dtype": torch.bfloat16},
model_kwargs=self.model_kwargs,
device="cuda",
)

Expand Down Expand Up @@ -108,7 +110,6 @@ def predict(self, prompt_text, **inf_kwargs):
gpu = rh.cluster(
name="rh-a10x", instance_type="A10G:1", memory="32+", provider="aws"
)
# gpu.restart_server(restart_ray=True)

# Next, we define the environment for our module. This includes the required dependencies that need
# to be installed on the remote machine, as well as any secrets that need to be synced up from local to remote.
Expand All @@ -124,38 +125,34 @@ def predict(self, prompt_text, **inf_kwargs):
"safetensors",
"scipy",
],
secrets=["huggingface"], # Needed to download Llama2
name="llama2inference",
secrets=["huggingface"], # Needed to download Llama3 from HuggingFace
name="llama3inference",
working_dir="./",
)

# Finally, we define our module and run it on the remote cluster. We construct it normally and then call
# `get_or_to` to run it on the remote cluster. Using `get_or_to` allows us to load the exiting Module
# by the name `llama-13b-model` if it was already put on the cluster. If we want to update the module each
# by the name `llama3-8b-model` if it was already put on the cluster. If we want to update the module each
# time we run this script, we can use `to` instead of `get_or_to`.
#
# Note that we also pass the `env` object to the `get_or_to` method, which will ensure that the environment is
# set up on the remote machine before the module is run.
remote_hf_chat_model = HFChatModel(
load_in_4bit=True, # Ignored right now
torch_dtype=torch.bfloat16, # Ignored right now
device_map="auto", # Ignored right now
).get_or_to(gpu, env=env, name="llama-13b-model")
torch_dtype=torch.bfloat16,
).get_or_to(gpu, env=env, name="llama3-8b-model")

# ## Calling our remote function
#
# We can call the `predict` method on the model class instance if it were running locally.
# This will run the function on the remote cluster and return the response to our local machine automatically.
# Further calls will also run on the remote machine, and maintain state that was updated between calls, like
# `self.model` and `self.tokenizer`.
prompt = "Who are you?"
print(remote_hf_chat_model.predict(prompt))
# while True:
# prompt = input(
# "\n\n... Enter a prompt to chat with the model, and 'exit' to exit ...\n"
# )
# if prompt.lower() == "exit":
# break
# output = remote_hf_chat_model.predict(prompt)
# print("\n\n... Model Output ...\n")
# print(output)
# `self.pipeline`.
while True:
prompt = input(
"\n\n... Enter a prompt to chat with the model, and 'exit' to exit ...\n"
)
if prompt.lower().strip() == "exit":
break
output = remote_hf_chat_model.predict(prompt)
print("\n\n... Model Output ...\n")
print(output)
Loading