Skip to content

XLA mul with bf16×bf16 upcasts to f32 — op math type and option to disable? #9662

@sshonTT

Description

@sshonTT

❓ Questions and Help

Hi folks, I have a question about the XLA mul op.

When both inputs are bf16, the generated graph converts to f32, performs the multiply, then converts back to bf16. Two questions:

In this case, is the op math type effectively f32 (not bf16)?

If this upcast exists primarily for TPU accuracy/stability, would it be acceptable to gate it behind a flag (e.g., env option) so we can treat that path as a no-op and keep the op in native bf16 when desired?

Reference code path:
https://github.com/pytorch/xla/blob/master/torch_xla/csrc/aten_xla_type.cpp#L187-L211

If there’s a better approach please let me know. Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requesttracingLazy Tensor tracing

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions