-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
MetalPrepackOpRegister.cpp
127 lines (115 loc) · 3.64 KB
/
MetalPrepackOpRegister.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/native/metal/MetalPrepackOpContext.h>
#include <torch/script.h>
#if defined(C10_IOS)
#import <ATen/native/metal/mpscnn/MPSCNNOps.h>
#endif
namespace at {
namespace native {
namespace metal {
c10::intrusive_ptr<Conv2dOpContext> unpack(
Tensor&& weight,
c10::optional<Tensor>&& bias,
std::vector<int64_t>&& stride,
std::vector<int64_t>&& padding,
std::vector<int64_t>&& dilation,
int64_t groups,
c10::optional<Scalar> output_min,
c10::optional<Scalar> output_max) {
const Tensor weightContig = weight.contiguous();
const auto ws = weightContig.sizes();
auto packed_buffer = permuteWeights(weightContig.data_ptr<float>(), ws.vec());
auto packedWeight = at::empty(ws);
int64_t size_bytes = at::prod_intlist(ws) * sizeof(float);
memcpy(packedWeight.data_ptr(), packed_buffer.data(), size_bytes);
return c10::make_intrusive<Conv2dOpContext>(
std::move(packedWeight),
std::move(bias),
stride,
padding,
dilation,
groups,
output_min,
output_max);
}
TORCH_LIBRARY(metal, m) {
m.class_<Conv2dOpContext>("Conv2dOpContext")
.def_pickle(
[](const c10::intrusive_ptr<Conv2dOpContext>& op_context)
-> SerializationTypeConv2dPrePack { // __getstate__
return op_context->pack();
},
[](SerializationTypeConv2dPrePack state)
-> c10::intrusive_ptr<Conv2dOpContext> { // __setstate__
return unpack(
std::move(std::get<0>(state)),
std::move(std::get<1>(state)),
std::move(std::get<2>(state)),
std::move(std::get<3>(state)),
std::move(std::get<4>(state)),
std::move(std::get<5>(state)),
std::move(std::get<6>(state)),
std::move(std::get<7>(state)));
});
m.def("copy_to_host(Tensor X) -> Tensor Y");
}
TORCH_LIBRARY(metal_prepack, m) {
m.def(
"conv2d_prepack(Tensor W, Tensor? B, int[2] stride, "
"int[2] padding, int[2] dilation, int groups, "
"Scalar? output_min=None, Scalar? output_max=None) "
"-> __torch__.torch.classes.metal.Conv2dOpContext");
m.def(
"conv2d_run(Tensor X, "
"__torch__.torch.classes.metal.Conv2dOpContext W_prepack) -> Tensor Y");
}
c10::intrusive_ptr<Conv2dOpContext> conv2d_prepack(
Tensor&& weight,
c10::optional<Tensor>&& bias,
std::vector<int64_t>&& stride,
std::vector<int64_t>&& padding,
std::vector<int64_t>&& dilation,
int64_t groups,
c10::optional<Scalar> output_min,
c10::optional<Scalar> output_max) {
TORCH_CHECK(weight.dim() == 4);
return c10::make_intrusive<Conv2dOpContext>(
std::move(weight),
std::move(bias),
stride,
padding,
dilation,
groups,
output_min,
output_max);
}
Tensor conv2d_prepack_run(
const Tensor& input,
const c10::intrusive_ptr<Conv2dOpContext>& op_context) {
#if defined(C10_IOS)
return mpscnn::conv2d(input, *op_context);
#else
TORCH_CHECK(false, "conv2d_prepack_run can only be invoked on iOS");
return input;
#endif
}
Tensor copy_to_host(const Tensor& input) {
#if defined(C10_IOS)
return mpscnn::copy_to_host(input);
#else
TORCH_CHECK(false, "copy_to_host can only be invoked on iOS");
return input;
#endif
}
TORCH_LIBRARY_IMPL(metal_prepack, CPU, m) {
m.impl("conv2d_prepack", TORCH_FN(conv2d_prepack));
}
TORCH_LIBRARY_IMPL(metal_prepack, Metal, m) {
m.impl("conv2d_run", conv2d_prepack_run);
}
TORCH_LIBRARY_IMPL(metal, Metal, m) {
m.impl("copy_to_host", copy_to_host);
}
} // namespace metal
} // namespace native
} // namespace at