Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
8 changes: 2 additions & 6 deletions torchchat/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def forward(self, x, freqs_cis, mask, input_pos=None, cache_lane: int = 0):
input_pos[-1].item(),
seqlen,
)
output = output.view(bsz, seqlen, self.dim).to(dtype=q.dtype)
output = output.view(bsz, seqlen, self.dim).to(dtype=x.dtype)
return self.wo(output)

def replace_attention_with_custom_sdpa_attention(module: nn.Module):
Expand Down Expand Up @@ -291,11 +291,7 @@ def export_for_et(model, device, output_path) -> str:
model = model.to(dtype=target_precision)
state_dict_dtype = target_precision

# Custom SDPA does not work with bfloat16 on CPU currently. (The op doesn't
# support anything but bfloat32, and our attempt to use it anyway by converting
# to and from float causes other errors.)
if target_precision != torch.bfloat16:
replace_attention_with_custom_sdpa_attention(model)
replace_attention_with_custom_sdpa_attention(model)

with torch.nn.attention.sdpa_kernel(
[torch.nn.attention.SDPBackend.MATH]
Expand Down
Loading