Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions examples/models/llama/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,6 @@ runtime.python_library(
runtime.python_binary(
name = "eval_llama",
main_function = "executorch.examples.models.llama.eval_llama.main",
preload_deps = [
"//executorch/extension/llm/custom_ops:custom_ops_aot_lib",
"//executorch/kernels/quantized:aot_lib",
],
deps = [
":eval_library",
"//caffe2:torch",
Expand Down
3 changes: 2 additions & 1 deletion examples/models/llama/eval_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ def eval_llama(

# Needed for loading mmlu dataset.
# See https://github.com/EleutherAI/lm-evaluation-harness/pull/1998/files
# pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks`
if args.tasks and "mmlu" in args.tasks:
import datasets

Expand All @@ -302,7 +303,7 @@ def eval_llama(
with torch.no_grad():
eval_results = simple_evaluate(
model=eval_wrapper,
tasks=args.tasks, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `tasks`
tasks=args.tasks,
num_fewshot=args.num_fewshot, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `num_fewshot`
limit=args.limit, # pyre-ignore: Undefined attribute [16]: `argparse.ArgumentParser` has no attribute `limit`
)
Expand Down
29 changes: 29 additions & 0 deletions examples/models/llama/runner/TARGETS
Original file line number Diff line number Diff line change
@@ -1,8 +1,37 @@
# Any targets that should be shared between fbcode and xplat must be defined in
# targets.bzl. This file can contain fbcode-only targets.

load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
load(":targets.bzl", "define_common_targets")

oncall("executorch")

define_common_targets()

runtime.python_library(
name = "eager_runner_library",
srcs = [
"eager.py",
"generation.py"
],
_is_external_target = True,
base_module = "executorch.examples.models.llama.runner",
visibility = [
"//bento/...",
"//bento_kernels/...",
"//executorch/examples/...",
"@EXECUTORCH_CLIENTS",
],
deps = [
"//executorch/examples/models/llama:export_library",
],
)

runtime.python_binary(
name = "eager",
main_function = "executorch.examples.models.llama.runner.eager.main",
deps = [
":eager_runner_library",
"//caffe2:torch",
],
)
9 changes: 4 additions & 5 deletions examples/models/llama/runner/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@
from typing import Optional

import torch

from examples.models.llama.llama_transformer import ModelArgs
from executorch.examples.models.llama.export_llama_lib import (
_prepare_for_llama_export,
build_args_parser as _build_args_parser,
)
from executorch.examples.models.llama.llama_transformer import ModelArgs
from executorch.examples.models.llama.runner.generation import LlamaRunner
from executorch.extension.llm.export import LLMEdgeManager
from executorch.extension.llm.export.builder import LLMEdgeManager


class EagerLlamaRunner(LlamaRunner):
Expand All @@ -43,8 +42,8 @@ def __init__(self, args):

def forward(
self,
tokens: Optional[torch.LongTensor] = None,
input_pos: Optional[torch.LongTensor] = None,
tokens: torch.Tensor,
input_pos: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model.forward(tokens=tokens, input_pos=input_pos)

Expand Down
7 changes: 4 additions & 3 deletions examples/models/llama/runner/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

class CompletionPrediction(TypedDict, total=False):
generation: str
tokens: List[str] # not required
tokens: List[int] # not required


def sample_top_p(probs, p):
Expand Down Expand Up @@ -47,6 +47,7 @@ def next_token(logits: torch.Tensor, temperature: float, top_p: float) -> int:
if temperature > 0:
probs = torch.softmax(logits / temperature, dim=-1)
return sample_top_p(probs, top_p).item()
# Pyre-ignore[7]: Incompatible return type [7]: Expected `int` but got `Union[bool, float, int]`
return torch.argmax(logits, dim=-1).item()


Expand All @@ -60,8 +61,8 @@ def __init__(self, tokenizer_path: str, model_args: ModelArgs, device: str = "cp
@abstractmethod
def forward(
self,
tokens: Optional[torch.LongTensor] = None,
input_pos: Optional[torch.LongTensor] = None,
tokens: torch.Tensor,
input_pos: Optional[torch.Tensor] = None,
) -> torch.Tensor:
pass

Expand Down
4 changes: 2 additions & 2 deletions examples/models/llama/runner/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def __init__(self, args):

def forward(
self,
tokens: Optional[torch.LongTensor] = None,
input_pos: Optional[torch.LongTensor] = None,
tokens: torch.Tensor,
input_pos: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return (
self.model.forward((tokens, input_pos))
Expand Down
Loading