Skip to content

Commit

Permalink
[MXNET-92] Support float16 in L2Normalization operator (apache#10078)
Browse files Browse the repository at this point in the history
* enable other dtype in l2 normalization

* Get rid of older code

* address code reviews: get rid of unnecessary checks

* address code reviews

* fix buggy InferType in L2Normalization

* address code review: change atol
  • Loading branch information
haojin2 authored and piiswrong committed Mar 20, 2018
1 parent ae4dddc commit d240f9e
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 68 deletions.
135 changes: 83 additions & 52 deletions src/operator/l2_normalization-inl.h
Expand Up @@ -66,7 +66,7 @@ struct L2NormalizationParam : public dmlc::Parameter<L2NormalizationParam> {
* \brief This is the implementation of l2 normalization operator.
* \tparam xpu The device that the op will be executed on.
*/
template<typename xpu>
template<typename xpu, typename DType>
class L2NormalizationOp : public Operator {
public:
explicit L2NormalizationOp(L2NormalizationParam p) {
Expand All @@ -89,41 +89,53 @@ class L2NormalizationOp : public Operator {
if (param_.mode == l2_normalization::kInstance) {
Shape<2> dshape = Shape2(orig_shape[0],
orig_shape.ProdShape(1, orig_shape.ndim()));
Tensor<xpu, 2> data = in_data[l2_normalization::kData]
.get_with_shape<xpu, 2, real_t>(dshape, s);
Tensor<xpu, 2> out = out_data[l2_normalization::kOut]
.get_with_shape<xpu, 2, real_t>(dshape, s);
Tensor<xpu, 1> norm = out_data[l2_normalization::kNorm].get<xpu, 1, real_t>(s);
Tensor<xpu, 2, DType> data = in_data[l2_normalization::kData]
.get_with_shape<xpu, 2, DType>(dshape, s);
Tensor<xpu, 2, DType> out = out_data[l2_normalization::kOut]
.get_with_shape<xpu, 2, DType>(dshape, s);
Tensor<xpu, 1, DType> norm = out_data[l2_normalization::kNorm].get<xpu, 1, DType>(s);
norm = sumall_except_dim<0>(F<mxnet::op::mshadow_op::square>(data));
norm = F<mxnet::op::mshadow_op::square_root>(norm + param_.eps);
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::plus, Req>, xpu>::Launch(
s, norm.size(0), norm.dptr_, norm.dptr_, DType(param_.eps));
});
norm = F<mxnet::op::mshadow_op::square_root>(norm);
out = data / broadcast<0>(norm, out.shape_);
} else if (param_.mode == l2_normalization::kChannel) {
CHECK_GE(orig_shape.ndim(), 3U);
Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
orig_shape.ProdShape(2, orig_shape.ndim()));
Tensor<xpu, 3> data = in_data[l2_normalization::kData]
.get_with_shape<xpu, 3, real_t>(dshape, s);
Tensor<xpu, 3> out = out_data[l2_normalization::kOut]
.get_with_shape<xpu, 3, real_t>(dshape, s);
Tensor<xpu, 3, DType> data = in_data[l2_normalization::kData]
.get_with_shape<xpu, 3, DType>(dshape, s);
Tensor<xpu, 3, DType> out = out_data[l2_normalization::kOut]
.get_with_shape<xpu, 3, DType>(dshape, s);
Shape<2> norm_shape = Shape2(dshape[0], dshape[2]);
Tensor<xpu, 2> norm = out_data[l2_normalization::kNorm]
.get_with_shape<xpu, 2, real_t>(norm_shape, s);
Tensor<xpu, 2, DType> norm = out_data[l2_normalization::kNorm]
.get_with_shape<xpu, 2, DType>(norm_shape, s);
norm = reduce_with_axis<red::sum, false>(F<mxnet::op::mshadow_op::square>(data), 1);
norm = F<mxnet::op::mshadow_op::square_root>(norm + param_.eps);
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::plus, Req>, xpu>::Launch(
s, norm.size(0) * norm.size(1), norm.dptr_, norm.dptr_, DType(param_.eps));
});
norm = F<mxnet::op::mshadow_op::square_root>(norm);
out = data / broadcast_with_axis(norm, 0, orig_shape[1]);
} else if (param_.mode == l2_normalization::kSpatial) {
CHECK_GE(orig_shape.ndim(), 3U);
Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
orig_shape.ProdShape(2, orig_shape.ndim()));
Tensor<xpu, 3> data = in_data[l2_normalization::kData]
.get_with_shape<xpu, 3, real_t>(dshape, s);
Tensor<xpu, 3> out = out_data[l2_normalization::kOut]
.get_with_shape<xpu, 3, real_t>(dshape, s);
Tensor<xpu, 3, DType> data = in_data[l2_normalization::kData]
.get_with_shape<xpu, 3, DType>(dshape, s);
Tensor<xpu, 3, DType> out = out_data[l2_normalization::kOut]
.get_with_shape<xpu, 3, DType>(dshape, s);
Shape<2> norm_shape = Shape2(dshape[0], dshape[1]);
Tensor<xpu, 2> norm = out_data[l2_normalization::kNorm]
.get_with_shape<xpu, 2, real_t>(norm_shape, s);
Tensor<xpu, 2, DType> norm = out_data[l2_normalization::kNorm]
.get_with_shape<xpu, 2, DType>(norm_shape, s);
norm = reduce_with_axis<red::sum, false>(F<mxnet::op::mshadow_op::square>(data), 2);
norm = F<mxnet::op::mshadow_op::square_root>(norm + param_.eps);
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
mxnet_op::Kernel<mxnet_op::op_with_req<mxnet::op::mshadow_op::plus, Req>, xpu>::Launch(
s, norm.size(0) * norm.size(1), norm.dptr_, norm.dptr_, DType(param_.eps));
});
norm = F<mxnet::op::mshadow_op::square_root>(norm);
out = data / broadcast_with_axis(norm, 1, dshape[2]);
} else {
LOG(FATAL) << "Unexpected mode in l2 normalization";
Expand All @@ -148,15 +160,15 @@ class L2NormalizationOp : public Operator {
if (param_.mode == l2_normalization::kInstance) {
Shape<2> dshape = Shape2(orig_shape[0],
orig_shape.ProdShape(1, orig_shape.ndim()));
Tensor<xpu, 2> data = out_data[l2_normalization::kOut]
.get_with_shape<xpu, 2, real_t>(dshape, s);
Tensor<xpu, 2> grad_in = in_grad[l2_normalization::kData]
.get_with_shape<xpu, 2, real_t>(dshape, s);
Tensor<xpu, 2> grad_out = out_grad[l2_normalization::kOut]
.get_with_shape<xpu, 2, real_t>(dshape, s);
Tensor<xpu, 1> norm = out_data[l2_normalization::kNorm].get<xpu, 1, real_t>(s);
Tensor<xpu, 1> temp = ctx.requested[l2_normalization::kTempSpace]
.get_space<xpu>(mshadow::Shape1(data.shape_[0]), s);
Tensor<xpu, 2, DType> data = out_data[l2_normalization::kOut]
.get_with_shape<xpu, 2, DType>(dshape, s);
Tensor<xpu, 2, DType> grad_in = in_grad[l2_normalization::kData]
.get_with_shape<xpu, 2, DType>(dshape, s);
Tensor<xpu, 2, DType> grad_out = out_grad[l2_normalization::kOut]
.get_with_shape<xpu, 2, DType>(dshape, s);
Tensor<xpu, 1, DType> norm = out_data[l2_normalization::kNorm].get<xpu, 1, DType>(s);
Tensor<xpu, 1, DType> temp = ctx.requested[l2_normalization::kTempSpace]
.get_space_typed<xpu, 1, DType>(mshadow::Shape1(data.shape_[0]), s);
temp = sumall_except_dim<0>(grad_out * data);
Assign(grad_in, req[l2_normalization::kData],
(grad_out - data * broadcast<0>(temp, data.shape_)) /
Expand All @@ -165,17 +177,17 @@ class L2NormalizationOp : public Operator {
CHECK_GE(orig_shape.ndim(), 3U);
Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
orig_shape.ProdShape(2, orig_shape.ndim()));
Tensor<xpu, 3> data = out_data[l2_normalization::kOut]
.get_with_shape<xpu, 3, real_t>(dshape, s);
Tensor<xpu, 3> grad_in = in_grad[l2_normalization::kData]
.get_with_shape<xpu, 3, real_t>(dshape, s);
Tensor<xpu, 3> grad_out = out_grad[l2_normalization::kOut]
.get_with_shape<xpu, 3, real_t>(dshape, s);
Tensor<xpu, 3, DType> data = out_data[l2_normalization::kOut]
.get_with_shape<xpu, 3, DType>(dshape, s);
Tensor<xpu, 3, DType> grad_in = in_grad[l2_normalization::kData]
.get_with_shape<xpu, 3, DType>(dshape, s);
Tensor<xpu, 3, DType> grad_out = out_grad[l2_normalization::kOut]
.get_with_shape<xpu, 3, DType>(dshape, s);
Shape<2> norm_shape = Shape2(dshape[0], dshape[2]);
Tensor<xpu, 2> norm = out_data[l2_normalization::kNorm]
.get_with_shape<xpu, 2, real_t>(norm_shape, s);
Tensor<xpu, 2> temp = ctx.requested[l2_normalization::kTempSpace]
.get_space<xpu>(mshadow::Shape2(data.shape_[0], data.shape_[2]), s);
Tensor<xpu, 2, DType> norm = out_data[l2_normalization::kNorm]
.get_with_shape<xpu, 2, DType>(norm_shape, s);
Tensor<xpu, 2, DType> temp = ctx.requested[l2_normalization::kTempSpace]
.get_space_typed<xpu, 2, DType>(mshadow::Shape2(data.shape_[0], data.shape_[2]), s);
temp = reduce_with_axis<red::sum, false>(grad_out * data, 1);
Assign(grad_in, req[l2_normalization::kData],
(grad_out - data * broadcast_with_axis(temp, 0, orig_shape[1])) /
Expand All @@ -184,17 +196,17 @@ class L2NormalizationOp : public Operator {
CHECK_GE(orig_shape.ndim(), 3U);
Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
orig_shape.ProdShape(2, orig_shape.ndim()));
Tensor<xpu, 3> data = out_data[l2_normalization::kOut]
.get_with_shape<xpu, 3, real_t>(dshape, s);
Tensor<xpu, 3> grad_in = in_grad[l2_normalization::kData]
.get_with_shape<xpu, 3, real_t>(dshape, s);
Tensor<xpu, 3> grad_out = out_grad[l2_normalization::kOut]
.get_with_shape<xpu, 3, real_t>(dshape, s);
Tensor<xpu, 3, DType> data = out_data[l2_normalization::kOut]
.get_with_shape<xpu, 3, DType>(dshape, s);
Tensor<xpu, 3, DType> grad_in = in_grad[l2_normalization::kData]
.get_with_shape<xpu, 3, DType>(dshape, s);
Tensor<xpu, 3, DType> grad_out = out_grad[l2_normalization::kOut]
.get_with_shape<xpu, 3, DType>(dshape, s);
Shape<2> norm_shape = Shape2(dshape[0], dshape[1]);
Tensor<xpu, 2> norm = out_data[l2_normalization::kNorm]
.get_with_shape<xpu, 2, real_t>(norm_shape, s);
Tensor<xpu, 2> temp = ctx.requested[l2_normalization::kTempSpace]
.get_space<xpu>(mshadow::Shape2(data.shape_[0], data.shape_[1]), s);
Tensor<xpu, 2, DType> norm = out_data[l2_normalization::kNorm]
.get_with_shape<xpu, 2, DType>(norm_shape, s);
Tensor<xpu, 2, DType> temp = ctx.requested[l2_normalization::kTempSpace]
.get_space_typed<xpu, 2, DType>(mshadow::Shape2(data.shape_[0], data.shape_[1]), s);
temp = reduce_with_axis<red::sum, false>(grad_out * data, 2);
Assign(grad_in, req[l2_normalization::kData],
(grad_out - data * broadcast_with_axis(temp, 1, dshape[2])) /
Expand All @@ -210,7 +222,7 @@ class L2NormalizationOp : public Operator {

// Decalre Factory function, used for dispatch specialization
template<typename xpu>
Operator* CreateOp(L2NormalizationParam param);
Operator* CreateOp(L2NormalizationParam param, int dtype);

#if DMLC_USE_CXX11
class L2NormalizationProp : public OperatorProperty {
Expand All @@ -235,6 +247,19 @@ class L2NormalizationProp : public OperatorProperty {
return param_.__DICT__();
}

bool InferType(std::vector<int> *in_type,
std::vector<int> *out_type,
std::vector<int> *aux_type) const override {
int dtype = (*in_type)[0];
type_assign(&dtype, (*out_type)[0]);
type_assign(&dtype, (*out_type)[1]);

TYPE_ASSIGN_CHECK(*in_type, 0, dtype);
TYPE_ASSIGN_CHECK(*out_type, 0, dtype);
TYPE_ASSIGN_CHECK(*out_type, 1, dtype);
return dtype != -1;
}

bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape,
std::vector<TShape> *aux_shape) const override {
Expand Down Expand Up @@ -294,7 +319,13 @@ class L2NormalizationProp : public OperatorProperty {
return {ResourceRequest::kTempSpace};
}

Operator* CreateOperator(Context ctx) const override;
Operator* CreateOperator(Context ctx) const override {
LOG(FATAL) << "Not Implemented.";
return NULL;
}

Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
std::vector<int> *in_type) const override;

private:
L2NormalizationParam param_;
Expand Down
13 changes: 9 additions & 4 deletions src/operator/l2_normalization.cc
Expand Up @@ -26,13 +26,18 @@
namespace mxnet {
namespace op {
template<>
Operator* CreateOp<cpu>(L2NormalizationParam param) {
return new L2NormalizationOp<cpu>(param);
Operator* CreateOp<cpu>(L2NormalizationParam param, int dtype) {
Operator* op = NULL;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
op = new L2NormalizationOp<cpu, DType>(param);
});
return op;
}

// DO_BIND_DISPATCH comes from static_operator_common.h
Operator* L2NormalizationProp::CreateOperator(Context ctx) const {
DO_BIND_DISPATCH(CreateOp, param_);
Operator* L2NormalizationProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
std::vector<int> *in_type) const {
DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0));
}

DMLC_REGISTER_PARAMETER(L2NormalizationParam);
Expand Down
8 changes: 6 additions & 2 deletions src/operator/l2_normalization.cu
Expand Up @@ -26,8 +26,12 @@
namespace mxnet {
namespace op {
template<>
Operator* CreateOp<gpu>(L2NormalizationParam param) {
return new L2NormalizationOp<gpu>(param);
Operator* CreateOp<gpu>(L2NormalizationParam param, int dtype) {
Operator* op = NULL;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
op = new L2NormalizationOp<gpu, DType>(param);
});
return op;
}
} // namespace op
} // namespace mxnet
21 changes: 11 additions & 10 deletions tests/python/unittest/test_operator.py
Expand Up @@ -2391,11 +2391,11 @@ def test_instance_normalization():
check_instance_norm_with_shape((3,3,2,3,2,1,1), default_context())


def check_l2_normalization(in_shape, mode, norm_eps=1e-10):
def check_l2_normalization(in_shape, mode, dtype, norm_eps=1e-10):
ctx = default_context()
data = mx.symbol.Variable('data')
out = mx.symbol.L2Normalization(data=data, mode=mode, eps=norm_eps)
in_data = np.random.uniform(-1, 1, in_shape)
in_data = np.random.uniform(-1, 1, in_shape).astype(dtype)
# calculate numpy results
if mode == 'channel':
assert in_data.ndim > 2
Expand All @@ -2419,21 +2419,22 @@ def check_l2_normalization(in_shape, mode, norm_eps=1e-10):
exe = out.simple_bind(ctx=ctx, data=in_data.shape)
output = exe.forward(is_train=True, data=in_data)
# compare numpy + mxnet
assert_almost_equal(exe.outputs[0].asnumpy(), np_out, rtol=1e-5)
assert_almost_equal(exe.outputs[0].asnumpy(), np_out, rtol=1e-2 if dtype is 'float16' else 1e-5, atol=1e-5)
# check gradient
check_numeric_gradient(out, [in_data], numeric_eps=1e-3, rtol=1e-2, atol=1e-3)


# TODO(szha): Seeding this masks failures. We need to do a deep dive for failures without this seed.
@with_seed(1234)
def test_l2_normalization():
for mode in ['channel', 'spatial', 'instance']:
for nbatch in [1, 4]:
for nchannel in [3, 5]:
for height in [4, 6]:
check_l2_normalization((nbatch, nchannel, height), mode)
for width in [5, 7]:
check_l2_normalization((nbatch, nchannel, height, width), mode)
for dtype in ['float16', 'float32', 'float64']:
for mode in ['channel', 'spatial', 'instance']:
for nbatch in [1, 4]:
for nchannel in [3, 5]:
for height in [4, 6]:
check_l2_normalization((nbatch, nchannel, height), mode, dtype)
for width in [5, 7]:
check_l2_normalization((nbatch, nchannel, height, width), mode, dtype)


def check_layer_normalization(in_shape, axis, eps, dtype=np.float32, forward_check_eps=1E-3):
Expand Down

0 comments on commit d240f9e

Please sign in to comment.