Skip to content

Commit

Permalink
support for any datatype for correlation operator (apache#10125)
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 authored and piiswrong committed Mar 16, 2018
1 parent 9293655 commit 1e0b1ec
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 36 deletions.
58 changes: 40 additions & 18 deletions src/operator/correlation-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ struct CorrelationParam : public dmlc::Parameter<CorrelationParam> {
.describe("operation type is either multiplication or subduction");
}
};
template<typename xpu>
template<typename xpu, typename DType>
class CorrelationOp : public Operator {
public:
explicit CorrelationOp(CorrelationParam param) {
Expand All @@ -79,14 +79,14 @@ class CorrelationOp : public Operator {
CHECK_EQ(in_data.size(), 2U);
CHECK_EQ(out_data.size(), 3U);
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 4> data1 = in_data[Correlation::kData1].get<xpu, 4, real_t>(s);
Tensor<xpu, 4> data2 = in_data[Correlation::kData2].get<xpu, 4, real_t>(s);
Tensor<xpu, 4> out = out_data[Correlation::kOut].get<xpu, 4, real_t>(s);
Tensor<xpu, 4> tmp1 = out_data[Correlation::kTemp1].get<xpu, 4, real_t>(s);
Tensor<xpu, 4> tmp2 = out_data[Correlation::kTemp2].get<xpu, 4, real_t>(s);
tmp1 = 0.0f;
tmp2 = 0.0f;
out = 0.0f;
Tensor<xpu, 4, DType> data1 = in_data[Correlation::kData1].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> data2 = in_data[Correlation::kData2].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> out = out_data[Correlation::kOut].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> tmp1 = out_data[Correlation::kTemp1].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> tmp2 = out_data[Correlation::kTemp2].get<xpu, 4, DType>(s);
tmp1 = DType(0.0f);
tmp2 = DType(0.0f);
out = DType(0.0f);
CHECK_EQ(data1.CheckContiguous(), true);
CHECK_EQ(data2.CheckContiguous(), true);
CHECK_EQ(out.CheckContiguous(), true);
Expand Down Expand Up @@ -124,13 +124,13 @@ class CorrelationOp : public Operator {
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
Stream<xpu> *s = ctx.get_stream<xpu>();
Tensor<xpu, 4> grad_data1 = in_grad[Correlation::kData1].get<xpu, 4, real_t>(s);
Tensor<xpu, 4> grad_data2 = in_grad[Correlation::kData2].get<xpu, 4, real_t>(s);
Tensor<xpu, 4> out_g = out_grad[Correlation::kOut].get<xpu, 4, real_t>(s);
Tensor<xpu, 4> tmp1 = out_data[Correlation::kTemp1].get<xpu, 4, real_t>(s);
Tensor<xpu, 4> tmp2 = out_data[Correlation::kTemp2].get<xpu, 4, real_t>(s);
if (req[0] != kAddTo) grad_data1 = 0.0f;
if (req[1] != kAddTo) grad_data2 = 0.0f;
Tensor<xpu, 4, DType> grad_data1 = in_grad[Correlation::kData1].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> grad_data2 = in_grad[Correlation::kData2].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> out_g = out_grad[Correlation::kOut].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> tmp1 = out_data[Correlation::kTemp1].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> tmp2 = out_data[Correlation::kTemp2].get<xpu, 4, DType>(s);
if (req[0] != kAddTo) grad_data1 = DType(0.0f);
if (req[1] != kAddTo) grad_data2 = DType(0.0f);
CHECK_EQ(grad_data1.CheckContiguous(), true);
CHECK_EQ(grad_data2.CheckContiguous(), true);
CHECK_EQ(out_g.CheckContiguous(), true);
Expand Down Expand Up @@ -163,7 +163,7 @@ class CorrelationOp : public Operator {
}; // class CorrelationOp
// Decalre Factory function
template<typename xpu>
Operator* CreateOp(CorrelationParam param);
Operator* CreateOp(CorrelationParam param, int dtype);
#if DMLC_USE_CXX11
class CorrelationProp : public OperatorProperty {
public:
Expand Down Expand Up @@ -228,6 +228,22 @@ void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) overr
out_shape->push_back(Shape4(dshape1[0], paddedbottomheight, paddedbottomwidth, dshape1[1]));
return true;
}
bool InferType(std::vector<int> *in_type,
std::vector<int> *out_type,
std::vector<int> *aux_type) const override {
int dtype = (*in_type)[0];
type_assign(&(*in_type)[1], dtype);
type_assign(&(*out_type)[0], dtype);
type_assign(&(*out_type)[1], dtype);
type_assign(&(*out_type)[2], dtype);

TYPE_ASSIGN_CHECK(*in_type, 0, dtype);
TYPE_ASSIGN_CHECK(*in_type, 1, dtype);
TYPE_ASSIGN_CHECK(*out_type, 0, dtype);
TYPE_ASSIGN_CHECK(*out_type, 1, dtype);
TYPE_ASSIGN_CHECK(*out_type, 2, dtype);
return dtype != -1;
}
OperatorProperty* Copy() const override {
CorrelationProp* Correlation_sym = new CorrelationProp();
Correlation_sym->param_ = this->param_;
Expand All @@ -244,7 +260,13 @@ void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) overr
return {out_grad[Correlation::kOut],
out_data[Correlation::kTemp1], out_data[Correlation::kTemp2]};
}
Operator* CreateOperator(Context ctx) const override;
Operator* CreateOperator(Context ctx) const override {
LOG(FATAL) << "Not Implemented.";
return NULL;
}

Operator* CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
std::vector<int> *in_type) const override;

private:
CorrelationParam param_;
Expand Down
13 changes: 9 additions & 4 deletions src/operator/correlation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,16 @@ inline void CorrelationBackward(const Tensor<cpu, 4, Dtype> &out_grad,
namespace mxnet {
namespace op {
template<>
Operator *CreateOp<cpu>(CorrelationParam param) {
return new CorrelationOp<cpu>(param);
Operator *CreateOp<cpu>(CorrelationParam param, int dtype) {
Operator* op = NULL;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
op = new CorrelationOp<cpu, DType>(param);
});
return op;
}
Operator* CorrelationProp::CreateOperator(Context ctx) const {
DO_BIND_DISPATCH(CreateOp, param_);
Operator* CorrelationProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
std::vector<int> *in_type) const {
DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0));
}
DMLC_REGISTER_PARAMETER(CorrelationParam);
MXNET_REGISTER_OP_PROPERTY(Correlation, CorrelationProp)
Expand Down
8 changes: 6 additions & 2 deletions src/operator/correlation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -621,8 +621,12 @@ inline void CorrelationBackward(const Tensor<gpu, 4, Dtype> &out_grad,
namespace mxnet {
namespace op {
template<>
Operator* CreateOp<gpu>(CorrelationParam param) {
return new CorrelationOp<gpu>(param);
Operator* CreateOp<gpu>(CorrelationParam param, int dtype) {
Operator* op = NULL;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
op = new CorrelationOp<gpu, DType>(param);
});
return op;
}
} // namespace op
} // namespace mxnet
25 changes: 13 additions & 12 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2161,12 +2161,12 @@ def correlation_backward(out_grad,tmp1,tmp2,data1,data2,pad_size,kernel_size,str
return tmp1_grad[:,:,pad_size:pad_size+data1.shape[2],pad_size:pad_size+data1.shape[3]],tmp2_grad[:,:,pad_size:pad_size+data1.shape[2],pad_size:pad_size+data1.shape[3]],


def unittest_correlation(data_shape,kernel_size,max_displacement,stride1,stride2,pad_size,is_multiply):
def unittest_correlation(data_shape,kernel_size,max_displacement,stride1,stride2,pad_size,is_multiply,dtype):

img1 = np.random.random(data_shape)
img1 = img1.astype(np.float32)
img1 = img1.astype(dtype)
img2 = np.random.random(data_shape)
img2 = img2.astype(np.float32)
img2 = img2.astype(dtype)

net1 = get_correlation(img1,img2,kernel_size,max_displacement,stride1,stride2,pad_size,is_multiply)
net2 = get_correlation(img1,img2,kernel_size,max_displacement,stride1,stride2,pad_size,is_multiply )
Expand Down Expand Up @@ -2198,15 +2198,16 @@ def unittest_correlation(data_shape,kernel_size,max_displacement,stride1,stride2

@with_seed()
def test_correlation():
unittest_correlation((1,3,10,10), kernel_size = 1,max_displacement = 4,stride1 = 1,stride2 = 1,pad_size = 4,is_multiply = False)
unittest_correlation((5,1,15,15), kernel_size = 1,max_displacement = 5,stride1 = 1,stride2 = 1,pad_size = 5,is_multiply = False)
unittest_correlation((5,1,15,15), kernel_size = 1,max_displacement = 5,stride1 = 1,stride2 = 1,pad_size = 5,is_multiply = True)
unittest_correlation((5,1,15,15), kernel_size = 1,max_displacement = 10,stride1 = 1,stride2 = 2,pad_size = 10,is_multiply = True)
unittest_correlation((5,1,4,4), kernel_size = 3,max_displacement = 1,stride1 = 1,stride2 = 1,pad_size = 2,is_multiply = True)
unittest_correlation((5,1,4,4), kernel_size = 3,max_displacement = 1,stride1 = 2,stride2 = 1,pad_size = 2,is_multiply = True)
unittest_correlation((5,1,4,4), kernel_size = 3,max_displacement = 1,stride1 = 2,stride2 = 1,pad_size = 2,is_multiply = False)
unittest_correlation((5,1,6,4), kernel_size = 3,max_displacement = 1,stride1 = 2,stride2 = 1,pad_size = 2,is_multiply = False)
unittest_correlation((5,1,11,11), kernel_size = 5,max_displacement = 1,stride1 = 1,stride2 = 1,pad_size = 2,is_multiply = False)
for dtype in ['float16', 'float32', 'float64']:
unittest_correlation((1,3,10,10), kernel_size = 1,max_displacement = 4,stride1 = 1,stride2 = 1,pad_size = 4,is_multiply = False, dtype = dtype)
unittest_correlation((5,1,15,15), kernel_size = 1,max_displacement = 5,stride1 = 1,stride2 = 1,pad_size = 5,is_multiply = False, dtype = dtype)
unittest_correlation((5,1,15,15), kernel_size = 1,max_displacement = 5,stride1 = 1,stride2 = 1,pad_size = 5,is_multiply = True, dtype = dtype)
unittest_correlation((5,1,15,15), kernel_size = 1,max_displacement = 10,stride1 = 1,stride2 = 2,pad_size = 10,is_multiply = True, dtype = dtype)
unittest_correlation((5,1,4,4), kernel_size = 3,max_displacement = 1,stride1 = 1,stride2 = 1,pad_size = 2,is_multiply = True, dtype = dtype)
unittest_correlation((5,1,4,4), kernel_size = 3,max_displacement = 1,stride1 = 2,stride2 = 1,pad_size = 2,is_multiply = True, dtype = dtype)
unittest_correlation((5,1,4,4), kernel_size = 3,max_displacement = 1,stride1 = 2,stride2 = 1,pad_size = 2,is_multiply = False, dtype = dtype)
unittest_correlation((5,1,6,4), kernel_size = 3,max_displacement = 1,stride1 = 2,stride2 = 1,pad_size = 2,is_multiply = False, dtype = dtype)
unittest_correlation((5,1,11,11), kernel_size = 5,max_displacement = 1,stride1 = 1,stride2 = 1,pad_size = 2,is_multiply = False, dtype = dtype)


@with_seed()
Expand Down

0 comments on commit 1e0b1ec

Please sign in to comment.