-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
MetalPrepackOpRegister.cpp
140 lines (128 loc) · 3.93 KB
/
MetalPrepackOpRegister.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/native/metal/MetalPrepackOpContext.h>
#if defined(C10_IOS)
#import <ATen/native/metal/MetalUtils.h>
#import <ATen/native/metal/mpscnn/MPSCNNOps.h>
#endif
namespace at {
namespace native {
namespace metal {
c10::intrusive_ptr<Conv2dOpContext> unpack(
Tensor&& weight,
c10::optional<Tensor>&& bias,
std::vector<int64_t>&& stride,
std::vector<int64_t>&& padding,
std::vector<int64_t>&& dilation,
int64_t groups,
c10::optional<Scalar> output_min,
c10::optional<Scalar> output_max) {
#if defined(C10_IOS)
const Tensor weightContig = weight.contiguous();
const auto ws = weightContig.sizes();
auto packed_buffer = permuteWeights(weightContig.data_ptr<float>(), ws.vec());
auto packedWeight = at::empty(ws);
int64_t size_bytes = at::prod_intlist(ws) * sizeof(float);
memcpy(packedWeight.data_ptr(), packed_buffer.data(), size_bytes);
return c10::make_intrusive<Conv2dOpContext>(
std::move(packedWeight),
std::move(bias),
stride,
padding,
dilation,
groups,
output_min,
output_max);
#else
TORCH_CHECK(false, "unpack can only be invoked on iOS")
return c10::make_intrusive<Conv2dOpContext>(
std::move(weight),
std::move(bias),
stride,
padding,
dilation,
groups,
output_min,
output_max);
#endif
}
TORCH_LIBRARY(metal, m) {
m.class_<Conv2dOpContext>("Conv2dOpContext")
.def_pickle(
[](const c10::intrusive_ptr<Conv2dOpContext>& op_context)
-> SerializationTypeConv2dPrePack { // __getstate__
return op_context->pack();
},
[](SerializationTypeConv2dPrePack state)
-> c10::intrusive_ptr<Conv2dOpContext> { // __setstate__
return unpack(
std::move(std::get<0>(state)),
std::move(std::get<1>(state)),
std::move(std::get<2>(state)),
std::move(std::get<3>(state)),
std::move(std::get<4>(state)),
std::move(std::get<5>(state)),
std::move(std::get<6>(state)),
std::move(std::get<7>(state)));
});
m.def("copy_to_host(Tensor X) -> Tensor Y");
}
TORCH_LIBRARY(metal_prepack, m) {
m.def(
"conv2d_prepack(Tensor W, Tensor? B, int[2] stride, "
"int[2] padding, int[2] dilation, int groups, "
"Scalar? output_min=None, Scalar? output_max=None) "
"-> __torch__.torch.classes.metal.Conv2dOpContext");
m.def(
"conv2d_run(Tensor X, "
"__torch__.torch.classes.metal.Conv2dOpContext W_prepack) -> Tensor Y");
}
c10::intrusive_ptr<Conv2dOpContext> conv2d_prepack(
Tensor&& weight,
c10::optional<Tensor>&& bias,
std::vector<int64_t>&& stride,
std::vector<int64_t>&& padding,
std::vector<int64_t>&& dilation,
int64_t groups,
c10::optional<Scalar> output_min,
c10::optional<Scalar> output_max) {
TORCH_CHECK(weight.dim() == 4);
return c10::make_intrusive<Conv2dOpContext>(
std::move(weight),
std::move(bias),
stride,
padding,
dilation,
groups,
output_min,
output_max);
}
Tensor conv2d_prepack_run(
const Tensor& input,
const c10::intrusive_ptr<Conv2dOpContext>& op_context) {
#if defined(C10_IOS)
return mpscnn::conv2d(input, *op_context);
#else
TORCH_CHECK(false, "conv2d_prepack_run can only be invoked on iOS");
return input;
#endif
}
Tensor copy_to_host(const Tensor& input) {
#if defined(C10_IOS)
return mpscnn::copy_to_host(input);
#else
TORCH_CHECK(false, "copy_to_host can only be invoked on iOS");
return input;
#endif
}
TORCH_LIBRARY_IMPL(metal_prepack, CPU, m) {
m.impl("conv2d_prepack", TORCH_FN(conv2d_prepack));
}
TORCH_LIBRARY_IMPL(metal_prepack, Metal, m) {
m.impl("conv2d_run", conv2d_prepack_run);
}
TORCH_LIBRARY_IMPL(metal, Metal, m) {
m.impl("copy_to_host", copy_to_host);
}
} // namespace metal
} // namespace native
} // namespace at