From 1dd27fa0420add0ad88d6a54b4690164c71e8711 Mon Sep 17 00:00:00 2001 From: Lunwen He Date: Fri, 18 Oct 2024 16:39:34 -0700 Subject: [PATCH] fix eager run for cuda [ghstack-poisoned] --- examples/models/llama/runner/eager.py | 12 ++++++------ examples/models/llama/runner/generation.py | 19 +++++++++++++------ 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index cff5c4f8023..e116e08a099 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -33,13 +33,13 @@ def __init__(self, args): use_kv_cache=args.use_kv_cache, **params, ) - super().__init__(tokenizer_path=args.tokenizer_path, model_args=model_args) - manager: LLMEdgeManager = _prepare_for_llama_export("llama", args) - self.model = ( - manager.model.eval().to(device="cuda") - if torch.cuda.is_available() - else manager.model.eval().to(device="cpu") + super().__init__( + tokenizer_path=args.tokenizer_path, + model_args=model_args, + device="cuda" if torch.cuda.is_available() else "cpu", ) + manager: LLMEdgeManager = _prepare_for_llama_export("llama", args) + self.model = manager.model.eval().to(device=self.device) def forward( self, diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index 3f7937cd5a8..27e89568b02 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -51,10 +51,11 @@ def next_token(logits: torch.Tensor, temperature: float, top_p: float) -> int: class LlamaRunner(ABC): - def __init__(self, tokenizer_path: str, model_args: ModelArgs): + def __init__(self, tokenizer_path: str, model_args: ModelArgs, device: str = "cpu"): self.params = model_args self.tokenizer = get_tokenizer(tokenizer_path) assert model_args.vocab_size == self.tokenizer.n_words + self.device = device @abstractmethod def forward( @@ -73,9 +74,9 @@ def generate( # noqa: C901 ) -> List[int]: # prefill logits = self.forward( - tokens=torch.tensor([prompt_tokens], dtype=torch.long), + tokens=torch.tensor([prompt_tokens], dtype=torch.long).to(self.device), input_pos=( - torch.tensor([0], dtype=torch.long) + torch.tensor([0], dtype=torch.long).to(self.device) if self.params.use_kv_cache else None ), @@ -87,11 +88,17 @@ def generate( # noqa: C901 while len(tokens) < self.params.max_seq_len: if self.params.use_kv_cache: logits = self.forward( - tokens=torch.tensor([[current_token]], dtype=torch.long), - input_pos=torch.tensor([len(tokens) - 1], dtype=torch.long), + tokens=torch.tensor([[current_token]], dtype=torch.long).to( + self.device + ), + input_pos=torch.tensor([len(tokens) - 1], dtype=torch.long).to( + self.device + ), ) else: - logits = self.forward(tokens=torch.tensor([tokens], dtype=torch.long)) + logits = self.forward( + tokens=torch.tensor([tokens], dtype=torch.long).to(self.device) + ) current_token = next_token(logits, temperature, top_p) if current_token == self.tokenizer.eos_id or ( hasattr(self, "stop_tokens") and current_token in self.stop_tokens