Skip to content

Commit

Permalink
MPS: add linespace op (#78570) (#78570)
Browse files Browse the repository at this point in the history
Summary:
Fixes #ISSUE_NUMBER

Pull Request resolved: #78570
Approved by: https://github.com/malfet

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/a3bdafece3a07aea186e34abc28e2540aa078393

Reviewed By: seemethere

Differential Revision: D36815745

Pulled By: seemethere

fbshipit-source-id: 8a6338c34b10b219e76a23e1577a0142ccd4c7b3
  • Loading branch information
kulinseth authored and facebook-github-bot committed Jun 2, 2022
1 parent c8139f5 commit 6027dbe
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 0 deletions.
105 changes: 105 additions & 0 deletions aten/src/ATen/native/mps/operations/RangeFactories.mm
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,109 @@

return result;
}

Tensor& linspace_out_mps(const Scalar& start, const Scalar& end, int64_t steps, Tensor& result) {
using namespace mps;
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *startTensor_ = nil;
MPSGraphTensor *endTensor_ = nil;
MPSGraphTensor *multiplyTensor_ = nil;
MPSGraphTensor *outputTensor_ = nil;
};

TORCH_CHECK(steps >= 0, "number of steps must be non-negative");
if (result.numel() != steps) {
result.resize_({steps});
}

if (steps == 0) {
// skip
} else if (steps == 1) {
result.fill_(start);
} else {
Tensor r = result.is_contiguous() ? result : result.contiguous();

// Do the MPSGraph computation
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
MPSStream* stream = getCurrentMPSStream();

@autoreleasepool {
string key = "linspace_out_mps:" + getTensorsStringKey({result}) + ":" + to_string(steps);
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));

if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {

CachedGraph *newCachedGraph = nil;

@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);

int shapeVal[1] = {(int32_t)steps};
MPSGraphTensor *shapeTensor = [mpsGraph constantWithData:[NSData dataWithBytes:shapeVal length:sizeof(int32_t)]
shape: @[@1]
dataType:MPSDataTypeInt32];
MPSGraphTensor* coordsTensor = [mpsGraph coordinateAlongAxis:0
withShapeTensor:shapeTensor
name:nil];
coordsTensor = [mpsGraph castTensor:coordsTensor toType:MPSDataTypeFloat32 name:@"coords"];

MPSGraphTensor* startTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeFloat32, @[@1]);
MPSGraphTensor* endTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeFloat32, @[@1]);
MPSGraphTensor* multiplyTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeFloat32, @[@1]);
MPSGraphTensor* scaledCoords = [mpsGraph multiplicationWithPrimaryTensor:coordsTensor
secondaryTensor:multiplyTensor
name:nil];
MPSGraphTensor* outputTensor = [mpsGraph additionWithPrimaryTensor:scaledCoords
secondaryTensor:startTensor
name:nil];
if(start.to<double>() <= end.to<double>()) {
outputTensor = [mpsGraph clampWithTensor:outputTensor
minValueTensor:startTensor
maxValueTensor:endTensor
name:nil];
} else {
outputTensor = [mpsGraph clampWithTensor:outputTensor
minValueTensor:endTensor
maxValueTensor:startTensor
name:nil];
}

if(getMPSDataType(result.scalar_type()) != MPSDataTypeFloat32) {
outputTensor = [mpsGraph castTensor:outputTensor toType:getMPSDataType(result.scalar_type()) name:@"output"];
}

newCachedGraph->startTensor_ = startTensor;
newCachedGraph->endTensor_ = endTensor;
newCachedGraph->multiplyTensor_ = multiplyTensor;
newCachedGraph->outputTensor_ = outputTensor;
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}

NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
auto multiplyScalar = (end.to<double>() - start.to<double>()) / ((double)steps - 1.0f);
Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, r);

// Create dictionary of inputs and outputs
feeds[cachedGraph->startTensor_] = getMPSGraphTensorFromScalar(stream, start, MPSDataTypeFloat32);
feeds[cachedGraph->endTensor_] = getMPSGraphTensorFromScalar(stream, end, MPSDataTypeFloat32);
feeds[cachedGraph->multiplyTensor_] = getMPSGraphTensorFromScalar(stream, Scalar(multiplyScalar), MPSDataTypeFloat32);

NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
}

if (!result.is_contiguous()) {
result.copy_(r);
}
}
return result;
}
}} // namespace at::native
1 change: 1 addition & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2782,6 +2782,7 @@
dispatch:
CPU, Meta: linspace_out
CUDA: linspace_cuda_out
MPS: linspace_out_mps

- func: log(Tensor self) -> Tensor
device_check: NoCheck # TensorIterator
Expand Down
13 changes: 13 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -3711,6 +3711,19 @@ def helper(shape, diag=0):
for diag in [0, 1, 2, 3, 4, -1, -2, -3, -4]:
helper(shape, diag=diag)

# Test linspace
def test_linspace(self):
def helper(start, end, steps, dtype=torch.float32):
cpu_result = torch.tensor(np.linspace(start, end, steps), dtype=dtype)
result = torch.linspace(start, end, steps, dtype=dtype, device='mps')
self.assertEqual(cpu_result, result)

for dtype in [torch.float32, torch.int32, torch.uint8, torch.int64]:
helper(2, 5, 10, dtype)
helper(2, 2, 10, dtype)
helper(5, 2, 10, dtype)
helper(2, 2, 0, dtype)

# Test softmax
def test_softmax(self):
def helper(shape, dim, channels_last=False):
Expand Down

0 comments on commit 6027dbe

Please sign in to comment.