From cc3eb77f5e55c2446fc3f6339b2525961fd4934f Mon Sep 17 00:00:00 2001 From: Lunwen He Date: Mon, 21 Oct 2024 13:54:48 -0700 Subject: [PATCH] fix eager run for cuda ghstack-source-id: 8278f05c85f41b237d82148dca24648cd67d2b15 Pull Request resolved: https://github.com/pytorch/executorch/pull/6365 --- examples/models/llama/runner/eager.py | 12 ++++++------ examples/models/llama/runner/generation.py | 22 +++++++++++++++------- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index cff5c4f8023..e116e08a099 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -33,13 +33,13 @@ def __init__(self, args): use_kv_cache=args.use_kv_cache, **params, ) - super().__init__(tokenizer_path=args.tokenizer_path, model_args=model_args) - manager: LLMEdgeManager = _prepare_for_llama_export("llama", args) - self.model = ( - manager.model.eval().to(device="cuda") - if torch.cuda.is_available() - else manager.model.eval().to(device="cpu") + super().__init__( + tokenizer_path=args.tokenizer_path, + model_args=model_args, + device="cuda" if torch.cuda.is_available() else "cpu", ) + manager: LLMEdgeManager = _prepare_for_llama_export("llama", args) + self.model = manager.model.eval().to(device=self.device) def forward( self, diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index 3f7937cd5a8..e332e0ebe2e 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -51,10 +51,11 @@ def next_token(logits: torch.Tensor, temperature: float, top_p: float) -> int: class LlamaRunner(ABC): - def __init__(self, tokenizer_path: str, model_args: ModelArgs): + def __init__(self, tokenizer_path: str, model_args: ModelArgs, device: str = "cpu"): self.params = model_args self.tokenizer = get_tokenizer(tokenizer_path) assert model_args.vocab_size == self.tokenizer.n_words + self.device = device @abstractmethod def forward( @@ -73,9 +74,9 @@ def generate( # noqa: C901 ) -> List[int]: # prefill logits = self.forward( - tokens=torch.tensor([prompt_tokens], dtype=torch.long), + tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device), input_pos=( - torch.tensor([0], dtype=torch.long) + torch.tensor([0], dtype=torch.long, device=self.device) if self.params.use_kv_cache else None ), @@ -87,14 +88,21 @@ def generate( # noqa: C901 while len(tokens) < self.params.max_seq_len: if self.params.use_kv_cache: logits = self.forward( - tokens=torch.tensor([[current_token]], dtype=torch.long), - input_pos=torch.tensor([len(tokens) - 1], dtype=torch.long), + tokens=torch.tensor( + [[current_token]], dtype=torch.long, device=self.device + ), + input_pos=torch.tensor( + [len(tokens) - 1], dtype=torch.long, device=self.device + ), ) else: - logits = self.forward(tokens=torch.tensor([tokens], dtype=torch.long)) + logits = self.forward( + tokens=torch.tensor([tokens], dtype=torch.long, device=self.device), + ) current_token = next_token(logits, temperature, top_p) if current_token == self.tokenizer.eos_id or ( - hasattr(self, "stop_tokens") and current_token in self.stop_tokens + hasattr(self.tokenizer, "stop_tokens") + and current_token in self.tokenizer.stop_tokens ): break tokens.append(current_token)