diff --git a/examples/llama3-8b-ec2/llama3_ec2.py b/examples/llama3-8b-ec2/llama3_ec2.py index 524ca41be..542ad0888 100644 --- a/examples/llama3-8b-ec2/llama3_ec2.py +++ b/examples/llama3-8b-ec2/llama3_ec2.py @@ -1,15 +1,17 @@ -# # Deploy Llama2 13B Chat Model Inference on AWS EC2 +# # Deploy Llama3 8B Chat Model Inference on AWS EC2 # This example demonstrates how to deploy a -# [LLama2 13B model from Hugging Face](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) +# [LLama3 8B model from Hugging Face](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) # on AWS EC2 using Runhouse. # +# Make sure to sign the waiver on the model page so that you can access it. +# # ## Setup credentials and dependencies # # Optionally, set up a virtual environment: # ```shell -# $ conda create -n llama-demo-apps python=3.8 -# $ conda activate llama-demo-apps +# $ conda create -n llama3-rh +# $ conda activate llama3-rh # ``` # Install the few required dependencies: # ```shell @@ -22,7 +24,7 @@ # $ aws configure # $ sky check # ``` -# We'll be downloading the Llama2 model from Hugging Face, so we need to set up our Hugging Face token: +# We'll be downloading the Llama3 model from Hugging Face, so we need to set up our Hugging Face token: # ```shell # $ export HF_TOKEN= # ``` @@ -55,7 +57,7 @@ def load_model(self): self.pipeline = transformers.pipeline( "text-generation", model=self.model_id, - model_kwargs={"torch_dtype": torch.bfloat16}, + model_kwargs=self.model_kwargs, device="cuda", ) @@ -108,7 +110,6 @@ def predict(self, prompt_text, **inf_kwargs): gpu = rh.cluster( name="rh-a10x", instance_type="A10G:1", memory="32+", provider="aws" ) - # gpu.restart_server(restart_ray=True) # Next, we define the environment for our module. This includes the required dependencies that need # to be installed on the remote machine, as well as any secrets that need to be synced up from local to remote. @@ -124,38 +125,34 @@ def predict(self, prompt_text, **inf_kwargs): "safetensors", "scipy", ], - secrets=["huggingface"], # Needed to download Llama2 - name="llama2inference", + secrets=["huggingface"], # Needed to download Llama3 from HuggingFace + name="llama3inference", working_dir="./", ) # Finally, we define our module and run it on the remote cluster. We construct it normally and then call # `get_or_to` to run it on the remote cluster. Using `get_or_to` allows us to load the exiting Module - # by the name `llama-13b-model` if it was already put on the cluster. If we want to update the module each + # by the name `llama3-8b-model` if it was already put on the cluster. If we want to update the module each # time we run this script, we can use `to` instead of `get_or_to`. # # Note that we also pass the `env` object to the `get_or_to` method, which will ensure that the environment is # set up on the remote machine before the module is run. remote_hf_chat_model = HFChatModel( - load_in_4bit=True, # Ignored right now - torch_dtype=torch.bfloat16, # Ignored right now - device_map="auto", # Ignored right now - ).get_or_to(gpu, env=env, name="llama-13b-model") + torch_dtype=torch.bfloat16, + ).get_or_to(gpu, env=env, name="llama3-8b-model") # ## Calling our remote function # # We can call the `predict` method on the model class instance if it were running locally. # This will run the function on the remote cluster and return the response to our local machine automatically. # Further calls will also run on the remote machine, and maintain state that was updated between calls, like - # `self.model` and `self.tokenizer`. - prompt = "Who are you?" - print(remote_hf_chat_model.predict(prompt)) - # while True: - # prompt = input( - # "\n\n... Enter a prompt to chat with the model, and 'exit' to exit ...\n" - # ) - # if prompt.lower() == "exit": - # break - # output = remote_hf_chat_model.predict(prompt) - # print("\n\n... Model Output ...\n") - # print(output) + # `self.pipeline`. + while True: + prompt = input( + "\n\n... Enter a prompt to chat with the model, and 'exit' to exit ...\n" + ) + if prompt.lower().strip() == "exit": + break + output = remote_hf_chat_model.predict(prompt) + print("\n\n... Model Output ...\n") + print(output)