diff --git a/build/builder.py b/build/builder.py index 85773a6c1..d8ba6e019 100644 --- a/build/builder.py +++ b/build/builder.py @@ -498,7 +498,7 @@ def _initialize_model( device_sync(device=builder_args.device) try: - from build.model_et import PTEModel + from build.model import PTEModel model = PTEModel(model.config, builder_args.pte_path) except Exception: diff --git a/build/model.py b/build/model.py index 2401e5724..27d1500bb 100644 --- a/build/model.py +++ b/build/model.py @@ -418,3 +418,36 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: x_out2 = x_out2.flatten(3) return x_out2.type_as(x) + + +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# ExecuTorch model components +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +try: + from executorch.extension.pybindings import portable_lib as exec_lib + + # ET changed the way it's loading the custom ops so it's not included in portable_lib but has to be loaded separately. + from executorch.examples.models.llama2.custom_ops import sdpa_with_kv_cache # no-qa + + class PTEModel(nn.Module): + def __init__(self, config, path) -> None: + super().__init__() + self.config = config + self.model_ = exec_lib._load_for_executorch(str(path)) + + def forward(self, x, input_pos): + # model_.forward expects inputs to be wrapped in a tuple + forward_inputs = (x.to(torch.long), input_pos.to(torch.long)) + logits = self.model_.forward(forward_inputs) + + # After wrapping in a tuple, we get a list back, so we need to grab + # the first element to get the tensor + assert len(logits) == 1 + logits = logits[0] + return logits + + def setup_caches(self, max_batch_size, max_seq_length): + pass +except: + pass diff --git a/build/model_et.py b/build/model_et.py deleted file mode 100644 index 54317124d..000000000 --- a/build/model_et.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torch.nn as nn -from executorch.extension.pybindings import portable_lib as exec_lib - -# ET changed the way it's loading the custom ops so it's not included in portable_lib but has to be loaded separately. -from executorch.examples.models.llama2.custom_ops import sdpa_with_kv_cache # no-qa - -class PTEModel(nn.Module): - def __init__(self, config, path) -> None: - super().__init__() - self.config = config - self.model_ = exec_lib._load_for_executorch(str(path)) - - def forward(self, x, input_pos): - # model_.forward expects inputs to be wrapped in a tuple - forward_inputs = (x.to(torch.long), input_pos.to(torch.long)) - logits = self.model_.forward(forward_inputs) - - # After wrapping in a tuple, we get a list back, so we need to grab - # the first element to get the tensor - assert len(logits) == 1 - logits = logits[0] - return logits - - def setup_caches(self, max_batch_size, max_seq_length): - pass