Skip to content

Llama‘s freqs_cos data loss as for convert dtype #9393

@WeiMa01

Description

@WeiMa01
  • The freqs_cos of the Llama model structure of ExecuTorch is FP32 at the time of its creation
    calculation of freqs_cos by precompute_freqs_cis:
    freqs = 1.0 / ( theta ** (torch.arange(0, dim, 2, device="cpu")[: (dim // 2)].float() / dim) )

  • The model is loaded by get_eager_model and is converted from FP32 to BF16 (Llama original model data type). This can lead to data loss and affect inferences
    ` def get_eager_model(self) -> torch.nn.Module:

      if self.dtype:
          return self.model_.to(self.dtype)`
    
  • I think freqs_cos should always keep FP32

Metadata

Metadata

Assignees

No one assigned

    Labels

    need-user-inputThe issue needs more information from the reporter before moving forward

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions