Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import torch_xla.core.xla_builder as xb
import torch_xla.core.xla_op_registry as xor
import torch_xla.distributed.data_parallel as dp
from torch_xla.distributed.fsdp import checkpoint_module
from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear
import torch_xla.debug.metrics as met
import torch_xla.debug.model_comparator as mc
Expand Down Expand Up @@ -2334,6 +2335,41 @@ def test_aten_move_scalar_cuda_to_xla(self):
self._test_move_tensor_cuda_to_xla(torch.tensor(42))


class SimpleModelWithDropout(torch.nn.Module):

def __init__(self):
super().__init__()
self.x = torch.nn.Linear(128, 128)
self.dropout = torch.nn.Dropout(p=0.1)
self.to_save = []

def save_output(self, output):
self.to_save.append(output.detach().cpu())

def forward(self, inp):
x = self.x(inp)
output = self.dropout(x)
xm.add_step_closure(self.save_output, args=(output,), run_async=False)
return output


class TestActivationCheckpoint(test_utils.XlaTestCase):

def test_dropout(self):
device = xm.xla_device()
model = SimpleModelWithDropout().to(device)
model = checkpoint_module(model)
_input = torch.randn(128, 128, requires_grad=True)
_input = _input.to(device)
output = model(_input)
output = torch.sum(output)
output.backward()
xm.mark_step()
same_output = torch.allclose(model.to_save[0], model.to_save[1])
self.assertTrue(same_output,
f"in fwd {model.to_save[0]}, in bwd {model.to_save[1]}")


if __name__ == '__main__':
torch.set_default_dtype(torch.float32)
torch.manual_seed(42)
Expand Down
23 changes: 23 additions & 0 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import io
import itertools
import logging
Expand Down Expand Up @@ -1263,6 +1264,28 @@ def get_rng_state(device=None):
return torch_xla._XLAC._xla_get_rng_seed(str(device) if device else '')


@contextlib.contextmanager
def fork_rng(device=None, enabled=True):
"""
Forks the RNG, so that when you return, the RNG is reset to the state that it was previously in.
Args:
device (string, optional): The device where the RNG state needs to be set. If missing the default device seed will be set.
enabled (bool): if ``False``, the RNG is not forked. This is a convenience argument for easily disabling the context manager without having to delete it and unindent your Python code under it.
"""
if not enabled:
yield
return

if device is None:
device = torch_xla._XLAC._xla_get_default_device()
xla_rng_state = get_rng_state(device=device)

try:
yield
finally:
set_rng_state(xla_rng_state, device=device)


def get_memory_info(device):
"""Retrieves the device memory information.

Expand Down
23 changes: 14 additions & 9 deletions torch_xla/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def forward(ctx, run_function, preserve_rng_state, *args):
"cache_enabled": torch.is_autocast_cache_enabled()
}
if preserve_rng_state:
ctx.fwd_xla_state = xm.get_rng_state()
ctx.fwd_cpu_state = torch.get_rng_state()
# Don't eagerly initialize the cuda context by accident.
# (If the user intends that the context is initialized later, within their
Expand Down Expand Up @@ -143,17 +144,21 @@ def backward(ctx, *args):
rng_devices = ctx.fwd_gpu_devices
xm.optimization_barrier_(
CheckpointFunction._extract_tensors_from_list(inputs + list(args)))
# torch.random.fork_rng will handle the cpu and gpu seed
# xm.fork_rng will handle the xla device seed
with torch.random.fork_rng(
devices=rng_devices, enabled=ctx.preserve_rng_state):
if ctx.preserve_rng_state:
torch.set_rng_state(ctx.fwd_cpu_state)
if ctx.had_cuda_in_fwd:
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
detached_inputs = detach_variable(tuple(inputs))
with torch.enable_grad(), \
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
outputs = ctx.run_function(*detached_inputs)
with xm.fork_rng():
if ctx.preserve_rng_state:
xm.set_rng_state(ctx.fwd_xla_state)
torch.set_rng_state(ctx.fwd_cpu_state)
if ctx.had_cuda_in_fwd:
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
detached_inputs = detach_variable(tuple(inputs))
with torch.enable_grad(), \
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
outputs = ctx.run_function(*detached_inputs)

if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
Expand Down