Skip to content

Commit

Permalink
[MPS] Fix embedding cache key
Browse files Browse the repository at this point in the history
ghstack-source-id: 3ba9c3d6cc9cfa7977cbe43f412dad099aef5bad
Pull Request resolved: #101857
  • Loading branch information
qqaatw committed May 19, 2023
1 parent f994d0b commit 6b85dd2
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
5 changes: 2 additions & 3 deletions aten/src/ATen/native/mps/operations/Indexing.mm
Original file line number Diff line number Diff line change
Expand Up @@ -819,9 +819,8 @@ Tensor embedding_dense_backward_mps(const Tensor& grad_,
auto stream = at::mps::getCurrentMPSStream();

@autoreleasepool {
string key = "edb_mps:" + getMPSTypeString(grad_) + ":indices" + std::to_string(num_indices_dims) + ":num_weights" +
std::to_string(num_weights) + ":padding_idx" + std::to_string(padding_idx) + ":scaled" +
std::to_string(scale_grad_by_freq);
string key = "edb_mps:" + getTensorsStringKey({grad_, indices}) + ":num_weights" + std::to_string(num_weights) +
":padding_idx" + std::to_string(padding_idx) + ":scaled" + std::to_string(scale_grad_by_freq);
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* incomingGradTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(grad_));

Expand Down
1 change: 1 addition & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -6719,6 +6719,7 @@ def helper(n, d, m, idx):
self.assertEqual(a_CPU.grad, a_MPS.grad)

helper(3, 5, 7, [0, 1, 2])
helper(3, 6, 7, [0, 1, 2])
helper(3, 5, 7, 2) # test scalar index

# Test pytorch gather
Expand Down

0 comments on commit 6b85dd2

Please sign in to comment.