Skip to content

Commit 398e8ba

Browse files
ppwwyyxxwat3rBro
authored andcommitted
Include two caffe2 ops in v1.4.0 (#31716)
* move AliasWithNameOp to caffe2/operators Summary: Pull Request resolved: #31281 Reviewed By: houseroad Differential Revision: D19053453 fbshipit-source-id: 350bfd5c001db9c17916dcae7ade8f56db1e9841 * move BatchPermutationOp to caffe2/operators Summary: Pull Request resolved: #31350 Reviewed By: houseroad Differential Revision: D19053527 fbshipit-source-id: 50d11f137d0f5c07e8ad899a3a84d56a042bbc32 Co-authored-by: wat3rBro <wangyanghan6@gmail.com>
1 parent 074b30c commit 398e8ba

12 files changed

+711
-296
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#include "caffe2/operators/alias_with_name.h"
2+
3+
namespace caffe2 {
4+
5+
REGISTER_CPU_OPERATOR(AliasWithName, AliasWithNameOp<CPUContext>);
6+
7+
OPERATOR_SCHEMA(AliasWithName)
8+
.NumInputs(1)
9+
.NumOutputs(1)
10+
.AllowInplace({{0, 0}})
11+
.IdenticalTypeAndShape()
12+
.SetDoc(R"DOC(
13+
Similar with AliasOp, storing the alias name as operator argument.
14+
)DOC")
15+
.Arg("name", "name of the aliasing")
16+
.Arg("is_backward", "weather or not to alias forward or backward")
17+
.Input(0, "input", "Input tensor whose storage will be shared.")
18+
.Output(0, "output", "Tensor of same shape as input, sharing its storage.");
19+
20+
} // namespace caffe2
21+
22+
C10_EXPORT_CAFFE2_OP_TO_C10_CPU(
23+
AliasWithName,
24+
"_caffe2::AliasWithName(Tensor input, str name, bool is_backward = False) -> (Tensor output)",
25+
caffe2::AliasWithNameOp<caffe2::CPUContext>);
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#include "caffe2/core/context_gpu.h"
2+
#include "caffe2/operators/alias_with_name.h"
3+
4+
namespace caffe2 {
5+
6+
REGISTER_CUDA_OPERATOR(AliasWithName, AliasWithNameOp<CUDAContext>);
7+
8+
} // namespace caffe2
9+
10+
C10_EXPORT_CAFFE2_OP_TO_C10_CUDA(
11+
AliasWithName,
12+
caffe2::AliasWithNameOp<caffe2::CUDAContext>);

caffe2/operators/alias_with_name.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#ifndef ALIAS_WITH_NAME_OP_H_
2+
#define ALIAS_WITH_NAME_OP_H_
3+
4+
#include "caffe2/core/context.h"
5+
#include "caffe2/core/export_caffe2_op_to_c10.h"
6+
#include "caffe2/core/operator.h"
7+
8+
C10_DECLARE_EXPORT_CAFFE2_OP_TO_C10(AliasWithName)
9+
10+
namespace caffe2 {
11+
12+
template <class Context>
13+
class AliasWithNameOp final : public Operator<Context> {
14+
public:
15+
USE_OPERATOR_CONTEXT_FUNCTIONS;
16+
template <class... Args>
17+
explicit AliasWithNameOp(Args&&... args)
18+
: Operator<Context>(std::forward<Args>(args)...),
19+
name_(this->template GetSingleArgument<std::string>(
20+
"name",
21+
"invalid_name")),
22+
is_backward_(
23+
this->template GetSingleArgument<bool>("is_backward", false)) {
24+
CAFFE_ENFORCE(
25+
OperatorBase::HasArgument("name"), "You have to specify argument name");
26+
}
27+
28+
bool RunOnDevice() override {
29+
auto& input = Input(0);
30+
CAFFE_ENFORCE_GE(input.numel(), 0, "Tensor is not initialized");
31+
32+
// This doesn't work anymore as this is "newstyle" operator
33+
// OutputTensorAlias(0, input);
34+
35+
OperatorBase::SetOutputTensor(0, input.Alias());
36+
return true;
37+
}
38+
39+
protected:
40+
std::string name_;
41+
bool is_backward_;
42+
};
43+
44+
} // namespace caffe2
45+
46+
#endif // ALIAS_WITH_NAME_OP_H_
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
#include "caffe2/operators/batch_permutation_op.h"
2+
3+
#include <cstring>
4+
#include <vector>
5+
6+
#ifdef CAFFE2_USE_MKLDNN
7+
#include <caffe2/ideep/operators/operator_fallback_ideep.h>
8+
#include <caffe2/ideep/utils/ideep_operator.h>
9+
#endif
10+
11+
namespace caffe2 {
12+
13+
template <bool forwards>
14+
void batch_permutation_loop(
15+
const int N,
16+
const int K,
17+
const float* src,
18+
const int* indices,
19+
float* dst) {
20+
long numBytes = K * sizeof(float);
21+
if (forwards) {
22+
#ifdef _OPENMP
23+
#if (_OPENMP >= 201307)
24+
#pragma omp parallel for simd
25+
#else
26+
#pragma omp parallel for
27+
#endif
28+
#endif
29+
for (int n = 0; n < N; n++) {
30+
int origIdx = n * K;
31+
int permuteIdx = indices[n] * K;
32+
std::memcpy(dst + origIdx, src + permuteIdx, numBytes);
33+
}
34+
} else {
35+
std::vector<int> backward_indices(N);
36+
for (size_t i = 0; i < N; ++i) {
37+
backward_indices[indices[i]] = i;
38+
}
39+
for (int n = 0; n < N; n++) {
40+
int permuteIdx = n * K;
41+
int origIdx = backward_indices[n] * K;
42+
std::memcpy(dst + permuteIdx, src + origIdx, numBytes);
43+
}
44+
}
45+
}
46+
47+
template <>
48+
bool BatchPermutationOp<float, CPUContext>::RunOnDevice() {
49+
auto& X = Input(0);
50+
auto& indices = Input(1);
51+
52+
CAFFE_ENFORCE(indices.dim() == 1, "indices must be 1-d");
53+
CAFFE_ENFORCE(
54+
X.dim32(0) == indices.dim32(0),
55+
"X.dim32(0) must be equal to indices.dim32(0)",
56+
"(",
57+
X.dim32(0),
58+
" vs. ",
59+
indices.dim32(0),
60+
")");
61+
62+
auto* Y = Output(0, X.sizes(), at::dtype<float>());
63+
64+
CAFFE_ENFORCE_GT(X.dim32(0), 0);
65+
batch_permutation_loop<true>(
66+
X.dim32(0),
67+
X.numel() / X.dim32(0),
68+
X.data<float>(),
69+
indices.data<int>(),
70+
Y->mutable_data<float>());
71+
return true;
72+
}
73+
74+
template <>
75+
bool BatchPermutationGradientOp<float, CPUContext>::RunOnDevice() {
76+
auto& indices = Input(0);
77+
auto& dY = Input(1);
78+
79+
auto* dX = Output(0, dY.sizes(), at::dtype<float>());
80+
81+
CAFFE_ENFORCE_GT(dY.dim32(0), 0);
82+
batch_permutation_loop<false>(
83+
dY.dim32(0),
84+
dY.numel() / dY.dim32(0),
85+
dY.data<float>(),
86+
indices.data<int>(),
87+
dX->mutable_data<float>());
88+
return true;
89+
}
90+
91+
#ifdef CAFFE2_USE_MKLDNN
92+
REGISTER_IDEEP_OPERATOR(
93+
BatchPermutation,
94+
IDEEPFallbackOp<BatchPermutationOp<float, CPUContext>>);
95+
#endif
96+
97+
REGISTER_CPU_OPERATOR(BatchPermutation, BatchPermutationOp<float, CPUContext>);
98+
REGISTER_CPU_OPERATOR(
99+
BatchPermutationGradient,
100+
BatchPermutationGradientOp<float, CPUContext>);
101+
102+
// Input: X, indices; Output: Y
103+
OPERATOR_SCHEMA(BatchPermutation)
104+
.NumInputs(2)
105+
.NumOutputs(1)
106+
.SetDoc(R"DOC(
107+
Batch permutation of an input tensor X given input indices. First dimension of
108+
X equals batch size N. The indices stores a be permutation of N.
109+
The output Y is a tensor of same shape as X, with data re-ordered according to
110+
the indices within the batch size.
111+
112+
Example of batch permutation on a 2-D tensor with batch size 4:
113+
X = [
114+
[1, 5, 2, 3, 4, 6, 0],
115+
[4, 3, 3, 5, 2, 3, 1],
116+
[2, 2, 3, 6, 0, 0, 1],
117+
[0, 0, 1, 1, 2, 2, 3]
118+
]
119+
indices = [2, 0, 1, 3]
120+
Y = [
121+
[2, 2, 3, 6, 0, 0, 1],
122+
[1, 5, 2, 3, 4, 6, 0],
123+
[4, 3, 3, 5, 2, 3, 1],
124+
[0, 0, 1, 1, 2, 2, 3]
125+
]
126+
127+
Example of batch permutation on a 3-D tensor with batch size 4:
128+
X = [
129+
[[1, 5, 2], [3, 4, 6, 0]],
130+
[[4, 3, 3], [5, 2, 3, 1]],
131+
[[2, 2, 3], [6, 0, 0, 1]],
132+
[[0, 0, 1], [1, 2, 2, 3]]
133+
]
134+
indices = [2, 0, 1, 3]
135+
Y = [
136+
[[2, 2, 3], [6, 0, 0, 1]],
137+
[[1, 5, 2], [3, 4, 6, 0]],
138+
[[4, 3, 3], [5, 2, 3, 1]],
139+
[[0, 0, 1], [1, 2, 2, 3]]
140+
]
141+
)DOC")
142+
.Input(0, "X", "Input tensor, where 1st dimension equals batch size")
143+
.Input(1, "indices", "Input indices of batch to permute")
144+
.Output(0, "Y", "Output permuted tensor");
145+
// Input: indices, dY (aka "gradOutput"); Output: dX (aka "gradInput")
146+
OPERATOR_SCHEMA(BatchPermutationGradient).NumInputs(2).NumOutputs(1);
147+
148+
class GetBatchPermutationGradient : public GradientMakerBase {
149+
using GradientMakerBase::GradientMakerBase;
150+
vector<OperatorDef> GetGradientDefs() override {
151+
return SingleGradientDef(
152+
"BatchPermutationGradient",
153+
"",
154+
vector<string>{I(1), GO(0)},
155+
vector<string>{GI(0)});
156+
}
157+
};
158+
159+
REGISTER_GRADIENT(BatchPermutation, GetBatchPermutationGradient);
160+
161+
} // namespace caffe2
162+
163+
using BatchPermutationOpFloatCPU =
164+
caffe2::BatchPermutationOp<float, caffe2::CPUContext>;
165+
166+
C10_EXPORT_CAFFE2_OP_TO_C10_CPU(
167+
BatchPermutation,
168+
"_caffe2::BatchPermutation(Tensor X, Tensor indices) -> Tensor",
169+
BatchPermutationOpFloatCPU);
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
#include "caffe2/core/context_gpu.h"
2+
#include "caffe2/operators/batch_permutation_op.h"
3+
4+
namespace caffe2 {
5+
6+
namespace {
7+
template <bool forward>
8+
__global__ void BatchPermutationKernel(
9+
int N,
10+
int K,
11+
const float* src,
12+
const int* indices,
13+
float* dst) {
14+
if (forward) {
15+
CUDA_1D_KERNEL_LOOP(index, N * K) {
16+
int k = index % K;
17+
int n = index / K;
18+
int idx = indices[n];
19+
CUDA_KERNEL_ASSERT(idx >= 0);
20+
CUDA_KERNEL_ASSERT(idx < N);
21+
dst[index] = src[idx * K + k];
22+
}
23+
} else {
24+
CUDA_1D_KERNEL_LOOP(index, N * K) {
25+
int k = index % K;
26+
int n = index / K;
27+
28+
// NOTE: an alternative implementation if we want to align the index with
29+
// the output tensor (rather than the input tensor).
30+
// int idx = -1;
31+
// for (size_t i = 0; i < N; ++i) {
32+
// if (indices[i] == n) {
33+
// idx = i;
34+
// }
35+
// }
36+
// CUDA_KERNEL_ASSERT(idx >= 0);
37+
// CUDA_KERNEL_ASSERT(idx < N);
38+
// dst[index] = src[idx * K + k];
39+
40+
int idx = indices[n];
41+
CUDA_KERNEL_ASSERT(idx >= 0);
42+
CUDA_KERNEL_ASSERT(idx < N);
43+
dst[idx * K + k] = src[index];
44+
}
45+
}
46+
}
47+
} // namespace
48+
49+
template <>
50+
bool BatchPermutationOp<float, CUDAContext>::RunOnDevice() {
51+
auto& X = Input(0);
52+
auto& indices = Input(1);
53+
54+
CAFFE_ENFORCE(indices.dim() == 1, "indices must be 1-d");
55+
CAFFE_ENFORCE(
56+
X.dim32(0) == indices.dim32(0),
57+
"X.dim32(0) must be equal to indices.dim32(0)",
58+
"(",
59+
X.dim32(0),
60+
" vs. ",
61+
indices.dim32(0),
62+
")");
63+
64+
auto* Y = Output(0, X.sizes(), at::dtype<float>());
65+
66+
CAFFE_ENFORCE_GT(X.dim32(0), 0);
67+
BatchPermutationKernel<true>
68+
<<<CAFFE_GET_BLOCKS(X.numel()),
69+
CAFFE_CUDA_NUM_THREADS,
70+
0,
71+
context_.cuda_stream()>>>(
72+
X.dim32(0),
73+
X.numel() / X.dim32(0),
74+
X.data<float>(),
75+
indices.data<int>(),
76+
Y->mutable_data<float>());
77+
78+
return true;
79+
}
80+
81+
template <>
82+
bool BatchPermutationGradientOp<float, CUDAContext>::RunOnDevice() {
83+
auto& indices = Input(0);
84+
auto& dY = Input(1);
85+
auto* dX = Output(0, dY.sizes(), at::dtype<float>());
86+
87+
CAFFE_ENFORCE_GT(dY.dim32(0), 0);
88+
BatchPermutationKernel<false>
89+
<<<CAFFE_GET_BLOCKS(dY.numel()),
90+
CAFFE_CUDA_NUM_THREADS,
91+
0,
92+
context_.cuda_stream()>>>(
93+
dY.dim32(0),
94+
dY.numel() / dY.dim32(0),
95+
dY.data<float>(),
96+
indices.data<int>(),
97+
dX->mutable_data<float>());
98+
99+
return true;
100+
}
101+
102+
REGISTER_CUDA_OPERATOR(
103+
BatchPermutation,
104+
BatchPermutationOp<float, CUDAContext>);
105+
REGISTER_CUDA_OPERATOR(
106+
BatchPermutationGradient,
107+
BatchPermutationGradientOp<float, CUDAContext>);
108+
} // namespace caffe2
109+
110+
using BatchPermutationOpFloatCUDA =
111+
caffe2::BatchPermutationOp<float, caffe2::CUDAContext>;
112+
113+
C10_EXPORT_CAFFE2_OP_TO_C10_CUDA(BatchPermutation, BatchPermutationOpFloatCUDA);

0 commit comments

Comments
 (0)