Skip to content

Commit

Permalink
[MXNET-145] Remove the dependences of mx.io and mx.initializer on the…
Browse files Browse the repository at this point in the history
… numpy's global random number generator (apache#10260)

* Remove the dependence of mx.io on the global random number generator of numpy

* Remove the dependence of mx.initializer on the global random number generator of numpy
  • Loading branch information
asitstands authored and zheng-da committed Jun 28, 2018
1 parent 5a35106 commit 54ddd53
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
4 changes: 2 additions & 2 deletions python/mxnet/initializer.py
Expand Up @@ -530,9 +530,9 @@ def _init_weight(self, _, arr):
nout = arr.shape[0]
nin = np.prod(arr.shape[1:])
if self.rand_type == "uniform":
tmp = np.random.uniform(-1.0, 1.0, (nout, nin))
tmp = random.uniform(-1.0, 1.0, shape=(nout, nin)).asnumpy()
elif self.rand_type == "normal":
tmp = np.random.normal(0.0, 1.0, (nout, nin))
tmp = random.normal(0.0, 1.0, shape=(nout, nin)).asnumpy()
u, _, v = np.linalg.svd(tmp, full_matrices=False) # pylint: disable=invalid-name
if u.shape == tmp.shape:
res = u
Expand Down
8 changes: 6 additions & 2 deletions python/mxnet/io.py
Expand Up @@ -39,6 +39,8 @@
from .ndarray import _ndarray_cls
from .ndarray import array
from .ndarray import concatenate
from .ndarray import arange
from .ndarray.random import shuffle as random_shuffle

class DataDesc(namedtuple('DataDesc', ['name', 'shape'])):
"""DataDesc is used to store name, shape, type and layout
Expand Down Expand Up @@ -651,12 +653,14 @@ def __init__(self, data, label=None, batch_size=1, shuffle=False,
raise NotImplementedError("`NDArrayIter` only supports ``CSRNDArray``" \
" with `last_batch_handle` set to `discard`.")

self.idx = np.arange(self.data[0][1].shape[0])
# shuffle data
if shuffle:
np.random.shuffle(self.idx)
tmp_idx = arange(self.data[0][1].shape[0], dtype=np.int32)
self.idx = random_shuffle(tmp_idx, out=tmp_idx).asnumpy()
self.data = _shuffle(self.data, self.idx)
self.label = _shuffle(self.label, self.idx)
else:
self.idx = np.arange(self.data[0][1].shape[0])

# batching
if last_batch_handle == 'discard':
Expand Down

0 comments on commit 54ddd53

Please sign in to comment.