Skip to content

Commit

Permalink
Update locally_connected_op (pytorch#2113)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaomengy committed Mar 2, 2018
1 parent ec3c299 commit f76fc6f
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 108 deletions.
41 changes: 8 additions & 33 deletions caffe2/operators/locally_connected_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "caffe2/core/operator.h"
#include "caffe2/operators/conv_op_shared.h"
#include "caffe2/operators/conv_pool_op_base.h"
#include "caffe2/operators/locally_connected_op_util.h"

namespace caffe2 {

Expand All @@ -46,21 +47,8 @@ class LocallyConnectedOp final : public ConvPoolOpBase<Context> {
bool RunOnDeviceWithOrderNHWC() override;

private:
struct ShapeParams {
int N;
int C;
int M;
int input_image_size;
int output_image_size;
int kernel_dim;
std::vector<int> input_image_dims;
std::vector<int> column_dims;
std::vector<int> column_transposed_dims;
std::vector<int> Y_transposed_dims;
};

void RunOnDeviceWithOrderNCHWImpl(
const ShapeParams& shape,
const lc_op_util::ShapeParams& shape,
const T* X_data,
const T* filter_data,
const T* bias_data,
Expand All @@ -70,7 +58,7 @@ class LocallyConnectedOp final : public ConvPoolOpBase<Context> {
Tensor<Context>* output_buffer);

void RunOnDeviceWithOrderNHWCImpl(
const ShapeParams& shape,
const lc_op_util::ShapeParams& shape,
const T* X_data,
const T* filter_data,
const T* bias_data,
Expand Down Expand Up @@ -99,7 +87,7 @@ class LocallyConnectedOp final : public ConvPoolOpBase<Context> {
Tensor<Context> Y_transposed_buffer_;

// Dims devices.
Tensor<Context> input_dims_device_;
Tensor<Context> X_dims_device_;
Tensor<Context> column_dims_device_;
Tensor<Context> column_transposed_dims_device_;
Tensor<Context> column_axes_device_;
Expand Down Expand Up @@ -134,21 +122,8 @@ class LocallyConnectedGradientOp final : public ConvPoolOpBase<Context> {
bool RunOnDeviceWithOrderNHWC() override;

private:
struct ShapeParams {
int N;
int C;
int M;
int input_image_size;
int output_image_size;
int kernel_dim;
std::vector<int> input_image_dims;
std::vector<int> column_dims;
std::vector<int> column_transposed_dims;
std::vector<int> dY_transposed_dims;
};

void RunOnDeviceWithOrderNCHWImpl(
const ShapeParams& shape,
const lc_op_util::ShapeParams& shape,
const T* X_data,
const T* filter_data,
const T* dY_data,
Expand All @@ -160,7 +135,7 @@ class LocallyConnectedGradientOp final : public ConvPoolOpBase<Context> {
Tensor<Context>* dY_transposed_buffer);

void RunOnDeviceWithOrderNHWCImpl(
const ShapeParams& shape,
const lc_op_util::ShapeParams& shape,
const T* X_data,
const T* filter_data,
const T* dY_data,
Expand All @@ -183,7 +158,7 @@ class LocallyConnectedGradientOp final : public ConvPoolOpBase<Context> {
const std::vector<int>& dY_dims,
std::vector<int>* dY_transposed_dims);

bool no_bias_;
const bool no_bias_;

Tensor<Context> bias_multiplier_;

Expand All @@ -193,7 +168,7 @@ class LocallyConnectedGradientOp final : public ConvPoolOpBase<Context> {
Tensor<Context> dY_transposed_buffer_;

// Dims devices.
Tensor<Context> input_dims_device_;
Tensor<Context> X_dims_device_;
Tensor<Context> column_dims_device_;
Tensor<Context> column_transposed_dims_device_;
Tensor<Context> column_axes_device_;
Expand Down
97 changes: 22 additions & 75 deletions caffe2/operators/locally_connected_op_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,65 +32,14 @@

namespace caffe2 {

namespace {

void SetColumnBufferShapeImpl(
const int N,
const int C,
const int kernel_dim,
const StorageOrder order,
const std::vector<int>& output_image_dims,
std::vector<int>* column_dims,
std::vector<int>* column_transposed_dims,
std::vector<int>* column_axes,
std::vector<int>* column_transposed_axes) {
const int n_column_dims = output_image_dims.size() + 2;
column_dims->resize(n_column_dims);
column_transposed_dims->resize(n_column_dims);
column_axes->resize(n_column_dims);
if (order == StorageOrder::NCHW) {
for (int i = 0; i < n_column_dims - 2; ++i) {
column_dims->at(i + 2) = output_image_dims[i];
column_transposed_dims->at(i) = output_image_dims[i];
column_axes->at(i) = i + 2;
}
column_dims->at(0) = N;
column_dims->at(1) = kernel_dim;
column_transposed_dims->at(n_column_dims - 2) = kernel_dim;
column_transposed_dims->at(n_column_dims - 1) = N;
column_axes->at(n_column_dims - 1) = 0;
column_axes->at(n_column_dims - 2) = 1;
} else {
for (int i = 0; i < n_column_dims - 2; ++i) {
column_dims->at(i + 1) = output_image_dims[i];
column_transposed_dims->at(i) = output_image_dims[i];
column_axes->at(i) = i + 1;
}
column_dims->at(0) = N;
column_dims->at(n_column_dims - 1) = kernel_dim;
column_transposed_dims->at(n_column_dims - 2) = N;
column_transposed_dims->at(n_column_dims - 1) = kernel_dim;
column_axes->at(n_column_dims - 2) = 0;
column_axes->at(n_column_dims - 1) = n_column_dims - 1;
}
if (column_transposed_axes != nullptr) {
column_transposed_axes->resize(n_column_dims);
for (int i = 0; i < n_column_dims; ++i) {
column_transposed_axes->at(column_axes->at(i)) = i;
}
}
}

} // namespace

template <typename T, class Context>
bool LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNCHW() {
const auto& X = Input(INPUT);
const auto& filter = Input(FILTER);
auto* Y = Output(0);
const int image_ndim = X.ndim() - 2;
CAFFE_ENFORCE_EQ(X.ndim() + image_ndim, filter.ndim());
ShapeParams shape;
lc_op_util::ShapeParams shape;
shape.N = X.dim32(0);
shape.C = X.dim32(1);
shape.M = filter.dim32(image_ndim);
Expand Down Expand Up @@ -123,7 +72,7 @@ bool LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNCHW() {

shape.input_image_dims = GetDims(X);
const std::vector<int> input_dims(X.dims().cbegin() + 1, X.dims().cend());
SetDeviceTensor(input_dims, &input_dims_device_);
SetDeviceTensor(input_dims, &X_dims_device_);
shape.kernel_dim = shape.C / group_ * kernel_dims_size;

const std::vector<int> Y_dims(Y->dims().cbegin(), Y->dims().cend());
Expand Down Expand Up @@ -176,7 +125,7 @@ bool LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNHWC() {
"Only 2d locally connected op is supported for NHWC storage type.");
const int image_ndim = X.ndim() - 2;
CAFFE_ENFORCE_EQ(X.ndim() + image_ndim, filter.ndim());
ShapeParams shape;
lc_op_util::ShapeParams shape;
shape.N = X.dim32(0);
shape.C = X.dim32(3);
shape.input_image_dims = {X.dim32(1), X.dim32(2)};
Expand Down Expand Up @@ -235,7 +184,7 @@ bool LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNHWC() {

template <typename T, class Context>
void LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNCHWImpl(
const ShapeParams& shape,
const lc_op_util::ShapeParams& shape,
const T* X_data,
const T* filter_data,
const T* bias_data,
Expand Down Expand Up @@ -274,7 +223,7 @@ void LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNCHWImpl(
} else {
math::Im2colNd<T, Context, StorageOrder::NCHW>(
X_data + group_id * input_stride,
input_dims_device_.template data<int>(),
X_dims_device_.template data<int>(),
column_dims_device_.template data<int>() + 1,
shape.C * shape.input_image_size,
column_stride,
Expand Down Expand Up @@ -339,7 +288,7 @@ void LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNCHWImpl(

template <typename T, class Context>
void LocallyConnectedOp<T, Context>::RunOnDeviceWithOrderNHWCImpl(
const ShapeParams& shape,
const lc_op_util::ShapeParams& shape,
const T* X_data,
const T* filter_data,
const T* bias_data,
Expand Down Expand Up @@ -429,9 +378,8 @@ void LocallyConnectedOp<T, Context>::SetColumnBufferShape(
std::vector<int>* column_dims,
std::vector<int>* column_transposed_dims) {
std::vector<int> column_axes;
SetColumnBufferShapeImpl(
lc_op_util::SetColumnBufferShapeImpl(
N,
C,
kernel_dim,
order_,
output_image_dims,
Expand Down Expand Up @@ -484,7 +432,7 @@ bool LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNCHW() {
const int image_ndim = X.ndim() - 2;
CAFFE_ENFORCE_EQ(X.ndim() + image_ndim, filter.ndim());

ShapeParams shape;
lc_op_util::ShapeParams shape;
shape.N = X.dim32(0);
shape.C = X.dim32(1);
shape.M = filter.dim32(image_ndim);
Expand All @@ -506,8 +454,8 @@ bool LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNCHW() {
kernel_dims_size *= kernel_[i];
}

const std::vector<int> input_dims(X.dims().cbegin() + 1, X.dims().cend());
SetDeviceTensor(input_dims, &input_dims_device_);
const std::vector<int> X_dims(X.dims().cbegin() + 1, X.dims().cend());
SetDeviceTensor(X_dims, &X_dims_device_);
shape.kernel_dim = shape.C / group_ * kernel_dims_size;

const std::vector<int> dY_dims(dY.dims().cbegin(), dY.dims().cend());
Expand All @@ -518,7 +466,7 @@ bool LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNCHW() {
output_image_dims,
&shape.column_dims,
&shape.column_transposed_dims);
SetDYTranposedBufferShape(dY_dims, &shape.dY_transposed_dims);
SetDYTranposedBufferShape(dY_dims, &shape.Y_transposed_dims);

dfilter->ResizeLike(filter);
const T* X_data = X.template data<T>();
Expand Down Expand Up @@ -568,7 +516,7 @@ bool LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNHWC() {
"Only 2d locally connected op is supported for NHWC storage type.");
const int image_ndim = X.ndim() - 2;
CAFFE_ENFORCE_EQ(X.ndim() + image_ndim, filter.ndim());
ShapeParams shape;
lc_op_util::ShapeParams shape;
shape.N = X.dim32(0);
shape.C = X.dim32(3);
shape.input_image_dims = {X.dim32(1), X.dim32(2)};
Expand All @@ -594,7 +542,7 @@ bool LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNHWC() {
output_image_dims,
&shape.column_dims,
&shape.column_transposed_dims);
SetDYTranposedBufferShape(dY_dims, &shape.dY_transposed_dims);
SetDYTranposedBufferShape(dY_dims, &shape.Y_transposed_dims);

dfilter->ResizeLike(filter);
const T* X_data = X.template data<T>();
Expand Down Expand Up @@ -634,7 +582,7 @@ bool LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNHWC() {

template <typename T, class Context>
void LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNCHWImpl(
const ShapeParams& shape,
const lc_op_util::ShapeParams& shape,
const T* X_data,
const T* filter_data,
const T* dY_data,
Expand All @@ -648,7 +596,7 @@ void LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNCHWImpl(
const int column_stride = shape.kernel_dim * shape.output_image_size;
column_buffer->Resize(shape.column_dims);
column_transposed_buffer->Resize(shape.column_transposed_dims);
dY_transposed_buffer->Resize(shape.dY_transposed_dims);
dY_transposed_buffer->Resize(shape.Y_transposed_dims);
T* column_buffer_data = column_buffer->template mutable_data<T>();
T* dY_transposed_buffer_data =
dY_transposed_buffer->template mutable_data<T>();
Expand Down Expand Up @@ -676,7 +624,7 @@ void LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNCHWImpl(
} else {
math::Im2colNd<T, Context, StorageOrder::NCHW>(
X_data + group_id * input_stride,
input_dims_device_.template data<int>(),
X_dims_device_.template data<int>(),
column_dims_device_.template data<int>() + 1,
shape.C * shape.input_image_size,
column_stride,
Expand All @@ -703,7 +651,7 @@ void LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNCHWImpl(
&context_);

math::Transpose(
shape.dY_transposed_dims.size(),
shape.Y_transposed_dims.size(),
dY_dims_device_.template data<int>(),
dY_transposed_dims_device_.template data<int>(),
dY_axes_device_.template data<int>(),
Expand Down Expand Up @@ -789,7 +737,7 @@ void LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNCHWImpl(
} else {
math::Col2imNd<T, Context, StorageOrder::NCHW>(
const_column_buffer_data + group_id * column_stride,
input_dims_device_.template data<int>(),
X_dims_device_.template data<int>(),
column_dims_device_.template data<int>() + 1,
shape.C * shape.input_image_size,
column_stride,
Expand All @@ -810,7 +758,7 @@ void LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNCHWImpl(

template <typename T, class Context>
void LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNHWCImpl(
const ShapeParams& shape,
const lc_op_util::ShapeParams& shape,
const T* X_data,
const T* filter_data,
const T* dY_data,
Expand All @@ -824,7 +772,7 @@ void LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNHWCImpl(
const int column_stride = shape.kernel_dim * shape.output_image_size;
column_buffer->Resize(shape.column_dims);
column_transposed_buffer->Resize(shape.column_transposed_dims);
dY_transposed_buffer->Resize(shape.dY_transposed_dims);
dY_transposed_buffer->Resize(shape.Y_transposed_dims);
T* column_buffer_data = column_buffer->template mutable_data<T>();
T* dY_transposed_buffer_data =
dY_transposed_buffer->template mutable_data<T>();
Expand Down Expand Up @@ -858,7 +806,7 @@ void LocallyConnectedGradientOp<T, Context>::RunOnDeviceWithOrderNHWCImpl(
&context_);

math::Transpose(
shape.dY_transposed_dims.size(),
shape.Y_transposed_dims.size(),
dY_dims_device_.template data<int>(),
dY_transposed_dims_device_.template data<int>(),
dY_axes_device_.template data<int>(),
Expand Down Expand Up @@ -955,9 +903,8 @@ void LocallyConnectedGradientOp<T, Context>::SetColumnBufferShape(
std::vector<int>* column_transposed_dims) {
std::vector<int> column_axes;
std::vector<int> column_transposed_axes;
SetColumnBufferShapeImpl(
lc_op_util::SetColumnBufferShapeImpl(
N,
C,
kernel_dim,
order_,
output_image_dims,
Expand Down

0 comments on commit f76fc6f

Please sign in to comment.