Skip to content

Commit

Permalink
enable float types in pytorch for non comptue comms
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed May 17, 2024
1 parent 6891cbe commit 02a80f9
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 3 deletions.
3 changes: 0 additions & 3 deletions float8_experimental/float8_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,8 @@ def allgather_fp8(aten_op, args, kwargs=None):
), f"expecting a Float8Tensor for allgather but found {type(fp8_input)}"

fp8_data = fp8_input._data
fp8_data = fp8_data.view(torch.uint8)
fp8_data = fp8_data.contiguous()
fp8_out = aten_op(fp8_data, *args[1:], **kwargs)
fp8_out = torch.ops._c10d_functional.wait_tensor(fp8_out)
fp8_out = fp8_out.view(fp8_input._data.dtype)
return Float8Tensor(
fp8_out, fp8_input._scale, fp8_input._orig_dtype, fp8_input._mm_config
)
Expand Down
2 changes: 2 additions & 0 deletions test/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,5 @@ def test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
except Exception as e:
print(f"Test {test.__name__} failed with error: {e}")
raise e

torch.distributed.destroy_process_group()

0 comments on commit 02a80f9

Please sign in to comment.