Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Metal] Enable optimize_for_mobile on Linux #46384

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 1 addition & 4 deletions aten/src/ATen/native/metal/MetalConvolution.h
@@ -1,6 +1,5 @@
#import <ATen/native/metal/MetalPrepackOpContext.h>
#import <ATen/native/metal/MetalUtils.h>
#import <ATen/native/metal/mpscnn/MPSCNNOp.h>

#include <torch/script.h>

namespace at {
Expand Down Expand Up @@ -49,8 +48,6 @@ struct Conv2DParams final {

NeuronType neuronType(const Conv2dOpContext& context);

Tensor conv2d_prepack_run_impl(Conv2dOpContext& context, const Tensor& input);

} // namespace metal
} // namespace native
} // namespace at
4 changes: 0 additions & 4 deletions aten/src/ATen/native/metal/MetalConvolution.mm
Expand Up @@ -60,10 +60,6 @@ NeuronType neuronType(const Conv2dOpContext& context) {
}
}

Tensor conv2d_prepack_run_impl(Conv2dOpContext& context, const Tensor& input) {
return mpscnn::conv2d(input, context);
}

} // namespace metal
} // namespace native
} // namespace at
60 changes: 33 additions & 27 deletions aten/src/ATen/native/metal/MetalPrepackOpContext.h
@@ -1,4 +1,4 @@
#import <Foundation/Foundation.h>
#pragma once

#include <ATen/Tensor.h>
#include <torch/custom_class.h>
Expand Down Expand Up @@ -49,6 +49,13 @@ class Conv2dOpContext : public torch::jit::CustomClassHolder {
output_min(output_min),
output_max(output_max) {}

void release_resources() override {
if (releaseCallback) {
releaseCallback(conv2dOp);
conv2dOp = nullptr;
}
}

Tensor weight;
c10::optional<Tensor> bias;
std::vector<int64_t> stride;
Expand All @@ -57,34 +64,33 @@ class Conv2dOpContext : public torch::jit::CustomClassHolder {
int64_t groups;
c10::optional<Scalar> output_min;
c10::optional<Scalar> output_max;
id extra = nil;
void* conv2dOp = nullptr; // reserved to hold MPSCNNConv2dOp objects
std::function<void(void*)> releaseCallback = nullptr;
};

c10::intrusive_ptr<Conv2dOpContext> unpack(
Tensor&& weight,
c10::optional<Tensor>&& bias,
std::vector<int64_t>&& stride,
std::vector<int64_t>&& padding,
std::vector<int64_t>&& dilation,
int64_t groups,
c10::optional<Scalar> output_min,
c10::optional<Scalar> output_max);

c10::intrusive_ptr<Conv2dOpContext> conv2d_prepack(
Tensor&& weight,
c10::optional<Tensor>&& bias,
std::vector<int64_t>&& stride,
std::vector<int64_t>&& padding,
std::vector<int64_t>&& dilation,
int64_t groups,
c10::optional<Scalar> output_min,
c10::optional<Scalar> output_max);

Tensor conv2d_prepack_run(
const Tensor& input,
const c10::intrusive_ptr<Conv2dOpContext>& op_context);

Tensor copy_to_host(const Tensor& input);
// The MPSCNNConvolution class takes weights in the order
// [outputChannels][kernelHeight][kernelWidth][inputChannels/groups].
static std::vector<float> permuteWeights(
const float* src,
const std::vector<int64_t>& sizes) {
const int64_t M = sizes[0];
const int64_t Cf = sizes[1];
const int64_t kH = sizes[2];
const int64_t kW = sizes[3];
std::vector<float> packedWeights(M * kH * kW * Cf);
for (auto m = 0; m < M; ++m) {
for (auto c = 0; c < Cf; ++c) {
for (auto kh = 0; kh < kH; ++kh) {
for (auto kw = 0; kw < kW; ++kw) {
int64_t oc = m * kH * kW * Cf + kh * kW * Cf + kw * Cf + c;
int64_t ic = m * Cf * kH * kW + c * kH * kW + kh * kW + kw;
packedWeights[oc] = src[ic];
}
}
}
}
return packedWeights;
}

} // namespace metal
} // namespace native
Expand Down
71 changes: 0 additions & 71 deletions aten/src/ATen/native/metal/MetalPrepackOpContext.mm

This file was deleted.

127 changes: 127 additions & 0 deletions aten/src/ATen/native/metal/MetalPrepackOpRegister.cpp
@@ -0,0 +1,127 @@
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/native/metal/MetalPrepackOpContext.h>
#include <torch/script.h>

#if defined(C10_IOS)
#import <ATen/native/metal/mpscnn/MPSCNNOps.h>
#endif

namespace at {
namespace native {
namespace metal {

c10::intrusive_ptr<Conv2dOpContext> unpack(
Tensor&& weight,
c10::optional<Tensor>&& bias,
std::vector<int64_t>&& stride,
std::vector<int64_t>&& padding,
std::vector<int64_t>&& dilation,
int64_t groups,
c10::optional<Scalar> output_min,
c10::optional<Scalar> output_max) {
const Tensor weightContig = weight.contiguous();
const auto ws = weightContig.sizes();
auto packed_buffer = permuteWeights(weightContig.data_ptr<float>(), ws.vec());
auto packedWeight = at::empty(ws);
int64_t size_bytes = at::prod_intlist(ws) * sizeof(float);
memcpy(packedWeight.data_ptr(), packed_buffer.data(), size_bytes);
return c10::make_intrusive<Conv2dOpContext>(
std::move(packedWeight),
std::move(bias),
stride,
padding,
dilation,
groups,
output_min,
output_max);
}

TORCH_LIBRARY(metal, m) {
m.class_<Conv2dOpContext>("Conv2dOpContext")
.def_pickle(
[](const c10::intrusive_ptr<Conv2dOpContext>& op_context)
-> SerializationTypeConv2dPrePack { // __getstate__
return op_context->pack();
},
[](SerializationTypeConv2dPrePack state)
-> c10::intrusive_ptr<Conv2dOpContext> { // __setstate__
return unpack(
std::move(std::get<0>(state)),
std::move(std::get<1>(state)),
std::move(std::get<2>(state)),
std::move(std::get<3>(state)),
std::move(std::get<4>(state)),
std::move(std::get<5>(state)),
std::move(std::get<6>(state)),
std::move(std::get<7>(state)));
});
m.def("copy_to_host(Tensor X) -> Tensor Y");
}

TORCH_LIBRARY(metal_prepack, m) {
m.def(
"conv2d_prepack(Tensor W, Tensor? B, int[2] stride, "
"int[2] padding, int[2] dilation, int groups, "
"Scalar? output_min=None, Scalar? output_max=None) "
"-> __torch__.torch.classes.metal.Conv2dOpContext");
m.def(
"conv2d_run(Tensor X, "
"__torch__.torch.classes.metal.Conv2dOpContext W_prepack) -> Tensor Y");
}

c10::intrusive_ptr<Conv2dOpContext> conv2d_prepack(
Tensor&& weight,
c10::optional<Tensor>&& bias,
std::vector<int64_t>&& stride,
std::vector<int64_t>&& padding,
std::vector<int64_t>&& dilation,
int64_t groups,
c10::optional<Scalar> output_min,
c10::optional<Scalar> output_max) {
TORCH_CHECK(weight.dim() == 4);
return c10::make_intrusive<Conv2dOpContext>(
std::move(weight),
std::move(bias),
stride,
padding,
dilation,
groups,
output_min,
output_max);
}

Tensor conv2d_prepack_run(
const Tensor& input,
const c10::intrusive_ptr<Conv2dOpContext>& op_context) {
#if defined(C10_IOS)
return mpscnn::conv2d(input, *op_context);
#else
TORCH_CHECK(false, "conv2d_prepack_run can only be invoked on iOS");
return input;
#endif
}

Tensor copy_to_host(const Tensor& input) {
#if defined(C10_IOS)
return mpscnn::copy_to_host(input);
#else
TORCH_CHECK(false, "copy_to_host can only be invoked on iOS");
return input;
#endif
}

TORCH_LIBRARY_IMPL(metal_prepack, CPU, m) {
m.impl("conv2d_prepack", TORCH_FN(conv2d_prepack));
}

TORCH_LIBRARY_IMPL(metal_prepack, Metal, m) {
m.impl("conv2d_run", conv2d_prepack_run);
}

TORCH_LIBRARY_IMPL(metal, Metal, m) {
m.impl("copy_to_host", copy_to_host);
}

} // namespace metal
} // namespace native
} // namespace at
55 changes: 0 additions & 55 deletions aten/src/ATen/native/metal/MetalPrepackOpRegister.mm

This file was deleted.

6 changes: 1 addition & 5 deletions aten/src/ATen/native/metal/MetalUtils.h
Expand Up @@ -12,11 +12,7 @@ std::vector<float> NCHW_to_NC4(
std::vector<float> NC4_to_NCHW(
const float* src,
const std::vector<int64_t>& sizes);
// The MPSCNNConvolution class takes weights in the order
// [outputChannels][kernelHeight][kernelWidth][inputChannels/groups].
std::vector<float> permuteWeights(
const float* src,
const std::vector<int64_t>& sizes);


} // namespace metal
} // namespace native
Expand Down