@@ -212,6 +212,7 @@ float* forward(Transformer* transformer, int token, int pos) {
212212 .to (torch::dtype (torch::kFloat32 ))
213213 .to (torch::kCPU );
214214 auto logits = result[0 ].data_ptr ();
215+ memcpy (s->logits , logits, p->vocab_size * sizeof (float ));
215216#else // __ET_MODEL__
216217 TensorPtr pos_managed = make_tensor_ptr ({1 }, pos_buffer, ScalarType::Long);
217218 TensorPtr tokens_managed = make_tensor_ptr ({1 , 1 }, token_buffer, ScalarType::Long);
@@ -228,10 +229,23 @@ float* forward(Transformer* transformer, int token, int pos) {
228229 exit (EXIT_FAILURE);
229230 }
230231 std::vector<EValue> result = outputs_res.get ();
231- auto logits = result[0 ].toTensor ().const_data_ptr ();
232+ // HACK: the rest of this runner assumes that logits must be float,
233+ // so we simply convert them rather than plumbing
234+ // templating/switch-on-type through the rest of this file.
235+ const auto & result_tensor = result[0 ].toTensor ();
236+ ET_SWITCH_REALHBBF16_TYPES (
237+ result_tensor.scalar_type (),
238+ unused,
239+ " forward" ,
240+ CTYPE,
241+ [&]() {
242+ const CTYPE* logits = result_tensor.const_data_ptr <CTYPE>();
243+ std::transform (logits, logits + p->vocab_size , s->logits , [](auto x) {
244+ return static_cast <float >(x);
245+ });
246+ });
232247#endif
233248
234- memcpy (s->logits , logits, p->vocab_size * sizeof (float ));
235249 return s->logits ;
236250}
237251
0 commit comments