Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPS: add linespace op #78570

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions aten/src/ATen/native/mps/operations/RangeFactories.mm
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,113 @@

return result;
}

Tensor& linspace_out_mps(const Scalar& start, const Scalar& end, int64_t steps, Tensor& result) {

using namespace mps;

struct CachedGraph : public MPSCachedGraph
{
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor *outputTensor_ = nil;
};

TORCH_CHECK(steps >= 0, "number of steps must be non-negative");
if (result.numel() != steps) {
result.resize_({steps});
}
auto ns_steps = [NSNumber numberWithInt:steps];

if (steps == 0) {
// skip
} else if (steps == 1) {
result.fill_(start);
} else {
Tensor r = result.is_contiguous() ? result : result.contiguous();

// Do the MPSGraph computation
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
MPSStream* stream = getCurrentMPSStream();

@autoreleasepool {
string key = "linspace_out_mps:" + getTensorsStringKey({result}) + ":" + to_string(start.to<double>()) + to_string(end.to<double>());
kulinseth marked this conversation as resolved.
Show resolved Hide resolved
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));

if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {

CachedGraph *newCachedGraph = nil;

@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);

int shapeVal[1] = {(uint32_t)steps};
kulinseth marked this conversation as resolved.
Show resolved Hide resolved
MPSGraphTensor *shapeTensor = [mpsGraph constantWithData:[NSData dataWithBytes:shapeVal length:sizeof(uint32_t) * 1]
shape:@[[NSNumber numberWithUnsignedInteger:1]]
dataType:MPSDataTypeUInt32];

// passing selector of reLUWithTensor on the mpsGraph object
MPSGraphTensor* coordsTensor = [mpsGraph coordinateAlongAxis:0
withShapeTensor:shapeTensor
name:nil];
coordsTensor = [mpsGraph castTensor:coordsTensor toType:MPSDataTypeFloat32 name:@"coords"];

auto multiplyScalar = (end.to<double>() - start.to<double>()) / ((double)steps - 1.0f);
MPSGraphTensor* startTensor = [mpsGraph constantWithScalar:start.to<double>()
dataType:MPSDataTypeFloat32];
MPSGraphTensor* endTensor = [mpsGraph constantWithScalar:end.to<double>()
dataType:MPSDataTypeFloat32];
MPSGraphTensor* multiplyTensor = [mpsGraph constantWithScalar:multiplyScalar
dataType:MPSDataTypeFloat32];

MPSGraphTensor* scaledCoords = [mpsGraph multiplicationWithPrimaryTensor:coordsTensor
secondaryTensor:multiplyTensor
name:nil];
MPSGraphTensor* outputTensor = [mpsGraph additionWithPrimaryTensor:scaledCoords
secondaryTensor:startTensor
name:nil];
if(start.to<double>() <= end.to<double>())
outputTensor = [mpsGraph clampWithTensor:outputTensor
minValueTensor:startTensor
maxValueTensor:endTensor
name:nil];
else
outputTensor = [mpsGraph clampWithTensor:outputTensor
minValueTensor:endTensor
maxValueTensor:startTensor
name:nil];

if(getMPSDataType(result.scalar_type()) != MPSDataTypeFloat32)
outputTensor = [mpsGraph castTensor:outputTensor toType:getMPSDataType(result.scalar_type()) name:@"output"];

newCachedGraph->outputTensor_ = outputTensor;
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}

Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, r);

// Create dictionary of inputs and outputs
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = nil;

NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};

runMPSGraph(stream, cachedGraph->graph(), feeds, results);

}

if (!result.is_contiguous()) {
result.copy_(r);
}
}

return result;

}

}} // namespace at::native
1 change: 1 addition & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2782,6 +2782,7 @@
dispatch:
CPU, Meta: linspace_out
CUDA: linspace_cuda_out
MPS: linspace_out_mps

- func: log(Tensor self) -> Tensor
device_check: NoCheck # TensorIterator
Expand Down
13 changes: 13 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -3711,6 +3711,19 @@ def helper(shape, diag=0):
for diag in [0, 1, 2, 3, 4, -1, -2, -3, -4]:
helper(shape, diag=diag)

# Test linspace
def test_linspace(self):
def helper(start, end, steps, dtype=torch.float32):
cpu_result = torch.tensor(np.linspace(start, end, steps), dtype=dtype)
result = torch.linspace(start, end, steps, dtype=dtype, device='mps')
self.assertEqual(cpu_result, result)

for dtype in [torch.float32, torch.int32, torch.uint8, torch.int64]:
helper(2, 5, 10, dtype)
helper(2, 2, 10, dtype)
helper(5, 2, 10, dtype)
helper(2, 2, 0, dtype)

# Test softmax
def test_softmax(self):
def helper(shape, dim, channels_last=False):
Expand Down