Skip to content

Commit

Permalink
[MPS] Fwd-fix for clamp regression (#122148)
Browse files Browse the repository at this point in the history
Forward fix for regressions introduced by #121381 as we failed to run MPS CI twice on it

- Do not call `minimumWithNaNPropagationWithPrimaryTensor` for integral tensors as it will crash with
  ```
    /AppleInternal/Library/BuildRoots/ce725a5f-c761-11ee-a4ec-b6ef2fd8d87b/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Utility/MPSKernelDAG.mm:805: failed assertion `Error getting visible function: (null) Function isNaN_i16_i8 was not found in the library'
   ```
- Change the order of max and min call as it's apparently important for
  consistency, as `min(max(a, b), c)` might not equal to `max(min(a, c), b)` if `c` is not always less or equal than `b`

Pull Request resolved: #122148
Approved by: https://github.com/huydhn

(cherry picked from commit 34f36a2)
  • Loading branch information
malfet authored and pytorchbot committed Apr 4, 2024
1 parent a8f93a5 commit b874749
Showing 1 changed file with 35 additions and 23 deletions.
58 changes: 35 additions & 23 deletions aten/src/ATen/native/mps/operations/TensorCompare.mm
Expand Up @@ -30,41 +30,53 @@ static void clamp_mps_graph(CachedGraph* cachedGraph,
const Tensor& min_tensor,
const Tensor& max_tensor) {
auto input_dtype = input_tensor.scalar_type();
auto min_dtype = input_dtype;
auto max_dtype = input_dtype;
if (cachedGraph->minTensor) {
min_dtype = min_tensor.scalar_type();
}
if (cachedGraph->maxTensor) {
max_dtype = max_tensor.scalar_type();
}
auto min_dtype = cachedGraph->minTensor ? min_tensor.scalar_type() : input_dtype;
auto max_dtype = cachedGraph->maxTensor ? max_tensor.scalar_type() : input_dtype;

MPSGraph* mpsGraph = cachedGraph->graph();

cachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_tensor);

MPSGraphTensor* minTensor = cachedGraph->minTensor;
MPSGraphTensor* maxTensor = cachedGraph->maxTensor;
auto minTensor = cachedGraph->minTensor;
auto maxTensor = cachedGraph->maxTensor;

if (input_dtype != min_dtype) {
minTensor = castMPSTensor(mpsGraph, cachedGraph->minTensor, input_dtype);
}
if (input_dtype != max_dtype) {
maxTensor = castMPSTensor(mpsGraph, cachedGraph->maxTensor, input_dtype);
}
if (cachedGraph->minTensor && cachedGraph->maxTensor) {
cachedGraph->outputTensor = [mpsGraph clampWithTensor:cachedGraph->inputTensor
minValueTensor:minTensor
maxValueTensor:maxTensor
name:nil];
} else if (cachedGraph->maxTensor) {
cachedGraph->outputTensor = [mpsGraph minimumWithPrimaryTensor:cachedGraph->inputTensor
secondaryTensor:maxTensor
name:nil];
} else if (cachedGraph->minTensor) {
cachedGraph->outputTensor = [mpsGraph maximumWithPrimaryTensor:cachedGraph->inputTensor
secondaryTensor:minTensor
name:nil];
if (c10::isIntegralType(input_dtype, /*includeBool=*/true)) {
if (minTensor && maxTensor) {
cachedGraph->outputTensor = [mpsGraph clampWithTensor:cachedGraph->inputTensor
minValueTensor:minTensor
maxValueTensor:maxTensor
name:nil];
} else if (maxTensor) {
cachedGraph->outputTensor = [mpsGraph minimumWithPrimaryTensor:cachedGraph->inputTensor
secondaryTensor:maxTensor
name:nil];
} else if (minTensor) {
cachedGraph->outputTensor = [mpsGraph maximumWithPrimaryTensor:cachedGraph->inputTensor
secondaryTensor:minTensor
name:nil];
}
return;
}
// clampWithTensor doesn't propagate NaN through so simulate it as composition of
// maximumWithNaNPropagationWithPrimaryTensor and minimumWithNaNPropagationWithPrimaryTensor
auto outputTensor = cachedGraph->inputTensor;
if (minTensor) {
outputTensor = [mpsGraph maximumWithNaNPropagationWithPrimaryTensor:outputTensor
secondaryTensor:minTensor
name:nil];
}
if (maxTensor) {
outputTensor = [mpsGraph minimumWithNaNPropagationWithPrimaryTensor:outputTensor
secondaryTensor:maxTensor
name:nil];
}
cachedGraph->outputTensor = outputTensor;
}

static void check_min_max_dims(const OptionalTensorRef clamp_opt, const Tensor& input_t, string op_name) {
Expand Down

0 comments on commit b874749

Please sign in to comment.