Skip to content

Commit 24eb005

Browse files
committed
Update base for Update on "[llava][15/N] Extract out text decoder runner"
Last PR #4556 refactored run_model_step() so that it is suitable to be extracted out as a separate class. This new `TextDecoderRunner` provides 2 APIs: * step(tokens, start_pos) This API takes one or more tokens with start_pos and feed them into Module. Return a tensor of logits. * logits_to_token(logits) This API samples the result and returns a token. We don't expect this logic to change across different runners. Differential Revision: [D60856571](https://our.internmc.facebook.com/intern/diff/D60856571) [ghstack-poisoned]
1 parent 306b656 commit 24eb005

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

examples/models/llama2/runner/runner.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ Result<uint64_t> Runner::prefill(
152152

153153
ManagedTensor managed_start_pos(&start_pos, {1}, ScalarType::Long);
154154

155-
Result<torch::executor::Tensor> outputs_res =
155+
Result<exec_aten::Tensor> outputs_res =
156156
run_model_step(managed_tokens, managed_start_pos);
157157

158158
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
@@ -164,6 +164,7 @@ Result<uint64_t> Runner::prefill(
164164
num_prompt_tokens,
165165
outputs_res.get().size(1));
166166
// insert new token into prompt_tokens
167+
// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
167168
uint64_t prev = prompt_tokens[0];
168169
uint64_t cur;
169170
for (int i = 1; i < prompt_tokens.size(); i++) {
@@ -177,6 +178,7 @@ Result<uint64_t> Runner::prefill(
177178
uint64_t prev_token;
178179
// token & pos
179180
int64_t pos_data = 0;
181+
// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
180182
cur_token = prompt_tokens[0];
181183

182184
// initialize tensor wrappers
@@ -188,7 +190,7 @@ Result<uint64_t> Runner::prefill(
188190
// Run the model
189191
pos_data = start_pos + pos;
190192

191-
Result<torch::executor::Tensor> logits_res =
193+
Result<exec_aten::Tensor> logits_res =
192194
run_model_step(managed_tokens, managed_start_pos);
193195

194196
ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error());

0 commit comments

Comments
 (0)