Skip to content

Commit

Permalink
Fix lintrunner
Browse files Browse the repository at this point in the history
  • Loading branch information
DenisVieriu97 committed Apr 24, 2023
1 parent ce5e38e commit eeac863
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 34 deletions.
5 changes: 3 additions & 2 deletions aten/src/ATen/mps/MPSDevice.mm
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
NSError* error = nil;
if (!_mtl_indexing_library) {
MTLCompileOptions* options = [MTLCompileOptions new];
[options setLanguageVersion: getMetalLanguageVersion(_mtl_device, isMacOS13Plus(MacOSVersion::MACOS_VER_13_0_PLUS))];
[options setLanguageVersion:getMetalLanguageVersion(_mtl_device, isMacOS13Plus(MacOSVersion::MACOS_VER_13_0_PLUS))];
[options setFastMathEnabled:YES];
_mtl_indexing_library = [_mtl_device newLibraryWithSource:[NSString stringWithCString:mps::indexing_metal_shaders
encoding:NSASCIIStringEncoding]
Expand All @@ -58,7 +58,8 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
return state;
}

id<MTLFunction> indexFunction = [[indexing_lib newFunctionWithName: [NSString stringWithUTF8String: kernel.c_str()]] autorelease];
id<MTLFunction> indexFunction =
[[indexing_lib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]] autorelease];
TORCH_CHECK(indexFunction,
"Failed to create specialized function state object: ",
kernel,
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/mps/operations/BinaryKernel.mm
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,10 @@ void binary_mps_impl(TensorIteratorBase& iter, const std::string func_name) {
}
}

id<MTLComputePipelineState> kernelDataOffsetsPSO =
MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets");
id<MTLBuffer> kernelDataOffsets = [[device newBufferWithLength: numThreads * sizeof(simd_uint3)
options: 0] autorelease];
id<MTLComputePipelineState> kernelDataOffsetsPSO =
MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets");
id<MTLBuffer> kernelDataOffsets = [[device newBufferWithLength:numThreads * sizeof(simd_uint3)
options:0] autorelease];
TORCH_CHECK(
kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
[computeEncoder setComputePipelineState:kernelDataOffsetsPSO];
Expand Down
10 changes: 6 additions & 4 deletions aten/src/ATen/native/mps/operations/CrossKernel.mm
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,12 @@ void cross_mps_impl(const Tensor& out, const Tensor& input, const Tensor& other,
}
}

id<MTLComputePipelineState> kernelDataOffsetsPSO = MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets");
id<MTLBuffer> kernelDataOffsets = [[device newBufferWithLength: numThreads * sizeof(simd_uint3)
options: 0] autorelease];
TORCH_CHECK(kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
id<MTLComputePipelineState> kernelDataOffsetsPSO =
MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets");
id<MTLBuffer> kernelDataOffsets = [[device newBufferWithLength:numThreads * sizeof(simd_uint3)
options:0] autorelease];
TORCH_CHECK(
kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
[computeEncoder setComputePipelineState:kernelDataOffsetsPSO];
[computeEncoder setBytes:strides.data() length:sizeof(uint32_t) * nDim * nOffsets atIndex:0];
[computeEncoder setBuffer:kernelDataOffsets offset:0 atIndex:1];
Expand Down
53 changes: 29 additions & 24 deletions aten/src/ATen/native/mps/operations/Indexing.mm
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
#include <ATen/MemoryOverlap.h>
#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/ceil_div.h>
#include <ATen/mps/MPSAllocatorInterface.h>
#include <ATen/native/IndexKernel.h>
#include <ATen/native/IndexingUtils.h>
#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/native/Resize.h>
#include <ATen/native/TensorAdvancedIndexing.h>
#include <ATen/native/mps/MPSGraphVenturaOps.h>
#include <ATen/native/mps/operations/Indexing.h>
#include <ATen/mps/MPSAllocatorInterface.h>
#include <c10/core/QScheme.h>
#include <c10/util/SmallVector.h>
#include <c10/util/irange.h>
Expand Down Expand Up @@ -78,9 +78,12 @@ static bool dispatchIndexKernel(TensorIteratorBase& iter,

MTLSize gridSize = MTLSizeMake(numThreads, 1, 1);
id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
id<MTLComputePipelineState> kernelDataOffsetsPSO = MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets");
id<MTLBuffer> kernelDataOffsets = (id<MTLBuffer>)getIMPSAllocator()->allocate(numThreads * sizeof(simd_uint3)).get();
TORCH_CHECK(kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
id<MTLComputePipelineState> kernelDataOffsetsPSO =
MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets");
id<MTLBuffer> kernelDataOffsets =
(id<MTLBuffer>)getIMPSAllocator()->allocate(numThreads * sizeof(simd_uint3)).get();
TORCH_CHECK(
kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);

[computeEncoder setComputePipelineState:kernelDataOffsetsPSO];
[computeEncoder setBytes:strides.data() length:sizeof(uint32_t) * nDim * nOffsets atIndex:0];
Expand All @@ -91,12 +94,11 @@ static bool dispatchIndexKernel(TensorIteratorBase& iter,

NSUInteger kernelOffsetsTGSize = kernelDataOffsetsPSO.maxTotalThreadsPerThreadgroup;
if (kernelOffsetsTGSize > numThreads) {
kernelOffsetsTGSize = numThreads;
kernelOffsetsTGSize = numThreads;
}

MTLSize kernelOffsetsThreadGroupSize = MTLSizeMake(kernelOffsetsTGSize, 1, 1);
[computeEncoder dispatchThreads: gridSize
threadsPerThreadgroup: kernelOffsetsThreadGroupSize];
[computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:kernelOffsetsThreadGroupSize];

std::string indexFunction = getIndexFunctionName(inputTensor.scalar_type(), index_select, accumulate);
id<MTLComputePipelineState> indexSelectPSO = nil;
Expand All @@ -108,34 +110,36 @@ static bool dispatchIndexKernel(TensorIteratorBase& iter,
indexAB = [[device newBufferWithLength:argumentBufferLength options:0] autorelease];
uint64_t* indexABContents = (uint64_t*)(indexAB.contents);
for (uint32_t idx = 0; idx < num_indices; idx++) {
const Tensor& indexTensor = iter.tensor(idx+2);
indexABContents[idx] = getMTLBufferStorage(indexTensor).gpuAddress + (indexTensor.storage_offset() * indexTensor.element_size());
const Tensor& indexTensor = iter.tensor(idx + 2);
indexABContents[idx] =
getMTLBufferStorage(indexTensor).gpuAddress + (indexTensor.storage_offset() * indexTensor.element_size());
TORCH_CHECK(indexTensor.scalar_type() == ScalarType::Long, "index(): Expected dtype int64 for Index");
[computeEncoder useResource:getMTLBufferStorage(indexTensor) usage:MTLResourceUsageRead];
}
}
else
} else
#endif
{
id<MTLLibrary> lib = MPSDevice::getInstance()->getMetalIndexingLibrary();
id<MTLFunction> indexKernelFunction = [[lib newFunctionWithName: [NSString stringWithUTF8String: indexFunction.c_str()]] autorelease];
id<MTLArgumentEncoder> argumentEncoder = [[indexKernelFunction newArgumentEncoderWithBufferIndex:0] autorelease];
id<MTLFunction> indexKernelFunction =
[[lib newFunctionWithName:[NSString stringWithUTF8String:indexFunction.c_str()]] autorelease];
id<MTLArgumentEncoder> argumentEncoder =
[[indexKernelFunction newArgumentEncoderWithBufferIndex:0] autorelease];
NSUInteger argumentBufferLength = argumentEncoder.encodedLength;
indexAB = [[device newBufferWithLength:argumentBufferLength options:0] autorelease];
[argumentEncoder setArgumentBuffer:indexAB offset:0];

for (uint32_t idx = 0; idx < num_indices; idx++) {
const Tensor& indexTensor = iter.tensor(idx+2);
[argumentEncoder setBuffer: getMTLBufferStorage(indexTensor)
offset: indexTensor.storage_offset() * indexTensor.element_size()
atIndex: idx];
const Tensor& indexTensor = iter.tensor(idx + 2);
[argumentEncoder setBuffer:getMTLBufferStorage(indexTensor)
offset:indexTensor.storage_offset() * indexTensor.element_size()
atIndex:idx];
TORCH_CHECK(indexTensor.scalar_type() == ScalarType::Long, "index(): Expected dtype int64 for Index");
[computeEncoder useResource:getMTLBufferStorage(indexTensor) usage:MTLResourceUsageRead];
}

indexSelectPSO = [[device newComputePipelineStateWithFunction: indexKernelFunction
error: &error] autorelease];
TORCH_CHECK(indexSelectPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
indexSelectPSO = [[device newComputePipelineStateWithFunction:indexKernelFunction error:&error] autorelease];
TORCH_CHECK(
indexSelectPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
}

[computeEncoder setComputePipelineState:indexSelectPSO];
Expand All @@ -144,17 +148,18 @@ static bool dispatchIndexKernel(TensorIteratorBase& iter,
[computeEncoder setBytes:index_stride.data() length:sizeof(index_stride[0]) * index_stride.size() atIndex:2];
[computeEncoder setBuffer:kernelDataOffsets offset:0 atIndex:3];
[computeEncoder setBuffer:inputBuffer offset:inputTensor.storage_offset() * inputTensor.element_size() atIndex:4];
[computeEncoder setBuffer:outputBuffer offset:outputTensor.storage_offset() * outputTensor.element_size() atIndex:5];
[computeEncoder setBuffer:outputBuffer
offset:outputTensor.storage_offset() * outputTensor.element_size()
atIndex:5];
[computeEncoder setBytes:&num_indices length:sizeof(uint32_t) atIndex:6];

NSUInteger tgSize = indexSelectPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > numThreads) {
tgSize = numThreads;
tgSize = numThreads;
}

MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreads: gridSize
threadsPerThreadgroup: threadGroupSize];
[computeEncoder dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize];
}
});

Expand Down

0 comments on commit eeac863

Please sign in to comment.