-
Notifications
You must be signed in to change notification settings - Fork 684
Closed
Labels
need-user-inputThe issue needs more information from the reporter before moving forwardThe issue needs more information from the reporter before moving forward
Description
-
The freqs_cos of the Llama model structure of ExecuTorch is FP32 at the time of its creation
calculation of freqs_cos by precompute_freqs_cis:
freqs = 1.0 / ( theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim) )
-
The model is loaded by get_eager_model and is converted from FP32 to BF16 (Llama original model data type). This can lead to data loss and affect inferences
` def get_eager_model(self) -> torch.nn.Module:if self.dtype: return self.model_.to(self.dtype)`
-
I think freqs_cos should always keep FP32
Metadata
Metadata
Assignees
Labels
need-user-inputThe issue needs more information from the reporter before moving forwardThe issue needs more information from the reporter before moving forward