Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/pytorch_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@
'test_random_from_to_xla_int32', # precision, TPU does not have real F64
'test_uniform_from_to_xla_float64', # float64 limit, TPU does not have real F64
'test_topk_integral_xla_int64', # (TPU) unimplemented HLO for X64
'test_float_to_int_conversion_finite_xla' # different behavior than numpy when casting float_max/min to int types
'test_float_to_int_conversion_finite_xla', # different behavior than numpy when casting float_max/min to int types

# test_indexing.py
# TestIndexing
Expand Down
80 changes: 47 additions & 33 deletions torch_xla/csrc/random.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,60 +35,74 @@ xla::XlaOp MakeSeed(xla::XlaOp seed) {
return xla::ConvertElementType(seed, xla::PrimitiveType::U64);
}

xla::XlaOp MakeUniformBoundaryValue(xla::XlaOp val) {
xla::PrimitiveType element_type = XlaHelpers::TypeOfXlaOp(val);
if (element_type == xla::PrimitiveType::BF16) {
return xla::ConvertElementType(val, xla::PrimitiveType::F32);
} else if (xla::primitive_util::IsComplexType(element_type)) {
return xla::Real(val);
}
return val;
}

xla::Shape MakeRngShape(const xla::Shape& shape) {
xla::PrimitiveType element_type = shape.element_type();
xla::Shape rng_shape(shape);
if (element_type == xla::PrimitiveType::BF16) {
rng_shape.set_element_type(xla::PrimitiveType::F32);
} else if (xla::primitive_util::IsComplexType(element_type)) {
rng_shape.set_element_type(
xla::primitive_util::ComplexComponentType(element_type));
}
return rng_shape;
}

} // namespace

xla::XlaOp RngUniform(xla::XlaOp seed, const xla::Shape& shape,
xla::XlaOp minval, xla::XlaOp maxval) {
xla::XlaOp rng_seed = MakeSeed(seed);
xla::Shape rng_shape = MakeRngShape(shape);
xla::XlaOp rng_minval = MakeUniformBoundaryValue(minval);
xla::XlaOp rng_maxval = MakeUniformBoundaryValue(maxval);
xla::XlaOp initial_state =
xla::Zero(rng_seed.builder(), xla::PrimitiveType::U64);
switch (shape.element_type()) {
case xla::PrimitiveType::BF16: {
xla::Shape f32_shape(shape);
f32_shape.set_element_type(xla::PrimitiveType::F32);
xla::XlaOp f32_minval = MaybeConvertTo(minval, xla::PrimitiveType::F32);
xla::XlaOp f32_maxval = MaybeConvertTo(maxval, xla::PrimitiveType::F32);
xla::XlaOp rng = xla::UniformFloatingPointDistribution(
rng_seed, initial_state, GetBitGenerator(),
f32_minval, f32_maxval, f32_shape)
rng_minval, rng_maxval, rng_shape)
.value;
return xla::ConvertElementType(rng, xla::PrimitiveType::BF16);
}
case xla::PrimitiveType::F32:
case xla::PrimitiveType::F64:
return xla::UniformFloatingPointDistribution(rng_seed, initial_state,
GetBitGenerator(), minval,
maxval, shape)
return xla::UniformFloatingPointDistribution(
rng_seed, initial_state, GetBitGenerator(), rng_minval,
rng_maxval, rng_shape)
.value;
case xla::PrimitiveType::C64:
case xla::PrimitiveType::C128: {
xla::XlaOp k_seed = XlaHelpers::ScalarValue<xla::uint64>(
17, XlaHelpers::TypeOfXlaOp(rng_seed), rng_seed.builder());
xla::Shape rng_shape(shape);
rng_shape.set_element_type(xla::PrimitiveType::F32);
xla::XlaOp f32_minval = MaybeConvertTo(minval, xla::PrimitiveType::F32);
xla::XlaOp f32_maxval = MaybeConvertTo(maxval, xla::PrimitiveType::F32);
xla::XlaOp rng_real = xla::UniformFloatingPointDistribution(
rng_seed, initial_state, GetBitGenerator(),
f32_minval, f32_maxval, rng_shape)
rng_minval, rng_maxval, rng_shape)
.value;
xla::XlaOp rng_imag =
xla::UniformFloatingPointDistribution(
rng_seed * k_seed, initial_state, GetBitGenerator(), f32_minval,
f32_maxval, rng_shape)
rng_seed * k_seed, initial_state, GetBitGenerator(), rng_minval,
rng_maxval, rng_shape)
.value;
xla::XlaOp rng = xla::Complex(rng_real, rng_imag);
return shape.element_type() == xla::PrimitiveType::C64
? rng
: xla::ConvertElementType(rng, xla::PrimitiveType::C128);
return xla::Complex(rng_real, rng_imag);
}
case xla::PrimitiveType::S32:
case xla::PrimitiveType::U32:
case xla::PrimitiveType::S64:
case xla::PrimitiveType::U64:
return xla::UniformIntDistribution(rng_seed, initial_state,
GetBitGenerator(), minval, maxval,
shape)
GetBitGenerator(), rng_minval,
rng_maxval, rng_shape)
.value;
default:
XLA_ERROR() << "RngUniform not implemented for type "
Expand All @@ -100,34 +114,32 @@ xla::XlaOp RngUniform(xla::XlaOp seed, const xla::Shape& shape,
xla::XlaOp RngNormal(xla::XlaOp seed, const xla::Shape& shape, xla::XlaOp mean,
xla::XlaOp std) {
xla::XlaOp rng_seed = MakeSeed(seed);
xla::Shape rng_shape = MakeRngShape(shape);
xla::XlaOp initial_state =
xla::Zero(rng_seed.builder(), xla::PrimitiveType::U64);
switch (shape.element_type()) {
case xla::PrimitiveType::BF16: {
xla::Shape f32_shape(shape);
f32_shape.set_element_type(xla::PrimitiveType::F32);
xla::XlaOp f32_mean = MaybeConvertTo(mean, xla::PrimitiveType::F32);
xla::XlaOp f32_std = MaybeConvertTo(std, xla::PrimitiveType::F32);
xla::XlaOp rng =
xla::NormalFloatingPointDistribution(rng_seed, initial_state,
GetBitGenerator(), f32_shape)
GetBitGenerator(), rng_shape)
.value;
return xla::ConvertElementType(f32_mean + rng * f32_std,
xla::PrimitiveType::BF16);
}
case xla::PrimitiveType::F32:
case xla::PrimitiveType::F64: {
xla::XlaOp rng = xla::NormalFloatingPointDistribution(
rng_seed, initial_state, GetBitGenerator(), shape)
.value;
xla::XlaOp rng =
xla::NormalFloatingPointDistribution(rng_seed, initial_state,
GetBitGenerator(), rng_shape)
.value;
return XlaHelpers::PromotedAdd(mean, XlaHelpers::PromotedMul(rng, std));
}
case xla::PrimitiveType::C64:
case xla::PrimitiveType::C128: {
xla::XlaOp k_seed = XlaHelpers::ScalarValue<xla::uint64>(
17, XlaHelpers::TypeOfXlaOp(rng_seed), rng_seed.builder());
xla::Shape rng_shape(shape);
rng_shape.set_element_type(xla::PrimitiveType::F32);
xla::XlaOp rng_real =
xla::NormalFloatingPointDistribution(rng_seed, initial_state,
GetBitGenerator(), rng_shape)
Expand All @@ -137,10 +149,12 @@ xla::XlaOp RngNormal(xla::XlaOp seed, const xla::Shape& shape, xla::XlaOp mean,
GetBitGenerator(), rng_shape)
.value;
xla::XlaOp rng = xla::Complex(rng_real, rng_imag);
if (shape.element_type() == xla::PrimitiveType::C128) {
rng = xla::ConvertElementType(rng, xla::PrimitiveType::C128);
}
return XlaHelpers::PromotedAdd(mean, XlaHelpers::PromotedMul(rng, std));
// Variance for normal distribution of the real and imaginary values is
// half of the input variance.
xla::XlaOp sqrtTwo = XlaHelpers::ScalarValue(
std::sqrt(2), XlaHelpers::TypeOfXlaOp(std), rng_seed.builder());
return XlaHelpers::PromotedAdd(
mean, XlaHelpers::PromotedMul(rng, std / sqrtTwo));
}
default:
XLA_ERROR() << "RngNormal not implemented for type "
Expand Down