Skip to content

Commit

Permalink
Fix lintrunner
Browse files Browse the repository at this point in the history
  • Loading branch information
DenisVieriu97 committed Apr 24, 2023
1 parent ad4a8cd commit b497efc
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions aten/src/ATen/native/mps/OperationUtils.mm
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) {

NSArray<NSNumber*>* getTensorAxes(int64_t ndim) {
auto axes = [NSMutableArray<NSNumber*> arrayWithCapacity:ndim];
for (const auto i: c10::irange(ndim)) {
for (const auto i : c10::irange(ndim)) {
axes[i] = [NSNumber numberWithInteger:i];
}
return axes;
Expand All @@ -180,7 +180,7 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) {
IntArrayRef dimValues = dim.value();
int ndim = dimValues.size();
auto axes = [NSMutableArray<NSNumber*> arrayWithCapacity:ndim];
for (const auto i: c10::irange(ndim)) {
for (const auto i : c10::irange(ndim)) {
axes[i] = [NSNumber numberWithInteger:dimValues[i]];
}

Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/mps/operations/ReduceOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ void reduction_out_mps(const Tensor& input_t,
}

if (input_shape.size() >= 5 && canSqueezeLastDim) {
for (const auto i: c10::irange(4, input_shape.size())) {
for (const auto i : c10::irange(4, input_shape.size())) {
if (input_shape[i] != 1) {
canSqueezeLastDim = false;
}
Expand All @@ -171,11 +171,10 @@ void reduction_out_mps(const Tensor& input_t,

MPSShape* mpsShape = getMPSShape(input_t);
if (canSqueezeLastDim) {
mpsShape = @[@(input_shape[0]), @(input_shape[1]), @(input_shape[2]), @(input_shape[3])];
mpsShape = @[ @(input_shape[0]), @(input_shape[1]), @(input_shape[2]), @(input_shape[3]) ];
input_shape = makeArrayRef(input_shape.begin(), input_shape.end() - (input_t.dim() - 4));
}


NSMutableArray<NSNumber*>* axes = nil;
NSMutableArray<NSNumber*>* apparent_input_shape = nil;
NSMutableArray<NSNumber*>* apparent_output_shape = nil;
Expand Down Expand Up @@ -203,7 +202,8 @@ void reduction_out_mps(const Tensor& input_t,
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
auto inputScalarType = input_t.scalar_type();

MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input_t.scalar_type()), mpsShape);
MPSGraphTensor* inputTensor =
mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input_t.scalar_type()), mpsShape);
MPSGraphTensor* castInputTensor = inputTensor;
MPSDataType inputCastType = MPSDataTypeInvalid;
if (dtype.has_value() &&
Expand Down

0 comments on commit b497efc

Please sign in to comment.