From 15a99daf6e454e4209f2171a3325a7eb5fc467d8 Mon Sep 17 00:00:00 2001 From: Hang Zhang <8041160+zhanghang1989@users.noreply.github.com> Date: Wed, 7 Mar 2018 17:44:24 -0800 Subject: [PATCH] Fix bug for Dropout with axes, also adding unit test (#10030) * fix bug * add test for dropout with axes --- src/operator/nn/dropout-inl.h | 2 +- tests/python/unittest/test_operator.py | 36 ++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index b57ab45891e9..1af4798d1cee 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -259,7 +259,7 @@ class DropoutOp { return; } // initialize the mask - LaunchRNG(s, pgen, out.Size(), + LaunchRNG(s, pgen, mask.Size(), mask.dptr(), this->pkeep_); // broadcast mul diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 1ee14b6e5a41..91b8faa49c12 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -4645,6 +4645,27 @@ def check_dropout_ratio(ratio, shape): exe.backward([mx.nd.ones(shape)], is_train=False) assert (exe.grad_arrays[0].asnumpy() == exe.outputs[0].asnumpy()).all() + def get_slice(x, axis, idx): + ix = () + for i in range(x.ndim): + if i == axis: + ix += (idx,) + else: + ix += (slice(None, None, None),) + return x[ix] + + def check_dropout_axes(ratio, shape, axes): + compactshape = list(shape) + for axis in axes: + compactshape[axis] = 1 + compactx = mx.random.uniform(shape=tuple(compactshape)) + broadcastx = compactx.broadcast_to(shape) + dropouty = mx.nd.Dropout(broadcastx, p=ratio, axes=axes) + for axis in axes: + target = get_slice(dropouty, axis, 0).asnumpy() + for i in range(1, shape[axis]): + assert(get_slice(dropouty, axis, i).asnumpy() == target).all() + shape = (100, 100) check_dropout_ratio(0.5, shape) check_dropout_ratio(0.0, shape) @@ -4652,6 +4673,21 @@ def check_dropout_ratio(ratio, shape): check_dropout_ratio(0.75, shape) check_dropout_ratio(0.25, shape) + nshape = (10, 10, 10, 10) + check_dropout_axes(0.25, nshape, axes = (0,)) + check_dropout_axes(0.25, nshape, axes = (1,)) + check_dropout_axes(0.25, nshape, axes = (2,)) + check_dropout_axes(0.25, nshape, axes = (3,)) + check_dropout_axes(0.25, nshape, axes = (0, 1)) + check_dropout_axes(0.25, nshape, axes = (0, 2)) + check_dropout_axes(0.25, nshape, axes = (0, 3)) + check_dropout_axes(0.25, nshape, axes = (1, 2)) + check_dropout_axes(0.25, nshape, axes = (1, 3)) + check_dropout_axes(0.25, nshape, axes = (2, 3)) + check_dropout_axes(0.25, nshape, axes = (0, 1, 2)) + check_dropout_axes(0.25, nshape, axes = (0, 2, 3)) + check_dropout_axes(0.25, nshape, axes = (1, 2, 3)) + @with_seed() def test_scatter_gather_nd():