From b0bd5c4508a8923685965614c2c74e6a8c82f7ba Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Wed, 7 Dec 2022 07:24:55 +0000 Subject: [PATCH] [MPS] Fix median_out_mps caching (#90326) We should cache graph based on input tensor type Fixes https://github.com/pytorch/pytorch/issues/90311 Pull Request resolved: https://github.com/pytorch/pytorch/pull/90326 Approved by: https://github.com/kulinseth --- aten/src/ATen/native/mps/operations/ReduceOps.mm | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm index d905107b8ffd4..8a321ffd2fb12 100644 --- a/aten/src/ATen/native/mps/operations/ReduceOps.mm +++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm @@ -1809,11 +1809,11 @@ Tensor median_mps(const Tensor& input_t) { auto stream = at::mps::getCurrentMPSStream(); @autoreleasepool { - string key = func_name + ":" + to_string(dim_) + ":" + native_mps::getMPSTypeString(input_t.scalar_type()); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + string key = func_name + ":" + to_string(dim_) + ":" + native_mps::getTensorsStringKey(input_t); + CachedGraph* cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { - native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () { + cachedGraph = cache_->CreateCachedGraphAs(key, ^ native_mps::MPSCachedGraph * () { CachedGraph *newCachedGraph = nil; @@ -1849,7 +1849,6 @@ Tensor median_mps(const Tensor& input_t) { } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); } auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t);