Skip to content

Commit

Permalink
[MPS] Squeeze last dimensions if possible for 5D (or bigger) reductio…
Browse files Browse the repository at this point in the history
…ns (#99856)

Summary of changes:
- Reduction ops optimization - squeeze all dimensions after 4th dim if they are all 1
- Disable type inference only for 1D unary ops
Pull Request resolved: #99856
Approved by: https://github.com/kulinseth
  • Loading branch information
DenisVieriu97 authored and pytorchmergebot committed Apr 25, 2023
1 parent 87a2af6 commit cf21240
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 17 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/mps/OperationUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ static inline std::string getMPSTypeString(const Tensor& t, bool short_name = fa
}
std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type);
NSArray<NSNumber*>* getTensorAxes(const Tensor& t);
NSArray<NSNumber*>* getTensorAxes(const Tensor& t, at::OptionalIntArrayRef dim);
NSArray<NSNumber*>* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim);
std::string getMPSShapeString(MPSShape* shape);
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = false);
std::string getArrayRefString(const IntArrayRef s);
Expand Down
15 changes: 11 additions & 4 deletions aten/src/ATen/native/mps/OperationUtils.mm
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,23 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) {
}
}

NSArray<NSNumber*>* getTensorAxes(const Tensor& t) {
int64_t ndim = t.dim();
NSArray<NSNumber*>* getTensorAxes(int64_t ndim) {
auto axes = [NSMutableArray<NSNumber*> arrayWithCapacity:ndim];
for (const auto i : c10::irange(ndim)) {
axes[i] = [NSNumber numberWithInteger:i];
}
return axes;
}

NSArray<NSNumber*>* getTensorAxes(const Tensor& t, at::OptionalIntArrayRef dim) {
NSArray<NSNumber*>* getTensorAxes(const Tensor& t) {
return getTensorAxes(t.dim());
}

NSArray<NSNumber*>* getTensorAxes(const IntArrayRef& sizes) {
return getTensorAxes(sizes.size());
}

NSArray<NSNumber*>* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim) {
if (dim.has_value() && dim.value().size() != 0) {
IntArrayRef dimValues = dim.value();
int ndim = dimValues.size();
Expand All @@ -180,7 +187,7 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) {
return axes;
}

return getTensorAxes(t);
return getTensorAxes(sizes);
}

std::string getMPSShapeString(MPSShape* shape) {
Expand Down
42 changes: 30 additions & 12 deletions aten/src/ATen/native/mps/operations/ReduceOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ void set_apparent_shapes(NSMutableArray<NSNumber*>*& apparent_out_shape,
NSMutableArray<NSNumber*>*& apparent_in_shape,
int64_t num_reduce_dims,
int64_t num_output_dims,
IntArrayRef& input_shape,
const IntArrayRef& input_shape,
NSMutableArray<NSNumber*>*& axes) {
if (num_reduce_dims == 0) {
/* Output shape becomes a one
Expand Down Expand Up @@ -108,14 +108,12 @@ void set_axes(NSMutableArray<NSNumber*>*& axes,
}

// Helper function to prepare axes and tensor shapes
void set_axes_and_shapes(const Tensor& input_t,
void set_axes_and_shapes(const IntArrayRef& input_shape,
OptionalIntArrayRef opt_dims,
NSMutableArray<NSNumber*>*& axes,
NSMutableArray<NSNumber*>*& apparent_input_shape,
NSMutableArray<NSNumber*>*& apparent_output_shape,
NSMutableArray<NSNumber*>*& output_shape) {
IntArrayRef input_shape = input_t.sizes();

int64_t num_input_dims = input_shape.size();
int64_t num_reduce_dims = opt_dims.has_value() ? opt_dims.value().size() : 0;
int64_t num_output_dims;
Expand Down Expand Up @@ -146,25 +144,44 @@ void reduction_out_mps(const Tensor& input_t,
const std::string& func_name) {
bool macOS13_3_plus = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
MPS_CHECK_INT64_OP_SUPPORTED(input_t, macOS13_3_plus, func_name);

auto input_shape = input_t.sizes();
bool canSqueezeLastDim = true;
IntArrayRef input_shape = input_t.sizes();
if (opt_dim.has_value()) {
IntArrayRef dim = opt_dim.value();
for (const auto dim_val : dim) {
auto wrap_dim = maybe_wrap_dim(dim_val, input_shape.size());
if (wrap_dim >= 4) {
canSqueezeLastDim = false;
}
TORCH_CHECK(
wrap_dim < static_cast<decltype(wrap_dim)>(input_shape.size() == 0 ? input_t.numel() : input_shape.size()),
func_name + ": reduction dim must be in the range of input shape")
}
}

if (input_shape.size() >= 5 && canSqueezeLastDim) {
for (const auto i : c10::irange(4, input_shape.size())) {
if (input_shape[i] != 1) {
canSqueezeLastDim = false;
}
}
} else {
canSqueezeLastDim = false;
}

MPSShape* mpsShape = getMPSShape(input_t);
if (canSqueezeLastDim) {
mpsShape = @[ @(input_shape[0]), @(input_shape[1]), @(input_shape[2]), @(input_shape[3]) ];
input_shape = makeArrayRef(input_shape.begin(), input_shape.end() - (input_t.dim() - 4));
}

NSMutableArray<NSNumber*>* axes = nil;
NSMutableArray<NSNumber*>* apparent_input_shape = nil;
NSMutableArray<NSNumber*>* apparent_output_shape = nil;
NSMutableArray<NSNumber*>* output_shape = nil;

set_axes_and_shapes(input_t, opt_dim, axes, apparent_input_shape, apparent_output_shape, output_shape);
NSArray<NSNumber*>* wrappedAxes = mps::getTensorAxes(input_t, opt_dim);
set_axes_and_shapes(input_shape, opt_dim, axes, apparent_input_shape, apparent_output_shape, output_shape);
NSArray<NSNumber*>* wrappedAxes = mps::getTensorAxes(input_shape, opt_dim);

if (output_t.numel() == 0 || input_t.numel() == 0) {
if (reduction_type == MPSReductionType::PROD) {
Expand All @@ -185,7 +202,8 @@ void reduction_out_mps(const Tensor& input_t,
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
auto inputScalarType = input_t.scalar_type();

MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
MPSGraphTensor* inputTensor =
mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input_t.scalar_type()), mpsShape);
MPSGraphTensor* castInputTensor = inputTensor;
MPSDataType inputCastType = MPSDataTypeInvalid;
if (dtype.has_value() &&
Expand Down Expand Up @@ -250,7 +268,7 @@ void reduction_out_mps(const Tensor& input_t,
newCachedGraph->outputTensor_ = outputTensor;
});

auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t);
auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t, mpsShape);
auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output_t, apparent_output_shape);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
Expand Down Expand Up @@ -309,7 +327,7 @@ void impl_func_norm_mps(const Tensor& input_tensor,

set_apparent_shapes(apparent_output_shape, apparent_input_shape, num_reduce_dims, num_output_dims, input_shape, axes);

NSArray<NSNumber*>* wrappedAxes = mps::getTensorAxes(input_t, dim);
NSArray<NSNumber*>* wrappedAxes = mps::getTensorAxes(input_shape, dim);
if (cdist) {
apparent_input_shape = [mps::getMPSShape(input_tensor.sizes()) mutableCopy];
apparent_output_shape = [mps::getMPSShape(output_t.sizes()) mutableCopy];
Expand Down Expand Up @@ -426,7 +444,7 @@ Tensor std_var_common_impl_mps(const Tensor& input_t,
const auto correction_value = correction.value_or(1.0).toDouble();
int64_t correction_n = 1;

NSArray<NSNumber*>* wrappedAxes = getTensorAxes(input_t, dim);
NSArray<NSNumber*>* wrappedAxes = getTensorAxes(input_t.sizes(), dim);

int64_t num_output_dims = 0;
NSMutableArray<NSNumber*>* axes = nil;
Expand Down

0 comments on commit cf21240

Please sign in to comment.