Skip to content

Commit

Permalink
[MPS] Fwd-fix for clamp regression
Browse files Browse the repository at this point in the history
Forward fix for regresion introdcued by
#121381

- Do not call `minimumWithNaNPropagationWithPrimaryTensor` for integer
  typed tensors as it will crash with
  ```
    /AppleInternal/Library/BuildRoots/ce725a5f-c761-11ee-a4ec-b6ef2fd8d87b/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Utility/MPSKernelDAG.mm:805: failed assertion `Error getting visible function:
 (null) Function isNaN_i16_i8 was not found in the library'
   ```
- Change the order of max and min call as it's apparently important for
  consistency, as `min(max(a, b), c)` might not equal to `max(min(a, c), b)` if `c` is not always less or equal than `b`
  • Loading branch information
malfet committed Mar 19, 2024
1 parent e6cf3e9 commit e9f105e
Showing 1 changed file with 26 additions and 15 deletions.
41 changes: 26 additions & 15 deletions aten/src/ATen/native/mps/operations/TensorCompare.mm
Expand Up @@ -30,41 +30,52 @@ static void clamp_mps_graph(CachedGraph* cachedGraph,
const Tensor& min_tensor,
const Tensor& max_tensor) {
auto input_dtype = input_tensor.scalar_type();
auto min_dtype = input_dtype;
auto max_dtype = input_dtype;
if (cachedGraph->minTensor) {
min_dtype = min_tensor.scalar_type();
}
if (cachedGraph->maxTensor) {
max_dtype = max_tensor.scalar_type();
}
auto min_dtype = cachedGraph->minTensor ? min_tensor.scalar_type() : input_dtype;
auto max_dtype = cachedGraph->maxTensor ? max_tensor.scalar_type() : input_dtype;

MPSGraph* mpsGraph = cachedGraph->graph();

cachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_tensor);

auto minTensor = cachedGraph->minTensor;
auto maxTensor = cachedGraph->maxTensor;
auto outputTensor = cachedGraph->inputTensor;

if (input_dtype != min_dtype) {
minTensor = castMPSTensor(mpsGraph, cachedGraph->minTensor, input_dtype);
}
if (input_dtype != max_dtype) {
maxTensor = castMPSTensor(mpsGraph, cachedGraph->maxTensor, input_dtype);
}
// clampWithTensor doesn't propagate NaN through so simulate it as composition of
// minimumWithNaNPropagationWithPrimaryTensor and maximumWithNaNPropagationWithPrimaryTensor
if (maxTensor) {
outputTensor = [mpsGraph minimumWithNaNPropagationWithPrimaryTensor:outputTensor
secondaryTensor:maxTensor
name:nil];
if (c10::isIntegralType(input_dtype, /*includeBool=*/true)) {
if (minTensor && maxTensor) {
cachedGraph->outputTensor = [mpsGraph clampWithTensor:cachedGraph->inputTensor
minValueTensor:minTensor
maxValueTensor:maxTensor
name:nil];
} else if (maxTensor) {
cachedGraph->outputTensor = [mpsGraph minimumWithPrimaryTensor:cachedGraph->inputTensor
secondaryTensor:maxTensor
name:nil];
} else if (minTensor) {
cachedGraph->outputTensor = [mpsGraph maximumWithPrimaryTensor:cachedGraph->inputTensor
secondaryTensor:minTensor
name:nil];
}
return;
}
// clampWithTensor doesn't propagate NaN through so simulate it as composition of
// maximumWithNaNPropagationWithPrimaryTensor and minimumWithNaNPropagationWithPrimaryTensor
auto outputTensor = cachedGraph->inputTensor;
if (minTensor) {
outputTensor = [mpsGraph maximumWithNaNPropagationWithPrimaryTensor:outputTensor
secondaryTensor:minTensor
name:nil];
}
if (maxTensor) {
outputTensor = [mpsGraph minimumWithNaNPropagationWithPrimaryTensor:outputTensor
secondaryTensor:maxTensor
name:nil];
}
cachedGraph->outputTensor = outputTensor;
}

Expand Down

0 comments on commit e9f105e

Please sign in to comment.