Skip to content

Commit

Permalink
[MPS] Fix median_out_mps caching (#90326)
Browse files Browse the repository at this point in the history
We should cache graph based on input tensor type

Fixes #90311

Pull Request resolved: #90326
Approved by: https://github.com/kulinseth
  • Loading branch information
malfet authored and pytorchmergebot committed Dec 7, 2022
1 parent 85ae28b commit b0bd5c4
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions aten/src/ATen/native/mps/operations/ReduceOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -1809,11 +1809,11 @@ Tensor median_mps(const Tensor& input_t) {
auto stream = at::mps::getCurrentMPSStream();

@autoreleasepool {
string key = func_name + ":" + to_string(dim_) + ":" + native_mps::getMPSTypeString(input_t.scalar_type());
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
string key = func_name + ":" + to_string(dim_) + ":" + native_mps::getTensorsStringKey(input_t);
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);

if(!cachedGraph) {
native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () {
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ native_mps::MPSCachedGraph * () {

CachedGraph *newCachedGraph = nil;

Expand Down Expand Up @@ -1849,7 +1849,6 @@ Tensor median_mps(const Tensor& input_t) {
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}

auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t);
Expand Down

0 comments on commit b0bd5c4

Please sign in to comment.