Skip to content

Commit

Permalink
[MXNET-514] Add clip_global_norm(row_sparse_grad). Fix row_sparse_par…
Browse files Browse the repository at this point in the history
…am.save(). Fix trainer init_kvstore (apache#11266)

* clip sparse grad. fix _reduce for rowsparse param

* fix kvstore init for local kv

* trigger
  • Loading branch information
eric-haibin-lin authored and zheng-da committed Jun 28, 2018
1 parent 0f5546d commit 3868af6
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 49 deletions.
10 changes: 6 additions & 4 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,14 +310,16 @@ def _init_grad(self):
self._grad, self.grad_req)

def _reduce(self):
"""Reduce data from multiple context."""
"""Reduce data from multiple context to cpu."""
ctx = context.cpu()
if self._stype == 'default':
block = self.list_data()
data = ndarray.add_n(*(w.copyto(context.cpu()) for w in block)) / len(block)
data = ndarray.add_n(*(w.copyto(ctx) for w in block)) / len(block)
else:
# fetch all rows for 'row_sparse' param
all_row_ids = ndarray.arange(0, self.shape[0], dtype='int64', ctx=context.cpu())
data = self.row_sparse_data(all_row_ids)
all_row_ids = ndarray.arange(0, self.shape[0], dtype='int64', ctx=ctx)
data = ndarray.zeros(self.shape, stype='row_sparse', ctx=ctx)
self._trainer._row_sparse_pull(self, data, all_row_ids)
return data

def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(),
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/gluon/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ def _reset_kvstore(self):

def _init_kvstore(self):
"""Create kvstore."""
arg_arrays = {}
config = self._kvstore_params
if self._contains_sparse:
kvstore, update_on_kvstore = _create_sparse_kvstore(config['kvstore'])
Expand All @@ -162,6 +161,7 @@ def _init_kvstore(self):
"gradients and/or sparse weights are present for "
"Parameter '%s'."%param.name)
else:
arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params}
kvstore, update_on_kvstore = _create_kvstore(config['kvstore'], len(self._contexts),
arg_arrays)
if config['update_on_kvstore'] is not None:
Expand Down
8 changes: 6 additions & 2 deletions python/mxnet/gluon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,14 @@ def split_and_load(data, ctx_list, batch_axis=0, even_split=True):
def clip_global_norm(arrays, max_norm):
"""Rescales NDArrays so that the sum of their 2-norm is smaller than `max_norm`.
"""
def _norm(array):
if array.stype == 'default':
x = array.reshape((-1,))
return ndarray.dot(x, x)
return array.norm().square()
assert len(arrays) > 0
ctx = arrays[0].context
total_norm = ndarray.add_n(*[ndarray.dot(x, x).as_in_context(ctx)
for x in (arr.reshape((-1,)) for arr in arrays)])
total_norm = ndarray.add_n(*[_norm(arr).as_in_context(ctx) for arr in arrays])
total_norm = ndarray.sqrt(total_norm).asscalar()
if not np.isfinite(total_norm):
warnings.warn(UserWarning('nan or inf is detected. Clipping results will be undefined.'),
Expand Down
47 changes: 26 additions & 21 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,16 @@ def test_parameter_invalid_access():

@with_seed()
def test_paramdict():
ctx = mx.cpu(1)
params0 = gluon.ParameterDict('net_')
params0.get('w0', shape=(10, 10))
params0.get('w1', shape=(10, 10), stype='row_sparse')
all_row_ids = mx.nd.arange(0, 10, ctx=mx.cpu())
all_row_ids = mx.nd.arange(0, 10, ctx=ctx)
# check param names
assert list(params0.keys()) == ['net_w0', 'net_w1']
params0.initialize(ctx=mx.cpu())
params0.initialize(ctx=ctx)
trainer0 = mx.gluon.Trainer(params0, 'sgd')
prev_w0 = params0.get('w0').data(mx.cpu())
prev_w0 = params0.get('w0').data(ctx)
prev_w1 = params0.get('w1').row_sparse_data(all_row_ids)
# save params
params0.save('test_paramdict.params')
Expand All @@ -107,11 +108,11 @@ def test_paramdict():
params1 = gluon.ParameterDict('net_')
params1.get('w0', shape=(10, 10))
params1.get('w1', shape=(10, 10), stype='row_sparse')
params1.load('test_paramdict.params', mx.cpu())
params1.load('test_paramdict.params', ctx)
trainer1 = mx.gluon.Trainer(params1, 'sgd')

# compare the values before and after save/load
cur_w0 = params1.get('w0').data(mx.cpu())
cur_w0 = params1.get('w0').data(ctx)
cur_w1 = params1.get('w1').row_sparse_data(all_row_ids)
mx.test_utils.assert_almost_equal(prev_w0.asnumpy(), cur_w0.asnumpy())
mx.test_utils.assert_almost_equal(prev_w1.asnumpy(), cur_w1.asnumpy())
Expand All @@ -121,11 +122,11 @@ def test_paramdict():
params2 = gluon.ParameterDict('net_')
params2.get('w0', shape=(10, 10))
params2.get('w1', shape=(10, 10))
params2.load('test_paramdict.params', mx.cpu())
params2.load('test_paramdict.params', ctx)

# compare the values before and after save/load
cur_w0 = params2.get('w0').data(mx.cpu())
cur_w1 = params2.get('w1').data(mx.cpu())
cur_w0 = params2.get('w0').data(ctx)
cur_w1 = params2.get('w1').data(ctx)
mx.test_utils.assert_almost_equal(prev_w0.asnumpy(), cur_w0.asnumpy())
mx.test_utils.assert_almost_equal(prev_w1.asnumpy(), cur_w1.asnumpy())

Expand Down Expand Up @@ -731,19 +732,23 @@ def test_sequential_warning():

@with_seed()
def test_global_norm_clip():
x1 = mx.nd.ones((3,3))
x2 = mx.nd.ones((4,4))
norm = gluon.utils.clip_global_norm([x1, x2], 1.0)
assert norm == 5.0
assert_almost_equal(x1.asnumpy(), np.ones((3,3))/5)
assert_almost_equal(x2.asnumpy(), np.ones((4,4))/5)

x3 = mx.nd.array([1.0, 2.0, float('nan')])
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
gluon.utils.clip_global_norm([x1, x3], 2.0)
assert len(w) == 1

stypes = ['default', 'row_sparse']
def check_global_norm_clip(stype):
x1 = mx.nd.ones((3,3)).tostype(stype)
x2 = mx.nd.ones((4,4)).tostype(stype)
norm = gluon.utils.clip_global_norm([x1, x2], 1.0)
assert norm == 5.0
assert_almost_equal(x1.asnumpy(), np.ones((3,3))/5)
assert_almost_equal(x2.asnumpy(), np.ones((4,4))/5)

x3 = mx.nd.array([1.0, 2.0, float('nan')]).tostype(stype)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
gluon.utils.clip_global_norm([x1, x3], 2.0)
assert len(w) == 1

for stype in stypes:
check_global_norm_clip(stype)

@with_seed()
def test_embedding():
Expand Down
48 changes: 27 additions & 21 deletions tests/python/unittest/test_gluon_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,24 +177,30 @@ def test_trainer_save_load():

@with_seed()
def test_trainer_reset_kv():
params = gluon.ParameterDict()
x = params.get('x', shape=(10,), lr_mult=1.0)
params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
trainer = gluon.Trainer(params, 'sgd', {'learning_rate': 0.1})
params.save('test_trainer_reset_kv.params')
with mx.autograd.record():
for w in x.list_data():
y = w + 1
y.backward()
trainer.step(1)
# load would reset kvstore
params.load('test_trainer_reset_kv.params')
assert trainer._kvstore is None
assert trainer._kv_initialized is False
with mx.autograd.record():
for w in x.list_data():
y = w + 1
y.backward()
trainer.step(1)
# the updated parameter should be based on the loaded checkpoint
assert (x.data(mx.cpu()) == -0.2).asnumpy().all()
def check_trainer_reset_kv(kv):
params = gluon.ParameterDict()
x = params.get('x', shape=(10,), lr_mult=1.0)
params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
trainer = gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv)
params.save('test_trainer_reset_kv.params')
with mx.autograd.record():
for w in x.list_data():
y = w + 1
y.backward()
trainer.step(1)
assert trainer._kvstore.type == kv
# load would reset kvstore
params.load('test_trainer_reset_kv.params')
assert trainer._kvstore is None
assert trainer._kv_initialized is False
with mx.autograd.record():
for w in x.list_data():
y = w + 1
y.backward()
trainer.step(1)
# the updated parameter should be based on the loaded checkpoint
assert (x.data(mx.cpu()) == -0.2).asnumpy().all()

kvs = ['local', 'device']
for kv in kvs:
check_trainer_reset_kv(kv)

0 comments on commit 3868af6

Please sign in to comment.