Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup forward compatibility of random ops #47896

Merged
merged 1 commit into from
Apr 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 31 additions & 102 deletions tensorflow/compiler/tests/stateful_random_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

from tensorflow.compiler.tests import xla_test
from tensorflow.python.client import device_lib
from tensorflow.python.compat import compat
from tensorflow.python.eager import def_function
from tensorflow.python.framework import config
from tensorflow.python.framework import dtypes
Expand Down Expand Up @@ -156,98 +155,29 @@ def testPhilox4x32(self):
[0xa4093822, 0x299f31d0],
[0xd16cfe09, 0x94fdcceb, 0x5001e420, 0x24126ea1])

def testNewStateThreeFry(self):
"""Tests that the new state is correct (for ThreeFry).
"""
if compat.forward_compatible(2020, 10, 25):
self.skipTest("The expected values in this test is inconsistent with "
"CPU/GPU. testXLAEqualsCPU has the correct checks of the "
"new states for the new version.")
with ops.device(xla_device_name()):
counter = 57
key = 0x1234
size = 46
state = [counter, key]
gen = random.Generator(state=state, alg=random.RNG_ALG_THREEFRY)
gen.uniform_full_int(shape=(size,), dtype=dtypes.uint32)
self.assertAllEqual([counter+(size+1)//2, key], gen.state.read_value())
gen.reset(state)
gen.uniform_full_int(shape=(size,), dtype=dtypes.uint64)
self.assertAllEqual([counter+size, key], gen.state.read_value())

def testNewStatePhilox(self):
"""Tests that the new state is correct (for Philox).
"""
if compat.forward_compatible(2020, 10, 25):
self.skipTest("The expected values in this test is inconsistent with "
"CPU/GPU. testXLAEqualsCPU has the correct checks of the "
"new states for the new version.")
with ops.device(xla_device_name()):
counter_low = 57
counter_high = 283
key = 0x1234
size = 47
state = [counter_low, counter_high, key]
gen = random.Generator(state=state, alg=random.RNG_ALG_PHILOX)
gen.uniform_full_int(shape=(size,), dtype=dtypes.uint32)
self.assertAllEqual([counter_low+(size+3)//4, counter_high, key],
gen.state.read_value())
gen.reset(state)
gen.uniform_full_int(shape=(size,), dtype=dtypes.uint64)
self.assertAllEqual([counter_low+(size+1)//2, counter_high, key],
gen.state.read_value())
# Tests that large counter_low will correctly overflows to counter_high
counter_low = -1 # same as 0xffffffffffffffff
counter_high = 283
size = 47
state = [counter_low, counter_high, key]
gen = random.Generator(state=state, alg=random.RNG_ALG_PHILOX)
gen.uniform_full_int(shape=(size,), dtype=dtypes.uint32)
self.assertAllEqual([(size+3)//4-1, counter_high+1, key],
gen.state.read_value())
gen.reset(state)
gen.uniform_full_int(shape=(size,), dtype=dtypes.uint64)
self.assertAllEqual([(size+1)//2-1, counter_high+1, key],
gen.state.read_value())

@parameterized.parameters(INTS)
def testXLAEqualsCPU(self, dtype):
"""Tests that XLA and CPU kernels generate the same integers."""
seed = 1234
shape = [315, 49]
if compat.forward_compatible(2020, 10, 25):
with ops.device("/device:CPU:0"):
cpu_gen = random.Generator.from_seed(
seed=seed, alg=random.RNG_ALG_PHILOX)
with ops.device(xla_device_name()):
xla_gen = random.Generator.from_seed(
seed=seed, alg=random.RNG_ALG_PHILOX)
# Repeat multiple times to make sure that the state after
# number-generation are the same between CPU and XLA.
for _ in range(5):
with ops.device("/device:CPU:0"):
# Test both number-generation and skip
cpu = cpu_gen.uniform_full_int(shape=shape, dtype=dtype)
cpu_gen.skip(100)
with ops.device(xla_device_name()):
xla = xla_gen.uniform_full_int(shape=shape, dtype=dtype)
xla_gen.skip(100)
self.assertAllEqual(cpu, xla)
self.assertAllEqual(cpu_gen.state, xla_gen.state)
else:
# The old version doesn't guarantee that CPU and XLA are in the same state
# after number-generation, which is a bug.
with ops.device("/device:CPU:0"):
cpu_gen = random.Generator.from_seed(
seed=seed, alg=random.RNG_ALG_PHILOX)
with ops.device(xla_device_name()):
xla_gen = random.Generator.from_seed(
seed=seed, alg=random.RNG_ALG_PHILOX)
# Repeat multiple times to make sure that the state after
# number-generation are the same between CPU and XLA.
for _ in range(5):
with ops.device("/device:CPU:0"):
cpu = (
random.Generator.from_seed(
seed=seed, alg=random.RNG_ALG_PHILOX).uniform_full_int(
shape=shape, dtype=dtype))
# Test both number-generation and skip
cpu = cpu_gen.uniform_full_int(shape=shape, dtype=dtype)
cpu_gen.skip(100)
with ops.device(xla_device_name()):
xla = (
random.Generator.from_seed(
seed=seed, alg=random.RNG_ALG_PHILOX).uniform_full_int(
shape=shape, dtype=dtype))
xla = xla_gen.uniform_full_int(shape=shape, dtype=dtype)
xla_gen.skip(100)
self.assertAllEqual(cpu, xla)
self.assertAllEqual(cpu_gen.state, xla_gen.state)

def testXLAEqualsCPUAroundCounterOverflow(self):
"""Tests XLA and CPU kernels generate the same integers in overflow case.
Expand All @@ -258,26 +188,25 @@ def testXLAEqualsCPUAroundCounterOverflow(self):
dtype = dtypes.uint64
seed = 2**64 - 10
shape = [315, 49]
if compat.forward_compatible(2020, 10, 25):
with ops.device("/device:CPU:0"):
cpu_gen = random.Generator.from_seed(
seed=seed, alg=random.RNG_ALG_PHILOX)
with ops.device(xla_device_name()):
xla_gen = random.Generator.from_seed(
seed=seed, alg=random.RNG_ALG_PHILOX)
# Repeat multiple times to make sure that the state after
# number-generation are the same between CPU and XLA.
for _ in range(5):
with ops.device("/device:CPU:0"):
cpu_gen = random.Generator.from_seed(
seed=seed, alg=random.RNG_ALG_PHILOX)
# Test both number-generation and skip
cpu = cpu_gen.uniform_full_int(shape=shape, dtype=dtype)
cpu_gen.skip(100)
with ops.device(xla_device_name()):
xla_gen = random.Generator.from_seed(
seed=seed, alg=random.RNG_ALG_PHILOX)
# Repeat multiple times to make sure that the state after
# number-generation are the same between CPU and XLA.
for _ in range(5):
with ops.device("/device:CPU:0"):
# Test both number-generation and skip
cpu = cpu_gen.uniform_full_int(shape=shape, dtype=dtype)
cpu_gen.skip(100)
with ops.device(xla_device_name()):
xla = xla_gen.uniform_full_int(shape=shape, dtype=dtype)
xla_gen.skip(100)
self.assertAllEqual(cpu, xla)
self.assertAllEqual(cpu_gen.state, xla_gen.state)
xla = xla_gen.uniform_full_int(shape=shape, dtype=dtype)
xla_gen.skip(100)
self.assertAllEqual(cpu, xla)
self.assertAllEqual(cpu_gen.state, xla_gen.state)
self.assertAllEqual(cpu, xla)

def _testRngIsNotConstant(self, rng, dtype):
# Tests that 'rng' does not always return the same value.
Expand Down
109 changes: 42 additions & 67 deletions tensorflow/python/ops/stateful_random_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

import six

from tensorflow.python.compat import compat
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import values_util
from tensorflow.python.eager import context
Expand Down Expand Up @@ -541,12 +540,9 @@ def algorithm(self):
return self._alg

def _standard_normal(self, shape, dtype):
if compat.forward_compatible(2020, 10, 25):
key, counter = self._prepare_key_counter(shape)
return gen_stateless_random_ops_v2.stateless_random_normal_v2(
key, counter = self._prepare_key_counter(shape)
return gen_stateless_random_ops_v2.stateless_random_normal_v2(
shape, key=key, counter=counter, dtype=dtype, alg=self.algorithm)
return gen_stateful_random_ops.stateful_standard_normal_v2(
self.state.handle, self.algorithm, shape, dtype=dtype)

@property
def key(self):
Expand All @@ -571,8 +567,12 @@ def key(self):
else:
raise ValueError("Unsupported algorithm id: %s" % alg)

# TODO(wangpeng): Add "Returns" section to docstring once new version kicks in
# pylint: disable=g-doc-return-or-yield
def _skip_single_var(self, var, delta):
# TODO(wangpeng): Cache the cast algorithm instead of casting everytime.
return gen_stateful_random_ops.rng_read_and_skip(
var.handle, alg=math_ops.cast(self.algorithm, dtypes.int32),
delta=math_ops.cast(delta, dtypes.uint64))

def skip(self, delta):
"""Advance the counter of a counter-based RNG.

Expand All @@ -581,21 +581,10 @@ def skip(self, delta):
`skip(n)` will be the same as that after `normal([n])`
(or any other distribution). The actual increment added to the
counter is an unspecified implementation detail.
"""
if compat.forward_compatible(2020, 10, 25):
return self._skip(delta)
gen_stateful_random_ops.rng_skip(
self.state.handle, math_ops.cast(self.algorithm, dtypes.int64),
math_ops.cast(delta, dtypes.int64))
# pylint: enable=g-doc-return-or-yield

def _skip_single_var(self, var, delta):
# TODO(wangpeng): Cache the cast algorithm instead of casting everytime.
return gen_stateful_random_ops.rng_read_and_skip(
var.handle, alg=math_ops.cast(self.algorithm, dtypes.int32),
delta=math_ops.cast(delta, dtypes.uint64))

def _skip(self, delta):
Returns:
A `Tensor` of type `int64`.
"""
def update_fn(v):
return self._skip_single_var(v, delta)
# TODO(b/170515001): Always call strategy.extended.update after calling it
Expand Down Expand Up @@ -666,16 +655,13 @@ def normal(self, shape, mean=0.0, stddev=1.0, dtype=dtypes.float32,
return math_ops.add(rnd * stddev, mean, name=name)

def _truncated_normal(self, shape, dtype):
if compat.forward_compatible(2020, 10, 25):
key, counter = self._prepare_key_counter(shape)
return gen_stateless_random_ops_v2.stateless_truncated_normal_v2(
shape=shape,
key=key,
counter=counter,
dtype=dtype,
alg=self.algorithm)
return gen_stateful_random_ops.stateful_truncated_normal(
self.state.handle, self.algorithm, shape, dtype=dtype)
key, counter = self._prepare_key_counter(shape)
return gen_stateless_random_ops_v2.stateless_truncated_normal_v2(
shape=shape,
key=key,
counter=counter,
dtype=dtype,
alg=self.algorithm)

def truncated_normal(self, shape,
mean=0.0,
Expand Down Expand Up @@ -712,30 +698,23 @@ def truncated_normal(self, shape,
return math_ops.add(mul, mean_tensor, name=name)

def _uniform(self, shape, dtype):
if compat.forward_compatible(2020, 10, 25):
key, counter = self._prepare_key_counter(shape)
return gen_stateless_random_ops_v2.stateless_random_uniform_v2(
shape=shape,
key=key,
counter=counter,
dtype=dtype,
alg=self.algorithm)
return gen_stateful_random_ops.stateful_uniform(
self.state.handle, self.algorithm, shape=shape, dtype=dtype)
key, counter = self._prepare_key_counter(shape)
return gen_stateless_random_ops_v2.stateless_random_uniform_v2(
shape=shape,
key=key,
counter=counter,
dtype=dtype,
alg=self.algorithm)

def _uniform_full_int(self, shape, dtype, name=None):
if compat.forward_compatible(2020, 10, 25):
key, counter = self._prepare_key_counter(shape)
return gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2(
shape=shape,
key=key,
counter=counter,
dtype=dtype,
alg=self.algorithm,
name=name)
return gen_stateful_random_ops.stateful_uniform_full_int(
self.state.handle, self.algorithm, shape=shape,
dtype=dtype, name=name)
key, counter = self._prepare_key_counter(shape)
return gen_stateless_random_ops_v2.stateless_random_uniform_full_int_v2(
shape=shape,
key=key,
counter=counter,
dtype=dtype,
alg=self.algorithm,
name=name)

def uniform(self, shape, minval=0, maxval=None,
dtype=dtypes.float32, name=None):
Expand Down Expand Up @@ -796,19 +775,15 @@ def uniform(self, shape, minval=0, maxval=None,
minval = ops.convert_to_tensor(minval, dtype=dtype, name="min")
maxval = ops.convert_to_tensor(maxval, dtype=dtype, name="max")
if dtype.is_integer:
if compat.forward_compatible(2020, 10, 25):
key, counter = self._prepare_key_counter(shape)
return gen_stateless_random_ops_v2.stateless_random_uniform_int_v2(
shape=shape,
key=key,
counter=counter,
minval=minval,
maxval=maxval,
alg=self.algorithm,
name=name)
return gen_stateful_random_ops.stateful_uniform_int(
self.state.handle, self.algorithm, shape=shape,
minval=minval, maxval=maxval, name=name)
key, counter = self._prepare_key_counter(shape)
return gen_stateless_random_ops_v2.stateless_random_uniform_int_v2(
shape=shape,
key=key,
counter=counter,
minval=minval,
maxval=maxval,
alg=self.algorithm,
name=name)
else:
rnd = self._uniform(shape=shape, dtype=dtype)
return math_ops.add(rnd * (maxval - minval), minval, name=name)
Expand Down
Loading