Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ function run_xla_op_tests1 {
run_test "$CDIR/dynamo/test_graph_input_matcher.py"
run_save_tensor_ir "$CDIR/dynamo/test_dynamo_graph_dump.py"
run_test "$CDIR/test_data_type.py"
run_test "$CDIR/test_fp8.py"
run_xla_ir_debug "$CDIR/test_env_var_mapper.py"
run_xla_hlo_debug "$CDIR/test_env_var_mapper.py"
run_xla_hlo_debug "$CDIR/stablehlo/test_stablehlo_save_load.py"
Expand Down
51 changes: 51 additions & 0 deletions test/test_fp8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
import re

import torch
import torch_xla
import unittest


class Fp8Test(unittest.TestCase):

def test_fp8(self):
device = torch_xla.device()
fp8_types = [torch.float8_e5m2]
for dtype in fp8_types:
t = torch.rand(2, 2).to(dtype)
xla_t = t.to(device)
torch_t = xla_t.cpu()
self.assertEqual(xla_t.dtype, dtype)
self.assertEqual(torch_t.dtype, dtype)
# Need to cast to float32 since allclose doesn't work with fp8.
self.assertTrue(
torch.allclose(t.to(torch.float32), torch_t.to(torch.float32)))

def test_fp8_matmul(self):
device = torch_xla.device()
fp8_types = [torch.float8_e5m2]
for dtype in fp8_types:
t = torch.rand(3, 2).to(dtype)
w = torch.rand(2, 5).to(dtype)
torch_matmul = torch.matmul(t, w)
xla_t = t.to(device)
xla_w = w.to(device)
xla_matmul = torch.matmul(xla_t, xla_w)
xla_matmul = xla_matmul.cpu()
# Need to cast to float32 since allclose doesn't work with fp8.
self.assertTrue(
torch.allclose(
xla_matmul.to(torch.float32), torch_matmul.to(torch.float32)))

def test_fp8_hlo(self):
device = torch_xla.device()
x = torch.randn((3, 5)).to(torch.float8_e5m2).to(device)
w = torch.randn((5, 8)).to(torch.float8_e5m2).to(device)
output = torch.matmul(x, w)
hlo = torch_xla._XLAC._get_xla_tensors_hlo([output])
self.assertTrue(re.search(r'f8e5m2.*dot.*f8e5m2.*f8e5m2', hlo) is not None)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
1 change: 1 addition & 0 deletions test/tpu/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ python3 test/spmd/test_fsdp_v2.py
XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shape_models.py -v
XLA_EXPERIMENTAL=nonzero:masked_select:nms python3 test/ds/test_dynamic_shapes.py -v
python3 test/test_autocast.py
python3 test/test_fp8.py
python3 test/test_grad_checkpoint.py
python3 test/dynamo/test_dynamo.py
python3 test/dynamo/test_dynamo_dynamic_shape.py
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/dtype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ at::ScalarType TorchTypeFromXlaType(xla::PrimitiveType xla_type) {
switch (xla_type) {
case xla::PrimitiveType::BF16:
return at::ScalarType::BFloat16;
case xla::PrimitiveType::F8E5M2:
return at::ScalarType::Float8_e5m2;
case xla::PrimitiveType::F16:
return at::ScalarType::Half;
case xla::PrimitiveType::F32:
Expand Down Expand Up @@ -49,6 +51,8 @@ xla::PrimitiveType XlaTypeFromTorchType(at::ScalarType scalar_type) {
return xla::PrimitiveType::BF16;
case at::ScalarType::Half:
return xla::PrimitiveType::F16;
case at::ScalarType::Float8_e5m2:
return xla::PrimitiveType::F8E5M2;
case at::ScalarType::Bool:
return xla::PrimitiveType::PRED;
case at::ScalarType::Byte:
Expand Down
51 changes: 51 additions & 0 deletions torch_xla/csrc/tensor_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,23 @@ struct Caster<tsl::bfloat16> {
return static_cast<D>(static_cast<float>(value));
}
};

template <>
struct Caster<at::Float8_e5m2> {
template <typename D>
D cast(const at::Float8_e5m2& value) const {
return static_cast<D>(static_cast<float>(value));
}
};

template <>
struct Caster<tsl::float8_e5m2> {
template <typename D>
D cast(const tsl::float8_e5m2& value) const {
return static_cast<D>(static_cast<float>(value));
}
};

template <>
struct Caster<at::Half> {
template <typename D>
Expand Down Expand Up @@ -185,6 +202,14 @@ struct NeedCast<at::BFloat16> {
static constexpr bool value = true;
};
template <>
struct NeedCast<tsl::float8_e5m2> {
static constexpr bool value = true;
};
template <>
struct NeedCast<at::Float8_e5m2> {
static constexpr bool value = true;
};
template <>
struct NeedCast<xla::half> {
static constexpr bool value = true;
};
Expand Down Expand Up @@ -248,6 +273,18 @@ void CopyData<tsl::bfloat16, at::BFloat16>(tsl::bfloat16* dest,
int64_t n, const CopyCasted&) {
CheckedMemcpy<tsl::bfloat16, at::BFloat16>(dest, source, n);
}
template <>
void CopyData<at::Float8_e5m2, tsl::float8_e5m2>(at::Float8_e5m2* dest,
const tsl::float8_e5m2* source,
int64_t n, const CopyCasted&) {
CheckedMemcpy<at::Float8_e5m2, tsl::float8_e5m2>(dest, source, n);
}
template <>
void CopyData<tsl::float8_e5m2, at::Float8_e5m2>(tsl::float8_e5m2* dest,
const at::Float8_e5m2* source,
int64_t n, const CopyCasted&) {
CheckedMemcpy<tsl::float8_e5m2, at::Float8_e5m2>(dest, source, n);
}

std::vector<int64_t> GetIterationDimensions(const xla::Shape& shape) {
// We want to favor the most minor dimension as core iteration dimension, as
Expand Down Expand Up @@ -414,6 +451,10 @@ void TensorToBufferSType(const at::Tensor& tensor, const xla::Shape& dest_shape,
TensorToBuffer<SType, double>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case xla::PrimitiveType::F8E5M2:
TensorToBuffer<SType, tsl::float8_e5m2>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case xla::PrimitiveType::PRED:
TensorToBuffer<SType, bool>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
Expand Down Expand Up @@ -537,6 +578,9 @@ at::Tensor XlaLiteralToTensorHelper(const xla::Literal& literal,
dest_element_type);
case at::ScalarType::Half:
return XlaLiteralToTensor<SType, at::Half>(literal, dest_element_type);
case at::ScalarType::Float8_e5m2:
return XlaLiteralToTensor<SType, at::Float8_e5m2>(literal,
dest_element_type);
case at::ScalarType::ComplexFloat:
return XlaLiteralToTensor<SType, c10::complex<float>>(literal,
dest_element_type);
Expand Down Expand Up @@ -567,6 +611,10 @@ void PopulateTensorBuffer(const at::Tensor& tensor,
TensorToBufferSType<at::BFloat16>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case at::ScalarType::Float8_e5m2:
TensorToBufferSType<at::Float8_e5m2>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case at::ScalarType::Half:
TensorToBufferSType<at::Half>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
Expand Down Expand Up @@ -626,6 +674,9 @@ at::Tensor MakeTensorFromXlaLiteral(const xla::Literal& literal,
case xla::PrimitiveType::BF16:
return XlaLiteralToTensorHelper<tsl::bfloat16>(literal,
dest_element_type);
case xla::PrimitiveType::F8E5M2:
return XlaLiteralToTensorHelper<tsl::float8_e5m2>(literal,
dest_element_type);
case xla::PrimitiveType::F16:
return XlaLiteralToTensorHelper<xla::half>(literal, dest_element_type);
case xla::PrimitiveType::F32:
Expand Down