Skip to content

Commit

Permalink
[DataLoader] Add generate_state for NumPy seeding (#56797)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #56797

After adding default seeding strategy for NumPy random module within each worker of DataLoader #56488, two concerns are raised:
- We dropped the support for NumPy < 1.17 due to `SeedSequence`
- In order to support seeding for NumPy < 1.17, how can we provide seed for `numpy.random`?
  - First option is set the same seed as `random`. But, the problem is a same algorithm is shared between `numpy.random` and `random`. With the same seed, they will have exact same state sequence. Thanks to rkern, we noticed this so-called [bad things](Lightning-AI/pytorch-lightning#6960 (comment)).
  - Considering most of users do not aware this problem, we can provide a better seed by default for `numpy.random` using same `SeedSequence` algorithm as numpy. This is just a workaround with hard-coded function to generate an array of four int32 as the seed.

To better coping with this problem since there are amount of 3rd party libraries not just `NumPy` having random module. We may at the end need to implement a `SeedSequence` within `torch.random` module, then users can `spawn` a new `SeedSequence` for each library.

Test Plan: Imported from OSS

Reviewed By: H-Huang

Differential Revision: D28000619

Pulled By: ejguan

fbshipit-source-id: 5701c8124a38ea5ded69eb8eee70f9680877ffa6
  • Loading branch information
ejguan authored and facebook-github-bot committed Apr 27, 2021
1 parent 759cfb7 commit 3b977a0
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 2 deletions.
28 changes: 28 additions & 0 deletions test/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1529,6 +1529,34 @@ def __len__(self):
self.assertIsInstance(batch, torch.DoubleTensor)
self.assertEqual(batch.size(), torch.Size([12, 2, 3, 4]))

@unittest.skipIf(not TEST_NUMPY, "numpy unavailable")
def test_numpy_gen_state(self):
from torch.utils.data._utils.worker import _generate_state
# Using NumPy generated states as the reference to test `_generate_state`
# having the same result.
# Test case: ((worker_id, base_seed), expected_state)
test_cases = [
((4, 13434589827475259383), (2884386318, 1088094898, 3523808998, 3860348662)),
((1, 15014285634777110771), (1934848465, 763213760, 2959016433, 179751970)),
((10, 978296274032934101), (1759791917, 3550927336, 1225977135, 1036538043)),
((12, 11868770762134256968), (3974661794, 3331131333, 3630387033, 2885815368)),
((9, 15378787925219019706), (3815056996, 3162224466, 2735102421, 3190253477)),
((5, 9055612723125076328), (3522565701, 3368424109, 959377806, 621878693)),
((15, 14617792358407278405), (3402479508, 1588702753, 1169536393, 3675067356)),
((9, 17363320784006640087), (957989458, 2518334477, 1421725660, 3086155459)),
((12, 480002904169484764), (2732851467, 1762620729, 4055801988, 1277640511)),
((15, 16803975943592702950), (3479415043, 4022359553, 295994005, 3358606349)),
((9, 11704776406047813044), (1968928009, 710113752, 2442656196, 1587420279)),
((10, 16357891985431864516), (1271733898, 4197047399, 3727213786, 2338547348)),
((2, 17423369006318065007), (544294336, 1911284083, 3299147734, 3231058347)),
((2, 2889492011444113593), (3721591783, 2595811276, 2212881745, 977682627)),
((0, 8979703111668486195), (4276723937, 2556068849, 2962827292, 233130238)),
((6, 6269787272229682235), (2548857855, 1216457374, 1012973562, 2999759647))
]

for (worker_id, base_seed), exp in test_cases:
self.assertEqual(exp, _generate_state(base_seed, worker_id))

def test_error(self):
self._test_error(self._get_data_loader(ErrorDataset(100), batch_size=2, shuffle=True))

Expand Down
84 changes: 82 additions & 2 deletions torch/utils/data/_utils/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,86 @@ class _IterableDatasetStopIteration(object):
class _ResumeIteration(object):
pass

# The function `_generate_state` is adapted from `numpy.random.SeedSequence`
# from https://github.com/numpy/numpy/blob/main/numpy/random/bit_generator.pyx
# It's MIT licensed, here is the copyright:

# Copyright (c) 2015 Melissa E. O'Neill
# Copyright (c) 2019 NumPy Developers
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

# This function generates an array of int32 as the seed for
# `numpy.random`, in order to prevent state collision due to same
# seed and algorithm for `numpy.random` and `random` modules.
# TODO: Implement `SeedSequence` like object for `torch.random`
def _generate_state(base_seed, worker_id):
INIT_A = 0x43b0d7e5
MULT_A = 0x931e8875
INIT_B = 0x8b51f9dd
MULT_B = 0x58f38ded
MIX_MULT_L = 0xca01f9dd
MIX_MULT_R = 0x4973f715
XSHIFT = 4 * 8 // 2
MASK32 = 0xFFFFFFFF

entropy = [worker_id, base_seed & MASK32, base_seed >> 32, 0]
pool = [0] * 4

hash_const_A = INIT_A

def hash(value):
nonlocal hash_const_A
value = (value ^ hash_const_A) & MASK32
hash_const_A = (hash_const_A * MULT_A) & MASK32
value = (value * hash_const_A) & MASK32
value = (value ^ (value >> XSHIFT)) & MASK32
return value

def mix(x, y):
result_x = (MIX_MULT_L * x) & MASK32
result_y = (MIX_MULT_R * y) & MASK32
result = (result_x - result_y) & MASK32
result = (result ^ (result >> XSHIFT)) & MASK32
return result

# Add in the entropy to the pool.
for i in range(len(pool)):
pool[i] = hash(entropy[i])

# Mix all bits together so late bits can affect earlier bits.
for i_src in range(len(pool)):
for i_dst in range(len(pool)):
if i_src != i_dst:
pool[i_dst] = mix(pool[i_dst], hash(pool[i_src]))

hash_const_B = INIT_B
state = []
for i_dst in range(4):
data_val = pool[i_dst]
data_val = (data_val ^ hash_const_B) & MASK32
hash_const_B = (hash_const_B * MULT_B) & MASK32
data_val = (data_val * hash_const_B) & MASK32
data_val = (data_val ^ (data_val >> XSHIFT)) & MASK32
state.append(data_val)
return state

def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
auto_collation, collate_fn, drop_last, base_seed, init_fn, worker_id,
num_workers, persistent_workers):
Expand All @@ -138,9 +218,9 @@ def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
random.seed(seed)
torch.manual_seed(seed)
if HAS_NUMPY:
np_seed = _generate_state(base_seed, worker_id)
import numpy as np
ss = np.random.SeedSequence([worker_id, base_seed])
np.random.seed(ss.generate_state(4))
np.random.seed(np_seed)

global _worker_info
_worker_info = WorkerInfo(id=worker_id, num_workers=num_workers,
Expand Down

0 comments on commit 3b977a0

Please sign in to comment.