diff --git a/src/operator/nn/mkldnn/mkldnn_softmax.cc b/src/operator/nn/mkldnn/mkldnn_softmax.cc index aa59f13d06da..acfa358a796e 100644 --- a/src/operator/nn/mkldnn/mkldnn_softmax.cc +++ b/src/operator/nn/mkldnn/mkldnn_softmax.cc @@ -26,6 +26,7 @@ #include "../softmax-inl.h" #include "./mkldnn_ops-inl.h" #include "./mkldnn_base-inl.h" +#include "../../tensor/broadcast_reduce_op.h" #if MXNET_USE_MKLDNN == 1 namespace mxnet { @@ -38,11 +39,13 @@ void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, auto input_mem = in_data.GetMKLDNNData(); mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc(); mkldnn::memory::desc data_md = data_mpd.desc(); + int axis = CheckAxis(param.axis, in_data.shape().ndim()); + auto cpu_engine = data_mpd.get_engine(); auto prop = ctx.is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring; mkldnn::softmax_forward::desc desc = mkldnn::softmax_forward::desc(prop, - data_md, param.axis); + data_md, axis); mkldnn::softmax_forward::primitive_desc pdesc(desc, cpu_engine); auto output_memory = out_data.GetMKLDNNData(); diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc index f8cc6fee9a22..e9b104f12868 100644 --- a/src/operator/nn/softmax.cc +++ b/src/operator/nn/softmax.cc @@ -38,10 +38,8 @@ static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - const SoftmaxParam& param = nnvm::get(attrs.parsed); // It seems MKLDNN softmax doesn't support training. - // and it only supports non-negative axis. - if (SupportMKLDNN(inputs[0]) && !ctx.is_train && param.axis >= 0) { + if (SupportMKLDNN(inputs[0]) && !ctx.is_train) { MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); MKLDNNSoftmaxForward(attrs, ctx, inputs[0], req[0], outputs[0]); auto fn = SoftmaxCompute; diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index f287c1919636..674266934363 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -4098,7 +4098,7 @@ def test_new_softmax(): for ndim in range(1, 5): for _ in range(5): shape = np.random.randint(1, 5, size=ndim) - axis = np.random.randint(0, ndim) + axis = np.random.randint(-ndim, ndim) data = np.random.uniform(-2, 2, size=shape) sym = mx.sym.softmax(axis=axis) check_symbolic_forward(sym, [data], [np_softmax(data, axis=axis)])