Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MPS] Add fmax fmin op #95191

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions aten/src/ATen/native/mps/OperationUtils.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// Copyright © 2022 Apple Inc.

#pragma once
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guards possible duplicate includes in the future.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good.


#include <ATen/ATen.h>
#include <ATen/Tensor.h>
#include <ATen/Utils.h>
Expand Down
199 changes: 199 additions & 0 deletions aten/src/ATen/native/mps/operations/BinaryKernel.mm
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/BinaryOps.h>

namespace at::native {
namespace mps {

static const char* METAL_BINARY = R"BINARY_METAL(

#include <metal_stdlib>
using namespace metal;

template<typename T>
kernel void fmax(constant void * input_ [[buffer(0)]],
constant void * other_ [[buffer(1)]],
device void * out_ [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
uint tid [[thread_position_in_grid]]) {
device T* out = (device T*)((device uint8_t*)out_ + offsets[tid].x);
constant T* input = (constant T*)((constant uint8_t*)input_ + offsets[tid].y);
constant T* other = (constant T*)((constant uint8_t*)other_ + offsets[tid].z);

*out = fmax(*input, *other);
}

template<typename T>
kernel void fmin(constant void * input_ [[buffer(0)]],
constant void * other_ [[buffer(1)]],
device void * out_ [[buffer(2)]],
constant uint3 * offsets [[buffer(3)]],
uint tid [[thread_position_in_grid]]) {
device T* out = (device T*)((device uint8_t*)out_ + offsets[tid].x);
constant T* input = (constant T*)((constant uint8_t*)input_ + offsets[tid].y);
constant T* other = (constant T*)((constant uint8_t*)other_ + offsets[tid].z);

*out = fmin(*input, *other);
}

#define REGISTER_FMAX_OP(DTYPE) \
template \
[[host_name("fmax_" #DTYPE)]] \
kernel void fmax<DTYPE>( \
constant void * input_ [[buffer(0)]], \
constant void * other_ [[buffer(1)]], \
device void * out_ [[buffer(2)]], \
constant uint3 * offsets [[buffer(3)]], \
uint tid [[thread_position_in_grid]]);

#define REGISTER_FMIN_OP(DTYPE) \
template \
[[host_name("fmin_" #DTYPE)]] \
kernel void fmin<DTYPE>( \
constant void * input_ [[buffer(0)]], \
constant void * other_ [[buffer(1)]], \
device void * out_ [[buffer(2)]], \
constant uint3 * offsets [[buffer(3)]], \
uint tid [[thread_position_in_grid]]);

REGISTER_FMAX_OP(float);
REGISTER_FMAX_OP(half);
REGISTER_FMIN_OP(float);
REGISTER_FMIN_OP(half);

)BINARY_METAL";

using namespace mps;

static id<MTLLibrary> compileBinaryOpsLibrary(id<MTLDevice> device) {
static id<MTLLibrary> binaryLibrary = nil;
if (binaryLibrary) {
return binaryLibrary;
}

NSError *error = nil;
MTLCompileOptions *options = [[MTLCompileOptions new] autorelease];
[options setLanguageVersion: MTLLanguageVersion2_3];
binaryLibrary = [device newLibraryWithSource:[NSString stringWithCString: METAL_BINARY encoding:NSASCIIStringEncoding]
options:options
error:&error];
TORCH_CHECK(binaryLibrary, "Failed to create metal binary library, error: ", [[error description] UTF8String]);
return binaryLibrary;
}

static id<MTLComputePipelineState> binaryPipelineState(id<MTLDevice> device, const std::string& kernel) {
static std::unordered_map<std::string, id<MTLComputePipelineState>> psoCache;
id<MTLComputePipelineState> pso = psoCache[kernel];
if (pso) {
return pso;
}

NSError* error = nil;
id<MTLLibrary> binaryLib = compileBinaryOpsLibrary(device);
id<MTLFunction> binaryFunc = [binaryLib newFunctionWithName:[NSString stringWithUTF8String:kernel.c_str()]];
TORCH_CHECK(binaryFunc, "Failed to create function state object for: ", kernel);
pso = [device newComputePipelineStateWithFunction:binaryFunc error:&error];
TORCH_CHECK(pso, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);

psoCache[kernel] = pso;
return pso;
}

void fmax_fmin_mps_impl(TensorIteratorBase& iter, const std::string max_min) {
TORCH_CHECK(iter.common_dtype() != at::kDouble, "float64 is not supported on MPS");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,
May i know why assert double datatype here? Is it because the Apple GPU doesn't support double precision?
BTW, why doesn't unary op have such assertion?
Thank you.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it doesn't support double precision AFAIK. As for why unary ops don't have such assertion, I think they're just missed. But in general, you're not allowed to create an fp64 MPS tensor nor allowed to type cast / move a fp64 tensor to MPS.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for help. I find the code below,

TORCH_CHECK_TYPE(false,
"Cannot convert a float64 Tensor to MPS as the MPS framework doesn't support float64. "
"Please use float32 instead.")

and
static const std::string& getMetalType(const c10::ScalarType& t) {
// Mapping from c10::ScalarType to integral type that can be used for bitwise ops
// As bitwise ops sign-agnostic map signed/unsigned char and boolean to the same type
static std::unordered_map<c10::ScalarType, std::string> scalar_to_metal_type = {
{c10::ScalarType::Long, "long"},
{c10::ScalarType::Int, "int"},
{c10::ScalarType::Short, "short"},
{c10::ScalarType::Byte, "char"},
{c10::ScalarType::Char, "char"},
{c10::ScalarType::Bool, "char"},
};
auto it = scalar_to_metal_type.find(t);
TORCH_CHECK(it != scalar_to_metal_type.end(), "Unsupported type ", t);

Any FP64 data type will throw error in these MPS utilities.

BTW, here code is used to register the supported datatype for the fmax kernel, right? Does it mean only the half and the float32 kernels will be precompiled when AOT compiling.

REGISTER_FMAX_OP(float);
REGISTER_FMAX_OP(half);

Thank you.


Tensor input = iter.input(0);
Tensor other = iter.input(1);
Tensor out = iter.output(0);
id<MTLBuffer> inputBuffer = getMTLBufferStorage(input);
id<MTLBuffer> otherBuffer = getMTLBufferStorage(other);
id<MTLBuffer> outputBuffer = getMTLBufferStorage(out);
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
const uint32_t nDim = iter.ndim();
constexpr uint32_t nOffsets = 3;
const uint32_t numThreads = iter.numel();
dispatch_sync(mpsStream->queue(), ^(){
@autoreleasepool {
NSError* error = nil;
id<MTLCommandBuffer> commandBuffer = mpsStream->commandBuffer();
id<MTLComputeCommandEncoder> computeEncoder = [commandBuffer computeCommandEncoder];
MTLSize gridSize = MTLSizeMake(numThreads, 1, 1);
const IntArrayRef& iterShape = iter.shape();
std::vector<uint32_t> iterShapeData(iterShape.size());
std::vector<std::array<uint32_t, nOffsets>> strides(nDim);

for (const auto i: c10::irange(iterShape.size())) {
TORCH_CHECK(i <= UINT32_MAX);
iterShapeData[i] = (uint32_t)(iterShape[i]);
}

for (const auto i: c10::irange(nDim)) {
for (const auto offset: c10::irange(nOffsets)) {
strides[i][offset] = iter.strides(offset)[i];
}
}

id<MTLFunction> kernelDataOffsetsFunction = MPSDevice::getInstance()->metalIndexingFunction("kernel_index_offsets", nil);
id<MTLComputePipelineState> kernelDataOffsetsPSO = [[device newComputePipelineStateWithFunction: kernelDataOffsetsFunction
error: &error] autorelease];
id<MTLBuffer> kernelDataOffsets = [[device newBufferWithLength: numThreads * sizeof(simd_uint3)
options: 0] autorelease];
TORCH_CHECK(kernelDataOffsetsPSO, "Failed to created pipeline state object, error: ", [[error description] UTF8String]);
[computeEncoder setComputePipelineState:kernelDataOffsetsPSO];
[computeEncoder setBytes:strides.data() length:sizeof(uint32_t) * nDim * nOffsets atIndex:0];
[computeEncoder setBuffer:kernelDataOffsets offset:0 atIndex:1];
[computeEncoder setBytes:iterShapeData.data() length:sizeof(uint32_t) * iterShape.size() atIndex:2];
[computeEncoder setBytes:&nDim length:sizeof(uint32_t) atIndex:3];
[computeEncoder setBytes:&nOffsets length:sizeof(uint32_t) atIndex:4];

NSUInteger kernelOffsetsTGSize = kernelDataOffsetsPSO.maxTotalThreadsPerThreadgroup;
if (kernelOffsetsTGSize > numThreads)
kernelOffsetsTGSize = numThreads;

MTLSize kernelOffsetsThreadGroupSize = MTLSizeMake(kernelOffsetsTGSize, 1, 1);
[computeEncoder dispatchThreads: gridSize
threadsPerThreadgroup: kernelOffsetsThreadGroupSize];

const std::string kernel = "f" + max_min + "_" + scalarToMetalTypeString(out.scalar_type());
id<MTLComputePipelineState> fmaxfminPSO = binaryPipelineState(device, kernel);
[computeEncoder setComputePipelineState:fmaxfminPSO];
[computeEncoder setBuffer:inputBuffer offset:input.storage_offset() * input.element_size() atIndex:0];
[computeEncoder setBuffer:otherBuffer offset:other.storage_offset() * other.element_size() atIndex:1];
[computeEncoder setBuffer:outputBuffer offset:out.storage_offset() * out.element_size() atIndex:2];
[computeEncoder setBuffer:kernelDataOffsets offset:0 atIndex:3];

NSUInteger tgSize = fmaxfminPSO.maxTotalThreadsPerThreadgroup;
if (tgSize > numThreads) {
tgSize = numThreads;
}

MTLSize threadGroupSize = MTLSizeMake(tgSize, 1, 1);
[computeEncoder dispatchThreads: gridSize
threadsPerThreadgroup: threadGroupSize];

[computeEncoder endEncoding];
mpsStream->commit(true);
}
});
}
} // namespace mps

void fmax_mps_kernel(TensorIteratorBase& iter) {
if (isFloatingType(iter.common_dtype())) {
mps::fmax_fmin_mps_impl(iter, "max");
} else {
at::maximum_out(const_cast<Tensor&>(iter.output()), iter.input(0), iter.input(1));
}
}
void fmin_mps_kernel(TensorIteratorBase& iter) {
if (isFloatingType(iter.common_dtype())) {
mps::fmax_fmin_mps_impl(iter, "min");
} else {
at::minimum_out(const_cast<Tensor&>(iter.output()), iter.input(0), iter.input(1));
}
}

REGISTER_DISPATCH(fmax_stub, &fmax_mps_kernel);
REGISTER_DISPATCH(fmin_stub, &fmin_mps_kernel);

} // namespace at::native
4 changes: 2 additions & 2 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9297,7 +9297,7 @@
structured_inherits: TensorIteratorBase
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: fmin_out
CPU, CUDA, MPS: fmin_out
tags: pointwise

- func: max(Tensor self) -> Tensor
Expand All @@ -9319,7 +9319,7 @@
structured_inherits: TensorIteratorBase
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: fmax_out
CPU, CUDA, MPS: fmax_out
tags: pointwise

- func: maximum(Tensor self, Tensor other) -> Tensor
Expand Down
4 changes: 4 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -9245,6 +9245,8 @@ class TestConsistency(TestCaseMPS):
'float': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'floor': ['f32', 'f16', 'i16', 'i32', 'i64'],
'floor_divide': ['f32', 'f16'],
'fmax': ['b8', 'f32', 'f16', 'i16', 'i32', 'i64', 'u8'],
'fmin': ['b8', 'f32', 'f16', 'i16', 'i32', 'i64', 'u8'],
'fmod': ['f32', 'f16', 'i16', 'i32', 'i64', 'u8'],
'frac': ['f16', 'f32'],
'gather': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
Expand Down Expand Up @@ -9505,6 +9507,8 @@ class TestConsistency(TestCaseMPS):
'flipud': ['f16', 'f32'],
'float': ['f32'],
'floor': ['f32'],
'fmax': ['f16', 'f32'],
'fmin': ['f16', 'f32'],
'gradient': ['f32'],
'half': ['f16'],
'hstack': ['f16', 'f32'],
Expand Down