-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Closed
Labels
Description
We're using np.random
in a few places in our tests:
(pt) ➜ vision git:(master) ✗ git grep np.random
test/common_utils.py: np.random.seed(seed)
test/test_datasets.py: labels = np.random.randint(0, self._VERSION_CONFIG["num_categories"], size=num_images).tolist()
test/test_image.py: pixels = np.random.rand(*shape) > 0.5
test/test_image.py: pixels = np.random.rand(*shape) > 0.5
test/test_transforms.py: ndarray = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8)
test/test_transforms.py: ndarray = np.random.rand(height, width, channels).astype(np.float32)
test/test_transforms.py: trans(np.random.rand(1, height, width).tolist())
test/test_transforms.py: trans(np.random.rand(height))
test/test_transforms.py: trans(np.random.rand(1, 1, height, width))
test/test_transforms.py: np_arr = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)
test/test_transforms.py: input_data = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8)
test/test_transforms.py: input_data = torch.as_tensor(np.random.rand(channels, height, width).astype(np.float32))
test/test_transforms.py: trans(np.random.rand(1, height, width).tolist())
test/test_transforms.py: trans(np.random.rand(1, height, width))
test/test_transforms.py: x_np = np.random.randint(0, 256, x_shape, np.uint8)
test/test_transforms.py: x_np = np.random.randint(0, 256, x_shape, np.uint8)
test/test_transforms_video.py: trans(np.random.rand(numFrames, height, width, 1).tolist())
this is using the global numpy RNG, which isn't ideal. We should instead instanciate a RandomState
object in each test to contain the RNG into a single local object. For example, instead of calling np.random.randint(...)
we should do:
np_rng = np.random.RandomState(0) # or whatever seed you might prefer
np_rng.randint(...)
Same for all calls to np.random.xyz()
. This local np_rng
object is only alive during the function call and thus doesn't leak or alter the global numpy RNG.
Regarding the np.random.seed(seed)
line in test/common_utils.set_rng_seed
, let's see if we can purely remove it. It shouldn't have any impact, I don't think numpy's RNG is being used in the tests that use set_rng_seed()
CC @vmoens , would you like to work on this?
cc @pmeier