Skip to content

Commit

Permalink
Update on "[quant][graphmode][fx] Add support for additional_{fusion/…
Browse files Browse the repository at this point in the history
…quant}_pattern"

Summary:
Allow user to provide additional fusion/quant patterns for fx graph mode

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D24317437](https://our.internmc.facebook.com/intern/diff/D24317437)

[ghstack-poisoned]
  • Loading branch information
jerryzh168 committed Oct 22, 2020
2 parents d62ccc8 + 658f90a commit 3824fc7
Show file tree
Hide file tree
Showing 104 changed files with 7,305 additions and 245 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ jobs:
- name: Ensure canonical include
run: |
(! git grep -I -l $'#include "' -- ./c10 ./aten ./torch/csrc ':(exclude)aten/src/ATen/native/quantized/cpu/qnnpack/**' || (echo "The above files have include with quotes; please convert them to #include <xxxx>"; false))
# note that this next step depends on a clean shallow checkout;
# if you run it locally in a deep checkout then it will complain
# about android/libs/fbjni/gradlew (in a submodule),
# as well as all the generated files in torch/test
- name: Ensure C++ source files are not executable
run: |
(! find . \( -path ./third_party -o -path ./.git -o -path ./torch/bin -o -path ./build \) -prune -o -type f -executable -regextype posix-egrep -not -regex '.+(\.(bash|sh|py|so)|git-pre-commit|git-clang-format)$' -print | grep . || (echo 'The above files have executable permission; please remove their executable permission by using `chmod -x`'; false))
Expand Down
13 changes: 8 additions & 5 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,26 @@ Copyright (c) 2016-present, Facebook Inc. All rights reserved.

All contributions by Facebook:
Copyright (c) 2016 Facebook Inc.

All contributions by Google:
Copyright (c) 2015 Google Inc.
All rights reserved.

All contributions by Yangqing Jia:
Copyright (c) 2015 Yangqing Jia
All rights reserved.


All contributions by Kakao Brain:
Copyright 2019-2020 Kakao Brain

All contributions from Caffe:
Copyright(c) 2013, 2014, 2015, the respective contributors
All rights reserved.

All other contributions:
Copyright(c) 2015, 2016 the respective contributors
All rights reserved.

Caffe2 uses a copyright model similar to Caffe: each contributor holds
copyright over their contributions to Caffe2. The project versioning records
all such contribution and copyright details. If a contributor wants to further
Expand Down
3 changes: 3 additions & 0 deletions NOTICE
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ All contributions by Yangqing Jia:
Copyright (c) 2015 Yangqing Jia
All rights reserved.

All contributions by Kakao Brain:
Copyright 2019-2020 Kakao Brain

All other contributions:
Copyright(c) 2015, 2016 the respective contributors
All rights reserved.
Expand Down
32 changes: 32 additions & 0 deletions aten/src/ATen/BatchingRegistrations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,34 @@ Tensor to_dtype_layout_batching_rule(
return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
}

Tensor new_zeros_batching_rule(
const Tensor& self,
IntArrayRef size,
optional<ScalarType> dtype,
optional<Layout> layout,
optional<Device> device,
optional<bool> pin_memory) {
auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
auto physical_size = physical_view.getPhysicalShape(size);
auto options = TensorOptions()
.dtype(dtype)
.layout(layout)
.device(device)
.pinned_memory(pin_memory);
auto result = physical_view.tensor().new_zeros(physical_size, options);
return physical_view.newLogicalFromPhysical(result);
}

Tensor new_empty_batching_rule(
const Tensor& self,
IntArrayRef size,
const TensorOptions& options) {
auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
auto physical_size = physical_view.getPhysicalShape(size);
auto result = physical_view.tensor().new_empty(physical_size, options);
return physical_view.newLogicalFromPhysical(result);
}

TORCH_LIBRARY_IMPL(_, Batched, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&batchedTensorForLoopFallback>());
}
Expand Down Expand Up @@ -669,6 +697,10 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
m.impl("select_backward", select_backward_batching_rule);
m.impl("slice_backward", slice_backward_batching_rule);
m.impl("diagonal_backward", diagonal_backward_batching_rule);

// Tensor.new_* operators
m.impl_UNBOXED("new_empty", new_empty_batching_rule);
m.impl("new_zeros", new_zeros_batching_rule);
}

} // namespace at
2 changes: 1 addition & 1 deletion aten/src/ATen/native/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ auto ConvParams::is_output_padding_neg() const -> bool {
auto ConvParams::is_output_padding_big() const -> bool {
bool is_big = false;
for (size_t i = 0; i < output_padding.size(); i++) {
is_big |= (output_padding[i] >= stride[i] || output_padding[i] >= dilation[i]);
is_big |= (output_padding[i] >= stride[i]);
}
return is_big;
}
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/Im2Col.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
namespace at {
namespace native {
namespace {

static void im2col_out_cpu_template(
Tensor& output,
const Tensor& input_,
Expand Down
13 changes: 9 additions & 4 deletions aten/src/ATen/native/im2col_shape_check.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,11 @@ static inline void col2im_shape_check(
dilation_width);

int64_t ndim = input.ndimension();
// allow dim=0 only the batch dimension.
TORCH_CHECK(
input.numel() != 0 && (ndim == 2 || ndim == 3),
"Expected non-empty 2D or 3D input tensor, but got input of sizes",
(ndim == 2 && input.size(0) != 0 && input.size(1) != 0) ||
(ndim == 3 && input.size(1) != 0 && input.size(2) != 0),
"Expected 2D or 3D (batch mode) tensor for input with possibly 0 batch size and non-zero dimensions for input, but got: ",
input.sizes());

int64_t batch_dim = (ndim == 3) ? 0 : -1;
Expand Down Expand Up @@ -155,9 +157,12 @@ static inline void im2col_shape_check(

int64_t ndim = input.ndimension();

// allow dim=0 only the batch dimension.
bool valid_dims = input.size(1) != 0 && input.size(2) != 0;
TORCH_CHECK(
input.numel() != 0 && (ndim == 3 || ndim == 4),
"Expected non-empty 3D or 4D input tensor, but got input of size ",
(ndim == 3 && input.size(0) && valid_dims) ||
(ndim == 4 && valid_dims && input.size(3) != 0),
"Expected 3D or 4D (batch mode) tensor with possibly 0 batch size and other non-zero dimensions for input, but got: ",
input.sizes());

int64_t dim_batch = 0;
Expand Down
24 changes: 24 additions & 0 deletions aten/src/ATen/native/metal/MetalAten.mm
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,11 @@ Tensor relu(const Tensor& input) {
return mpscnn::relu(input);
}

Tensor& relu_(Tensor& input) {
TORCH_CHECK(input.is_metal());
return mpscnn::relu_(input);
}

Tensor sigmoid(const Tensor& input) {
TORCH_CHECK(input.is_metal());
return mpscnn::sigmoid(input);
Expand Down Expand Up @@ -192,6 +197,14 @@ Tensor add_Tensor(const Tensor& input1, const Tensor& input2, Scalar alpha) {
return mpscnn::add(input1, input2.is_metal() ? input2 : input2.metal());
}

Tensor& add__Tensor(Tensor& input1, const Tensor& input2, Scalar alpha) {
TORCH_CHECK(input1.is_metal());
TORCH_CHECK(input1.dim() == input2.dim());
TORCH_CHECK(input1.sizes()[2] == input2.sizes()[2]);
TORCH_CHECK(input1.sizes()[3] == input2.sizes()[3]);
return mpscnn::add_(input1, input2.is_metal() ? input2 : input2.metal());
}

Tensor sub_Tensor(const Tensor& input1, const Tensor& input2, Scalar alpha) {
TORCH_CHECK(input1.is_metal());
TORCH_CHECK(input1.dim() == input2.dim());
Expand Down Expand Up @@ -223,23 +236,34 @@ Tensor reshape(const Tensor& input, IntArrayRef shape) {
return mpscnn::reshape(input, shape);
}

Tensor flatten_using_ints(
const Tensor& input,
int64_t start_dim,
int64_t end_dim) {
TORCH_CHECK(input.is_metal());
return mpscnn::flatten_using_ints(input, start_dim, end_dim);
}

TORCH_LIBRARY_IMPL(aten, Metal, m) {
m.impl("conv2d", TORCH_FN(conv2d));
m.impl("add.Tensor", TORCH_FN(add_Tensor));
m.impl("add_.Tensor", TORCH_FN(add__Tensor));
m.impl("addmm", TORCH_FN(addmm));
m.impl_UNBOXED("empty.memory_format", empty);
m.impl("empty_strided", TORCH_FN(empty_strided));
m.impl("log_softmax.int", TORCH_FN(log_softmax_int));
m.impl("max_pool2d", TORCH_FN(max_pool2d));
m.impl("mul.Tensor", TORCH_FN(mul_Tensor));
m.impl("relu", TORCH_FN(relu));
m.impl("relu_", TORCH_FN(relu_));
m.impl("sigmoid", TORCH_FN(sigmoid));
m.impl("sub.Tensor", TORCH_FN(sub_Tensor));
m.impl("upsample_nearest2d.vec", TORCH_FN(upsample_nearest2d_vec));
m.impl("view", TORCH_FN(view));
m.impl("adaptive_avg_pool2d", TORCH_FN(adaptive_avg_pool2d));
m.impl("hardtanh_", TORCH_FN(hardtanh_));
m.impl("reshape", TORCH_FN(reshape));
m.impl("flatten.using_ints", TORCH_FN(flatten_using_ints));
}

} // namespace metal
Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/native/metal/mpscnn/MPSCNNOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ Tensor global_avg_pool2d(const Tensor& input, IntArrayRef output_size);

Tensor relu(const Tensor& input);

Tensor& relu_(Tensor& input);

Tensor sigmoid(const Tensor& input);

Tensor& hardtanh_(Tensor& input, Scalar min_val, Scalar max_val);
Expand All @@ -44,6 +46,8 @@ Tensor addmm(const Tensor& bias, const Tensor& input, const Tensor& weight);

Tensor add(const Tensor& input1, const Tensor& input2);

Tensor& add_(Tensor& input1, const Tensor& input2);

Tensor sub(const Tensor& input1, const Tensor& input2);

Tensor mul(const Tensor& input1, const Tensor& input2);
Expand All @@ -55,6 +59,8 @@ Tensor upsample_nearest2d_vec(
c10::optional<IntArrayRef> output_size,
c10::optional<ArrayRef<double>> scale_factors);

Tensor flatten_using_ints(const Tensor & input, int64_t start_dim, int64_t end_dim);

Tensor copy_to_host(const Tensor& input);

} // namespace mpscnn
Expand Down
92 changes: 92 additions & 0 deletions aten/src/ATen/native/metal/mpscnn/MPSCNNOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,36 @@ Tensor neuronKernel(const Tensor& input, MPSCNNNeuron* neuron) {
return output;
}

API_AVAILABLE(ios(10.0), macos(10.13))
Tensor& neuronKernel_(Tensor& input, MPSCNNNeuron* neuron) {
MPSImage* X = imageFromTensor(input);
std::vector<int64_t> outputSize = input.sizes().vec();
std::vector<int64_t> textureSize = outputSize;
if (input.dim() == 2) {
textureSize = {outputSize[0], outputSize[1], 1, 1};
}
MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input);
MPSImage* Y = [MPSImage temporaryImageFromSize:input.sizes().vec()
commandBuffer:commandBuffer];
[neuron encodeToCommandBuffer:commandBuffer.buffer
sourceImage:X
destinationImage:Y];
MetalTensorImpl* impl = (MetalTensorImpl*)input.unsafeGetTensorImpl();
MetalTensor& metalTensor = impl->unsafe_opaque_handle();
metalTensor.texture()->copyFromTexture(Y);
return input;
}

API_AVAILABLE(ios(10.0), macos(10.13))
Tensor relu(const Tensor& input) {
return neuronKernel(input, [MPSCNNNeuronOp relu]);
}

API_AVAILABLE(ios(10.0), macos(10.13))
Tensor& relu_(Tensor& input) {
return neuronKernel_(input, [MPSCNNNeuronOp relu]);
}

API_AVAILABLE(ios(10.0), macos(10.13))
Tensor sigmoid(const Tensor& input) {
return neuronKernel(input, [MPSCNNNeuronOp sigmoid]);
Expand Down Expand Up @@ -356,12 +381,50 @@ Tensor binaryElementwiseKernel(
return output;
}

API_AVAILABLE(ios(10.0), macos(10.13))
Tensor& binaryElementwiseKernel_(
Tensor& input1,
const Tensor& input2,
NSString* arrayKernel,
NSString* nonarrayKernal) {
MPSImage* X1 = imageFromTensor(input1);
MPSImage* X2 = imageFromTensor(input2);
std::vector<int64_t> outputSize = input1.sizes().vec();
MetalCommandBuffer* cb1 = commandBufferFromInputTensor(input1);
MetalCommandBuffer* cb2 = commandBufferFromInputTensor(input2);
TORCH_CHECK([cb1 isEqual:cb2], @"inputs have different command buffer");
MPSImage* Y = [MPSImage temporaryImageFromSize:outputSize commandBuffer:cb1];
id<MTLComputePipelineState> state = [[MPSCNNContext sharedInstance]
pipelineState:kernelFor(X1, arrayKernel, nonarrayKernal)];
id<MTLComputeCommandEncoder> encoder = [cb1.buffer computeCommandEncoder];
[encoder setComputePipelineState:state];
[encoder setTexture:[X1 texture] atIndex:0];
[encoder setTexture:[X2 texture] atIndex:1];
[encoder setTexture:[Y texture] atIndex:2];
const auto& launchParams = spatialPointwiseKernelLaunchParams(state, Y);
[encoder dispatchThreadgroups:launchParams.threadgroupsPerGrid
threadsPerThreadgroup:launchParams.threadsPerThreadgroup];
[encoder endEncoding];
[X1 markRead];
[X2 markRead];
MetalTensorImpl* impl = (MetalTensorImpl*)input1.unsafeGetTensorImpl();
MetalTensor& metalTensor = impl->unsafe_opaque_handle();
metalTensor.texture()->copyFromTexture(Y);
return input1;
}

API_AVAILABLE(ios(10.0), macos(10.13))
Tensor add(const Tensor& input1, const Tensor& input2) {
return binaryElementwiseKernel(
input1, input2, @"elementwise_add", @"elementwise_add_nonarray");
}

API_AVAILABLE(ios(10.0), macos(10.13))
Tensor& add_(Tensor& input1, const Tensor& input2) {
return binaryElementwiseKernel_(
input1, input2, @"elementwise_add", @"elementwise_add_nonarray");
}

API_AVAILABLE(ios(10.0), macos(10.13))
Tensor sub(const Tensor& input1, const Tensor& input2) {
return binaryElementwiseKernel(
Expand Down Expand Up @@ -510,6 +573,35 @@ Tensor upsample_nearest2d_vec(
return output;
}

Tensor flatten_using_ints(
const Tensor& input,
int64_t start_dim,
int64_t end_dim) {
start_dim = maybe_wrap_dim(start_dim, input.dim());
end_dim = maybe_wrap_dim(end_dim, input.dim());
TORCH_CHECK(
start_dim <= end_dim,
"flatten() has invalid args: start_dim cannot come after end_dim");
std::vector<int64_t> shape;
if (input.dim() == 0) {
return input.reshape({1});
}
if (start_dim == end_dim) {
return input;
}
auto slice_numel =
prod_intlist(input.sizes().slice(start_dim, end_dim - start_dim + 1));
shape.reserve(input.dim() - end_dim + start_dim);
for (int64_t i = 0; i < start_dim; i++) {
shape.push_back(input.size(i));
}
shape.push_back(slice_numel);
for (int64_t i = end_dim + 1; i < input.dim(); i++) {
shape.push_back(input.size(i));
}
return input.reshape(shape);
}

Tensor copy_to_host(const Tensor& input) {
MPSImage* X = imageFromTensor(input);
MetalCommandBuffer* commandBuffer = commandBufferFromInputTensor(input);
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/native/quantized/cpu/quant_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ inline TensorQuantizationParams ChooseQuantizationParams(
// to be a middle value between qmin and qmax.
// If either min or max is 0, then we just use 0 as zero_point.
if (min < 0 && max > 0 && preserve_sparsity) {
initial_zero_point = (qmin + qmax) / 2 + 1;
const auto midpoint = qmin + (qmax - qmin) / 2; // Overflow-safe midpoint
initial_zero_point = midpoint + 1;
}

// Now we need to nudge the zero point to be an integer
Expand Down
6 changes: 6 additions & 0 deletions benchmarks/operator_benchmark/benchmark_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ def _build_test(configs, bench_op, OperatorTestCase, run_backward, op_name_funct
op._set_backward_test(run_backward)
op.init(**init_dict)

if not run_backward:
for _, attr in vars(op).items():
if isinstance(attr, torch.nn.Module):
for param in attr.parameters():
param.requires_grad = False

input_name = None

# _num_inputs_require_grads is used to track the number of tensors
Expand Down

0 comments on commit 3824fc7

Please sign in to comment.