Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.utils.checkpoint.checkpoint + torch.cuda.amp #49757

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions test/test_cuda.py
Expand Up @@ -16,6 +16,7 @@
import torch.cuda.comm as comm
from torch import multiprocessing as mp
from torch.nn.parallel import scatter_gather
from torch.utils.checkpoint import checkpoint_sequential
from torch._six import inf, nan, container_abcs

from test_torch import AbstractTestCases
Expand Down Expand Up @@ -2882,6 +2883,17 @@ def test_autocast_cache_leak(self):
out = linear(data)
self.assertTrue(first_iter_mem == torch.cuda.memory_allocated())

def test_autocast_checkpointing(self):
model = torch.nn.Sequential(torch.nn.Linear(8, 8),
torch.nn.Linear(8, 8),
torch.nn.Linear(8, 8)).cuda()
input = torch.rand((8, 8), device="cuda", dtype=torch.float16, requires_grad=True)
with torch.cuda.amp.autocast():
output = checkpoint_sequential(model, 2, input)
self.assertTrue(output.requires_grad)
self.assertTrue(output.dtype is torch.float16)
output.sum().backward()

@slowTest
@unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
def test_max_large_axis(self):
Expand Down
3 changes: 2 additions & 1 deletion torch/utils/checkpoint.py
Expand Up @@ -59,6 +59,7 @@ def forward(ctx, run_function, preserve_rng_state, *args):
check_backward_validity(args)
ctx.run_function = run_function
ctx.preserve_rng_state = preserve_rng_state
ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
if preserve_rng_state:
ctx.fwd_cpu_state = torch.get_rng_state()
# Don't eagerly initialize the cuda context by accident.
Expand Down Expand Up @@ -91,7 +92,7 @@ def backward(ctx, *args):
if ctx.had_cuda_in_fwd:
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
detached_inputs = detach_variable(inputs)
with torch.enable_grad():
with torch.enable_grad(), torch.cuda.amp.autocast(ctx.had_autocast_in_fwd):
outputs = ctx.run_function(*detached_inputs)

if isinstance(outputs, torch.Tensor):
Expand Down