Skip to content

Commit

Permalink
Disable autocast cache for tensor views as fix for #48049 (#48696)
Browse files Browse the repository at this point in the history
Summary:
Fixes #48049

Root cause of the issue explained [here](#48049 (comment)).

This PR implements albanD's suggestion to add the `!t.is_view()` check and disable autocast caching for views of tensors.

The added test checks for an increase in memory usage by comparing the initially allocated memory with the memory after 3 iterations using a single `nn.Linear` layer in a `no_grad` and `autocast` context.

After this PR the memory usage in the original issue doesn't grow anymore and yields:
```python
autocast: True
0: 0MB (peak 1165MB)
1: 0MB (peak 1264MB)
2: 0MB (peak 1265MB)
3: 0MB (peak 1265MB)
4: 0MB (peak 1265MB)
5: 0MB (peak 1265MB)
6: 0MB (peak 1265MB)
7: 0MB (peak 1265MB)
8: 0MB (peak 1265MB)
9: 0MB (peak 1265MB)
```

CC ngimel mcarilli

Pull Request resolved: #48696

Reviewed By: bdhirsh

Differential Revision: D25276231

Pulled By: ngimel

fbshipit-source-id: e2571e9f166c0a6f6f569b0c28e8b9ca34132743
  • Loading branch information
pbialecki authored and facebook-github-bot committed Dec 3, 2020
1 parent 0e4f9a7 commit 22c3ae8
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
2 changes: 1 addition & 1 deletion aten/src/ATen/autocast_mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg) {
if (is_eligible(arg) && (arg.scalar_type() != to_type)) {
// Heuristic: Do what Apex does, and cache fp16 casts of fp32 model weights (leaves).
// See cached_casts declaration above for detailed strategy.
bool can_try_cache = (to_type == at::kHalf && arg.scalar_type() == at::kFloat && arg.requires_grad() && arg.is_leaf());
bool can_try_cache = (to_type == at::kHalf && arg.scalar_type() == at::kFloat && arg.requires_grad() && arg.is_leaf() && !arg.is_view());
if (can_try_cache) {
auto it = cached_casts.find(arg.unsafeGetTensorImpl());
if (it != cached_casts.end()) {
Expand Down
16 changes: 16 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -2838,6 +2838,22 @@ def test_autocast_rnn(self):
for grad, grad_control in zip(grads, grads_control):
self.assertEqual(grad.half(), grad_control)

def test_autocast_cache_leak(self):
# Reported at https://github.com/pytorch/pytorch/issues/48049
# Test is used to check, if autocast recaches the same parameters
# when executed in a `torch.no_grad()` block.

linear = torch.nn.Linear(10, 10).to('cuda')
data = torch.randn(1, 10, device='cuda')

with torch.cuda.amp.autocast():
with torch.no_grad():
out = linear(data)
first_iter_mem = torch.cuda.memory_allocated()
for _ in range(3):
out = linear(data)
self.assertTrue(first_iter_mem == torch.cuda.memory_allocated())

@slowTest
@unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
def test_max_large_axis(self):
Expand Down

0 comments on commit 22c3ae8

Please sign in to comment.