Skip to content

Commit

Permalink
[MPS] Fix embedding cache key
Browse files Browse the repository at this point in the history
ghstack-source-id: d73319d91d3ff66d49d5781dd2e65fefaaed1b73
Pull Request resolved: #101857
  • Loading branch information
qqaatw committed May 19, 2023
1 parent f994d0b commit d4fd6d4
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions aten/src/ATen/native/mps/operations/Indexing.mm
Original file line number Diff line number Diff line change
Expand Up @@ -819,9 +819,8 @@ Tensor embedding_dense_backward_mps(const Tensor& grad_,
auto stream = at::mps::getCurrentMPSStream();

@autoreleasepool {
string key = "edb_mps:" + getMPSTypeString(grad_) + ":indices" + std::to_string(num_indices_dims) + ":num_weights" +
std::to_string(num_weights) + ":padding_idx" + std::to_string(padding_idx) + ":scaled" +
std::to_string(scale_grad_by_freq);
string key = "edb_mps:" + getTensorsStringKey({grad_, indices}) + ":num_weights" + std::to_string(num_weights) +
":padding_idx" + std::to_string(padding_idx) + ":scaled" + std::to_string(scale_grad_by_freq);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* incomingGradTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(grad_));

Expand Down

0 comments on commit d4fd6d4

Please sign in to comment.