Skip to content

Use numpy RandomState objects in tests instead of drawing from numpy's global RNG #4247

@NicolasHug

Description

@NicolasHug

We're using np.random in a few places in our tests:

(pt) ➜  vision git:(master) ✗ git grep np.random
test/common_utils.py:    np.random.seed(seed)
test/test_datasets.py:        labels = np.random.randint(0, self._VERSION_CONFIG["num_categories"], size=num_images).tolist()
test/test_image.py:        pixels = np.random.rand(*shape) > 0.5
test/test_image.py:        pixels = np.random.rand(*shape) > 0.5
test/test_transforms.py:        ndarray = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8)
test/test_transforms.py:        ndarray = np.random.rand(height, width, channels).astype(np.float32)
test/test_transforms.py:            trans(np.random.rand(1, height, width).tolist())
test/test_transforms.py:            trans(np.random.rand(height))
test/test_transforms.py:            trans(np.random.rand(1, 1, height, width))
test/test_transforms.py:        np_arr = np.random.randint(0, 255, (32, 32, 3), dtype=np.uint8)
test/test_transforms.py:        input_data = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8)
test/test_transforms.py:        input_data = torch.as_tensor(np.random.rand(channels, height, width).astype(np.float32))
test/test_transforms.py:            trans(np.random.rand(1, height, width).tolist())
test/test_transforms.py:            trans(np.random.rand(1, height, width))
test/test_transforms.py:    x_np = np.random.randint(0, 256, x_shape, np.uint8)
test/test_transforms.py:    x_np = np.random.randint(0, 256, x_shape, np.uint8)
test/test_transforms_video.py:            trans(np.random.rand(numFrames, height, width, 1).tolist())

this is using the global numpy RNG, which isn't ideal. We should instead instanciate a RandomState object in each test to contain the RNG into a single local object. For example, instead of calling np.random.randint(...) we should do:

np_rng = np.random.RandomState(0)  # or whatever seed you might prefer
np_rng.randint(...)

Same for all calls to np.random.xyz(). This local np_rng object is only alive during the function call and thus doesn't leak or alter the global numpy RNG.

Regarding the np.random.seed(seed) line in test/common_utils.set_rng_seed, let's see if we can purely remove it. It shouldn't have any impact, I don't think numpy's RNG is being used in the tests that use set_rng_seed()

CC @vmoens , would you like to work on this?

cc @pmeier

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions