Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Metal] Enable optimize_for_mobile on Linux
Currently, the optimize_for_mobile binary only works on macOS, which is not very convenient to use. This diff introduces a new buck target that separates out the objective-c code. The goal here is to be able to export models for metal on linux machines. Differential Revision: [D24322017](https://our.internmc.facebook.com/intern/diff/D24322017/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D24322017/)! [ghstack-poisoned]
- Loading branch information
Showing
7 changed files
with
158 additions
and
164 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
#include <ATen/core/op_registration/op_registration.h> | ||
#include <ATen/native/metal/MetalPrepackOpContext.h> | ||
|
||
#if defined(C10_IOS) | ||
#import <ATen/native/metal/MetalUtils.h> | ||
#import <ATen/native/metal/mpscnn/MPSCNNOps.h> | ||
#endif | ||
|
||
namespace at { | ||
namespace native { | ||
namespace metal { | ||
|
||
c10::intrusive_ptr<Conv2dOpContext> unpack( | ||
Tensor&& weight, | ||
c10::optional<Tensor>&& bias, | ||
std::vector<int64_t>&& stride, | ||
std::vector<int64_t>&& padding, | ||
std::vector<int64_t>&& dilation, | ||
int64_t groups, | ||
c10::optional<Scalar> output_min, | ||
c10::optional<Scalar> output_max) { | ||
#if defined(C10_IOS) | ||
const Tensor weightContig = weight.contiguous(); | ||
const auto ws = weightContig.sizes(); | ||
auto packed_buffer = permuteWeights(weightContig.data_ptr<float>(), ws.vec()); | ||
auto packedWeight = at::empty(ws); | ||
int64_t size_bytes = at::prod_intlist(ws) * sizeof(float); | ||
memcpy(packedWeight.data_ptr(), packed_buffer.data(), size_bytes); | ||
return c10::make_intrusive<Conv2dOpContext>( | ||
std::move(packedWeight), | ||
std::move(bias), | ||
stride, | ||
padding, | ||
dilation, | ||
groups, | ||
output_min, | ||
output_max); | ||
#else | ||
TORCH_CHECK(false, "unpack can only be invoked on iOS") | ||
return c10::make_intrusive<Conv2dOpContext>( | ||
std::move(weight), | ||
std::move(bias), | ||
stride, | ||
padding, | ||
dilation, | ||
groups, | ||
output_min, | ||
output_max); | ||
#endif | ||
} | ||
|
||
TORCH_LIBRARY(metal, m) { | ||
m.class_<Conv2dOpContext>("Conv2dOpContext") | ||
.def_pickle( | ||
[](const c10::intrusive_ptr<Conv2dOpContext>& op_context) | ||
-> SerializationTypeConv2dPrePack { // __getstate__ | ||
return op_context->pack(); | ||
}, | ||
[](SerializationTypeConv2dPrePack state) | ||
-> c10::intrusive_ptr<Conv2dOpContext> { // __setstate__ | ||
return unpack( | ||
std::move(std::get<0>(state)), | ||
std::move(std::get<1>(state)), | ||
std::move(std::get<2>(state)), | ||
std::move(std::get<3>(state)), | ||
std::move(std::get<4>(state)), | ||
std::move(std::get<5>(state)), | ||
std::move(std::get<6>(state)), | ||
std::move(std::get<7>(state))); | ||
}); | ||
m.def("copy_to_host(Tensor X) -> Tensor Y"); | ||
} | ||
|
||
TORCH_LIBRARY(metal_prepack, m) { | ||
m.def( | ||
"conv2d_prepack(Tensor W, Tensor? B, int[2] stride, " | ||
"int[2] padding, int[2] dilation, int groups, " | ||
"Scalar? output_min=None, Scalar? output_max=None) " | ||
"-> __torch__.torch.classes.metal.Conv2dOpContext"); | ||
m.def( | ||
"conv2d_run(Tensor X, " | ||
"__torch__.torch.classes.metal.Conv2dOpContext W_prepack) -> Tensor Y"); | ||
} | ||
|
||
c10::intrusive_ptr<Conv2dOpContext> conv2d_prepack( | ||
Tensor&& weight, | ||
c10::optional<Tensor>&& bias, | ||
std::vector<int64_t>&& stride, | ||
std::vector<int64_t>&& padding, | ||
std::vector<int64_t>&& dilation, | ||
int64_t groups, | ||
c10::optional<Scalar> output_min, | ||
c10::optional<Scalar> output_max) { | ||
TORCH_CHECK(weight.dim() == 4); | ||
return c10::make_intrusive<Conv2dOpContext>( | ||
std::move(weight), | ||
std::move(bias), | ||
stride, | ||
padding, | ||
dilation, | ||
groups, | ||
output_min, | ||
output_max); | ||
} | ||
|
||
Tensor conv2d_prepack_run( | ||
const Tensor& input, | ||
const c10::intrusive_ptr<Conv2dOpContext>& op_context) { | ||
#if defined(C10_IOS) | ||
return mpscnn::conv2d(input, *op_context); | ||
#else | ||
TORCH_CHECK(false, "conv2d_prepack_run can only be invoked on iOS"); | ||
return input; | ||
#endif | ||
} | ||
|
||
Tensor copy_to_host(const Tensor& input) { | ||
#if defined(C10_IOS) | ||
return mpscnn::copy_to_host(input); | ||
#else | ||
TORCH_CHECK(false, "copy_to_host can only be invoked on iOS"); | ||
return input; | ||
#endif | ||
} | ||
|
||
TORCH_LIBRARY_IMPL(metal_prepack, CPU, m) { | ||
m.impl("conv2d_prepack", TORCH_FN(conv2d_prepack)); | ||
} | ||
|
||
TORCH_LIBRARY_IMPL(metal_prepack, Metal, m) { | ||
m.impl("conv2d_run", conv2d_prepack_run); | ||
} | ||
|
||
TORCH_LIBRARY_IMPL(metal, Metal, m) { | ||
m.impl("copy_to_host", copy_to_host); | ||
} | ||
|
||
} // namespace metal | ||
} // namespace native | ||
} // namespace at |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters