Skip to content

Commit

Permalink
Fix for MPS regression in #122016
Browse files Browse the repository at this point in the history
  • Loading branch information
jhavukainen authored and pytorchmergebot committed Apr 3, 2024
1 parent 74b3a79 commit 16eb447
Showing 1 changed file with 9 additions and 20 deletions.
29 changes: 9 additions & 20 deletions aten/src/ATen/native/mps/operations/ConstantOps.mm
Expand Up @@ -28,43 +28,32 @@

struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
};

@autoreleasepool {
string key = "fill_scalar_mps_impl" + getTensorsStringKey(self) + ":" + to_string(value.toDouble());

auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
auto isBool = self.scalar_type() == c10::ScalarType::Bool;
auto isUInt8 = self.scalar_type() == c10::ScalarType::Byte;
auto dataType = !isUInt8 ? !isBool ? getMPSScalarType(self.scalar_type()) : MPSDataTypeInt8 : MPSDataTypeUInt32;
// constantWithScalar does not work for boolTypes on MacOS-12.[34]
// workaround by filing it as int8 tensor and than casting to bool
// See https://github.com/pytorch/pytorch/issues/82427
// constantWithScalar does not work for UInt8 Types on MacOS-12.[34]/Ventura preview
// workaround by filing it as uint32 tensor and than casting to uint8
// See https://github.com/pytorch/pytorch/issues/83692
MPSGraphTensor* inputTensor = [mpsGraph constantWithScalar:value.toDouble()
shape:getMPSShape(self)
dataType:dataType];
MPSGraphTensor* inputTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSDataType(self.scalar_type()));
MPSGraphTensor* outputTensor = [mpsGraph identityWithTensor:inputTensor name:nil];
if (isBool) {
outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeBool name:@"constWithBool-workaround"];
}
if (isUInt8) {
outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeUInt8 name:@"constWithUInt8-workaround"];
}

newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;
});

auto mpsScalar = getMPSScalar(value, self.scalar_type());
auto mpsScalarData = getMPSGraphTensorFromScalar(getCurrentMPSStream(), mpsScalar);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =
@{cachedGraph->inputTensor_ : mpsScalarData};

Placeholder outputPlaceholder =
Placeholder(cachedGraph->outputTensor_, needsCopyToOutput ? output : self, nullptr, !needsCopyToOutput);

NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results =
@{outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()};

runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), /*feeds*/ nil, results);
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);

if (needsCopyToOutput) {
self.copy_(output);
Expand Down

0 comments on commit 16eb447

Please sign in to comment.