Skip to content

Commit

Permalink
[OSS][Metal] Support Resnet models
Browse files Browse the repository at this point in the history
Summary:
This diff adds the missing ops to run the Resnet models from Torchvision. Move the tensors to GPU can significantly improve the perf as show below (iPhone11)

Time running on CPU (ms):

```
forward took: 166.115
forward took: 150.722
forward took: 150.383
forward took: 150.345
forward took: 150.761
forward took: 150.533
forward took: 150.588
forward took: 150.812
forward took: 150.925
forward took: 150.25
```

Time running on GPU (ms):

```
forward took: 39.9355
forward took: 41.3531
forward took: 41.798
forward took: 40.4744
forward took: 39.5181
forward took: 42.6464
forward took: 41.2658
forward took: 40.0862
forward took: 42.3533
forward took: 41.9348
```

Discrepancy in result

```
GPU:
    "(623, 4.6211)",
    "(111, 3.8809)",
    "(499, 3.8555)",
    "(596, 3.8047)",
    "(473, 3.7422)",
    "(846, 3.5762)",
    "(892, 3.5449)",
    "(813, 3.5098)",
    "(446, 3.5020)",
    "(902, 3.4980)"
CPU:
    "(623, 4.4229)",
    "(499, 3.8321)",
    "(596, 3.6192)",
    "(111, 3.5295)",
    "(813, 3.4848)",
    "(584, 3.3979)",
    "(418, 3.3357)",
    "(473, 3.2760)",
    "(846, 3.2745)",
    "(902, 3.2376)"
```

Test Plan: {F340824316}

Reviewed By: IvanKobzarev

Differential Revision: D24416294

fbshipit-source-id: 12c9199ade0b76a7aa8a3838eddc4c19c79b6f37
  • Loading branch information
xta0 authored and facebook-github-bot committed Oct 22, 2020
1 parent 9371944 commit b63ddd6
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 0 deletions.
24 changes: 24 additions & 0 deletions aten/src/ATen/native/metal/MetalAten.mm
Expand Up @@ -160,6 +160,11 @@ Tensor relu(const Tensor& input) {
return mpscnn::relu(input);
}

Tensor& relu_(Tensor& input) {
TORCH_CHECK(input.is_metal());
return mpscnn::relu_(input);
}

Tensor sigmoid(const Tensor& input) {
TORCH_CHECK(input.is_metal());
return mpscnn::sigmoid(input);
Expand Down Expand Up @@ -192,6 +197,14 @@ Tensor add_Tensor(const Tensor& input1, const Tensor& input2, Scalar alpha) {
return mpscnn::add(input1, input2.is_metal() ? input2 : input2.metal());
}

Tensor& add__Tensor(Tensor& input1, const Tensor& input2, Scalar alpha) {
TORCH_CHECK(input1.is_metal());
TORCH_CHECK(input1.dim() == input2.dim());
TORCH_CHECK(input1.sizes()[2] == input2.sizes()[2]);
TORCH_CHECK(input1.sizes()[3] == input2.sizes()[3]);
return mpscnn::add_(input1, input2.is_metal() ? input2 : input2.metal());
}

Tensor sub_Tensor(const Tensor& input1, const Tensor& input2, Scalar alpha) {
TORCH_CHECK(input1.is_metal());
TORCH_CHECK(input1.dim() == input2.dim());
Expand Down Expand Up @@ -223,23 +236,34 @@ Tensor reshape(const Tensor& input, IntArrayRef shape) {
return mpscnn::reshape(input, shape);
}

Tensor flatten_using_ints(
const Tensor& input,
int64_t start_dim,
int64_t end_dim) {
TORCH_CHECK(input.is_metal());
return mpscnn::flatten_using_ints(input, start_dim, end_dim);
}

TORCH_LIBRARY_IMPL(aten, Metal, m) {
m.impl("conv2d", TORCH_FN(conv2d));
m.impl("add.Tensor", TORCH_FN(add_Tensor));
m.impl("add_.Tensor", TORCH_FN(add__Tensor));
m.impl("addmm", TORCH_FN(addmm));
m.impl_UNBOXED("empty.memory_format", empty);
m.impl("empty_strided", TORCH_FN(empty_strided));
m.impl("log_softmax.int", TORCH_FN(log_softmax_int));
m.impl("max_pool2d", TORCH_FN(max_pool2d));
m.impl("mul.Tensor", TORCH_FN(mul_Tensor));
m.impl("relu", TORCH_FN(relu));
m.impl("relu_", TORCH_FN(relu_));
m.impl("sigmoid", TORCH_FN(sigmoid));
m.impl("sub.Tensor", TORCH_FN(sub_Tensor));
m.impl("upsample_nearest2d.vec", TORCH_FN(upsample_nearest2d_vec));
m.impl("view", TORCH_FN(view));
m.impl("adaptive_avg_pool2d", TORCH_FN(adaptive_avg_pool2d));
m.impl("hardtanh_", TORCH_FN(hardtanh_));
m.impl("reshape", TORCH_FN(reshape));
m.impl("flatten.using_ints", TORCH_FN(flatten_using_ints));
}

} // namespace metal
Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/native/metal/mpscnn/MPSCNNOps.h
Expand Up @@ -30,6 +30,8 @@ Tensor global_avg_pool2d(const Tensor& input, IntArrayRef output_size);

Tensor relu(const Tensor& input);

Tensor& relu_(Tensor& input);

Tensor sigmoid(const Tensor& input);

Tensor& hardtanh_(Tensor& input, Scalar min_val, Scalar max_val);
Expand All @@ -44,6 +46,8 @@ Tensor addmm(const Tensor& bias, const Tensor& input, const Tensor& weight);

Tensor add(const Tensor& input1, const Tensor& input2);

Tensor& add_(Tensor& input1, const Tensor& input2);

Tensor sub(const Tensor& input1, const Tensor& input2);

Tensor mul(const Tensor& input1, const Tensor& input2);
Expand All @@ -55,6 +59,8 @@ Tensor upsample_nearest2d_vec(
c10::optional<IntArrayRef> output_size,
c10::optional<ArrayRef<double>> scale_factors);

Tensor flatten_using_ints(const Tensor & input, int64_t start_dim, int64_t end_dim);

Tensor copy_to_host(const Tensor& input);

} // namespace mpscnn
Expand Down
92 changes: 92 additions & 0 deletions aten/src/ATen/native/metal/mpscnn/MPSCNNOps.mm
Expand Up @@ -216,11 +216,36 @@ Tensor neuronKernel(const Tensor& input, MPSCNNNeuron* neuron) {
return output;
}

API_AVAILABLE(ios(10.0), macos(10.13))
Tensor& neuronKernel_(Tensor& input, MPSCNNNeuron* neuron) {
MPSImage* X = imageFromTensor(input);
std::vector<int64_t> outputSize = input.sizes().vec();
std::vector<int64_t> textureSize = outputSize;
if (input.dim() == 2) {
textureSize = {outputSize[0], outputSize[1], 1, 1};
}
MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input);
MPSImage* Y = [MPSImage temporaryImageFromSize:input.sizes().vec()
commandBuffer:commandBuffer];
[neuron encodeToCommandBuffer:commandBuffer.buffer
sourceImage:X
destinationImage:Y];
MetalTensorImpl* impl = (MetalTensorImpl*)input.unsafeGetTensorImpl();
MetalTensor& metalTensor = impl->unsafe_opaque_handle();
metalTensor.texture()->copyFromTexture(Y);
return input;
}

API_AVAILABLE(ios(10.0), macos(10.13))
Tensor relu(const Tensor& input) {
return neuronKernel(input, [MPSCNNNeuronOp relu]);
}

API_AVAILABLE(ios(10.0), macos(10.13))
Tensor& relu_(Tensor& input) {
return neuronKernel_(input, [MPSCNNNeuronOp relu]);
}

API_AVAILABLE(ios(10.0), macos(10.13))
Tensor sigmoid(const Tensor& input) {
return neuronKernel(input, [MPSCNNNeuronOp sigmoid]);
Expand Down Expand Up @@ -356,12 +381,50 @@ Tensor binaryElementwiseKernel(
return output;
}

API_AVAILABLE(ios(10.0), macos(10.13))
Tensor& binaryElementwiseKernel_(
Tensor& input1,
const Tensor& input2,
NSString* arrayKernel,
NSString* nonarrayKernal) {
MPSImage* X1 = imageFromTensor(input1);
MPSImage* X2 = imageFromTensor(input2);
std::vector<int64_t> outputSize = input1.sizes().vec();
MetalCommandBuffer* cb1 = commandBufferFromInputTensor(input1);
MetalCommandBuffer* cb2 = commandBufferFromInputTensor(input2);
TORCH_CHECK([cb1 isEqual:cb2], @"inputs have different command buffer");
MPSImage* Y = [MPSImage temporaryImageFromSize:outputSize commandBuffer:cb1];
id<MTLComputePipelineState> state = [[MPSCNNContext sharedInstance]
pipelineState:kernelFor(X1, arrayKernel, nonarrayKernal)];
id<MTLComputeCommandEncoder> encoder = [cb1.buffer computeCommandEncoder];
[encoder setComputePipelineState:state];
[encoder setTexture:[X1 texture] atIndex:0];
[encoder setTexture:[X2 texture] atIndex:1];
[encoder setTexture:[Y texture] atIndex:2];
const auto& launchParams = spatialPointwiseKernelLaunchParams(state, Y);
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
[encoder endEncoding];
[X1 markRead];
[X2 markRead];
MetalTensorImpl* impl = (MetalTensorImpl*)input1.unsafeGetTensorImpl();
MetalTensor& metalTensor = impl->unsafe_opaque_handle();
metalTensor.texture()->copyFromTexture(Y);
return input1;
}

API_AVAILABLE(ios(10.0), macos(10.13))
Tensor add(const Tensor& input1, const Tensor& input2) {
return binaryElementwiseKernel(
input1, input2, @"elementwise_add", @"elementwise_add_nonarray");
}

API_AVAILABLE(ios(10.0), macos(10.13))
Tensor& add_(Tensor& input1, const Tensor& input2) {
return binaryElementwiseKernel_(
input1, input2, @"elementwise_add", @"elementwise_add_nonarray");
}

API_AVAILABLE(ios(10.0), macos(10.13))
Tensor sub(const Tensor& input1, const Tensor& input2) {
return binaryElementwiseKernel(
Expand Down Expand Up @@ -510,6 +573,35 @@ Tensor upsample_nearest2d_vec(
return output;
}

Tensor flatten_using_ints(
const Tensor& input,
int64_t start_dim,
int64_t end_dim) {
start_dim = maybe_wrap_dim(start_dim, input.dim());
end_dim = maybe_wrap_dim(end_dim, input.dim());
TORCH_CHECK(
start_dim <= end_dim,
"flatten() has invalid args: start_dim cannot come after end_dim");
std::vector<int64_t> shape;
if (input.dim() == 0) {
return input.reshape({1});
}
if (start_dim == end_dim) {
return input;
}
auto slice_numel =
prod_intlist(input.sizes().slice(start_dim, end_dim - start_dim + 1));
shape.reserve(input.dim() - end_dim + start_dim);
for (int64_t i = 0; i < start_dim; i++) {
shape.push_back(input.size(i));
}
shape.push_back(slice_numel);
for (int64_t i = end_dim + 1; i < input.dim(); i++) {
shape.push_back(input.size(i));
}
return input.reshape(shape);
}

Tensor copy_to_host(const Tensor& input) {
MPSImage* X = imageFromTensor(input);
MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input);
Expand Down

0 comments on commit b63ddd6

Please sign in to comment.