Skip to content

Commit

Permalink
[DataLoader] Add Numpy seeding to worker of DataLoader (#56488)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #56488

Considering amount of requests for this feature, introduce numpy seeding as default within each worker for DataLoader.

## BC-breaking Note:
- By introducing default numpy.random seeding strategy to workers of DataLoader, users don't need to manually set seed for workers by the `worker_init_fn`. And this PR won't influence users who are currently using `worker_init_fn` to set customized seed for workers.
- DataLoader will preserve reproducibility for users who are using numpy.random within Dataset.
- Multiprocessing (without `worker_init_fn` to define seed for numpy)
  - Start method as `spawn`: Each worker will now have seed for numpy random, rather than the seed generated from the imported time of Numpy module that make the DataLoader lose the reproducibility.
  - Start method as `fork`: Each worker not only have the same benefit as `spawn`,  but also have different seed for numpy as default, rather than inheriting the same seed.

Using the following Dataset and script as an example:
```py
class RandomDataset(Dataset):
    def __getitem__(self, ind):
        item = [ind, np.random.randint(1, 10000)]
        return item

    def __len__(self):
        return 20

if __name__ == '__main__'"
    ctx = mp.get_context('fork')
    ds = RandomDataset()
    g = torch.Generator()
    g.manual_seed(0)
    dl = DataLoader(ds, 2, shuffle=False, num_workers=4, multiprocessing_context=ctx, generator=g)

    epochs = 2
    for _ in range(epochs):
        for batch in d;:
            print(batch)
        print("====" * 10)
```

### 1.8.1:
Each worker generates same random result per iteration. And the seed will be reset to same for each epoch.
```py
tensor([[   0, 7449],
        [   1, 1519]])
tensor([[   2, 7449],
        [   3, 1519]])
tensor([[   4, 9645],
        [   5, 2387]])
tensor([[   6, 9645],
        [   7, 2387]])
tensor([[   8, 3118],
        [   9, 4552]])
=========================
tensor([[   0, 7449],
        [   1, 1519]])
tensor([[   2, 7449],
        [   3, 1519]])
tensor([[   4, 9645],
        [   5, 2387]])
tensor([[   6, 9645],
        [   7, 2387]])
tensor([[   8, 3118],
        [   9, 4552]])
=========================
```

### This PR:
Each worker has different seed at the beginning and re-seed for each epoch.
```py
tensor([[   0, 8715],
        [   1, 5555]])
tensor([[   2, 6379],
        [   3, 1432]])
tensor([[   4, 3271],
        [   5, 5132]])
tensor([[   6, 4287],
        [   7, 1104]])
tensor([[   8, 8682],
        [   9, 1699]])
=========================
tensor([[   0, 1374],
        [   1,  996]])
tensor([[   2,  143],
        [   3, 3507]])
tensor([[   4, 5887],
        [   5, 4730]])
tensor([[   6, 7274],
        [   7,  738]])
tensor([[   8, 6374],
        [   9, 1572]])
=========================
```

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D27908486

Pulled By: ejguan

fbshipit-source-id: 5f313a30563bedeb88be214fa4beca0cefe9e4f4
  • Loading branch information
ejguan authored and facebook-github-bot committed Apr 22, 2021
1 parent bc3d892 commit aec83ff
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 4 deletions.
7 changes: 7 additions & 0 deletions torch/utils/data/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@
"""


try:
import numpy
HAS_NUMPY = True
except ModuleNotFoundError:
HAS_NUMPY = False


def _set_python_exit_flag():
global python_exit_status
python_exit_status = True
Expand Down
11 changes: 8 additions & 3 deletions torch/utils/data/_utils/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from dataclasses import dataclass
from torch._utils import ExceptionWrapper
from typing import Union
from . import signal_handling, MP_STATUS_CHECK_INTERVAL, IS_WINDOWS
from . import signal_handling, MP_STATUS_CHECK_INTERVAL, IS_WINDOWS, HAS_NUMPY

if IS_WINDOWS:
import ctypes
Expand Down Expand Up @@ -104,7 +104,7 @@ def get_worker_info():
set up each worker process differently, for instance, using ``worker_id``
to configure the ``dataset`` object to only read a specific fraction of a
sharded dataset, or use ``seed`` to seed other libraries used in dataset
code (e.g., NumPy).
code.
"""
return _worker_info

Expand All @@ -120,7 +120,7 @@ class _ResumeIteration(object):
pass

def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
auto_collation, collate_fn, drop_last, seed, init_fn, worker_id,
auto_collation, collate_fn, drop_last, base_seed, init_fn, worker_id,
num_workers, persistent_workers):
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
# logic of this function.
Expand All @@ -134,8 +134,13 @@ def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
signal_handling._set_worker_signal_handlers()

torch.set_num_threads(1)
seed = base_seed + worker_id
random.seed(seed)
torch.manual_seed(seed)
if HAS_NUMPY:
import numpy as np
ss = np.random.SeedSequence([worker_id, base_seed])
np.random.seed(ss.generate_state(4))

global _worker_info
_worker_info = WorkerInfo(id=worker_id, num_workers=num_workers,
Expand Down
2 changes: 1 addition & 1 deletion torch/utils/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,7 +903,7 @@ def __init__(self, loader):
args=(self._dataset_kind, self._dataset, index_queue,
self._worker_result_queue, self._workers_done_event,
self._auto_collation, self._collate_fn, self._drop_last,
self._base_seed + i, self._worker_init_fn, i, self._num_workers,
self._base_seed, self._worker_init_fn, i, self._num_workers,
self._persistent_workers))
w.daemon = True
# NB: Process.start() actually take some time as it needs to
Expand Down

0 comments on commit aec83ff

Please sign in to comment.