Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .circleci/common.sh
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,8 @@ function run_torch_xla_tests() {

# GPU tests
if [ -x "$(command -v nvidia-smi)" ]; then
# Broke by functionalization.
# python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1
# python test/test_train_mp_imagenet_fsdp.py --fake_data --auto_wrap_policy type_based --use_small_fake_sample --num_epochs=1
python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1
python test/test_train_mp_imagenet_fsdp.py --fake_data --auto_wrap_policy type_based --use_small_fake_sample --num_epochs=1
# Syncfree SGD optimizer tests
if [ -d ./torch_xla/amp/syncfree ]; then
echo "Running Syncfree Optimizer Test"
Expand Down
42 changes: 18 additions & 24 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,16 +905,6 @@ def test_inplace_view_non_contig(self):
x.sum().backward()
self.assertEqual(root.grad.tolist(), [[1, 2], [1, 1], [1, 1]])

@unittest.skip(
"functorch.functionalize doesn't seem to support updating .data directly")
def test_view_data_update(self):
a = torch.zeros(4, device=xm.xla_device())
v = a.view(2, 2)
a.data = a.data + 1
self.assertEqual(a.tolist(), [1, 1, 1, 1])
# Upadting a.data should not update v's value.
self.assertEqual(v.tolist(), [[0.0, 0.0], [0.0, 0.0]])

def test_view_out_computation(self):

def func(a, b):
Expand All @@ -926,34 +916,38 @@ def func(a, b):
b = torch.ones([2, 2])
self.runAtenTest((a, b), func)

@unittest.skip("Broken by functionalization")
def test_set(self):
met.clear_all()

t1 = torch.zeros(50, device=xm.xla_device())
t1 += 1
xm.mark_step()
self.assertEqual(met.counter_value('DestroyXlaTensor'), 2)

t1.data = torch.zeros(20, device=xm.xla_device())
self.assertEqual(met.counter_value('DestroyXlaTensor'), 3)

t1.set_(torch.zeros(10, device=xm.xla_device()))
t2 = torch.zeros(10, device=xm.xla_device())
self.assertEqual(met.counter_value('DestroyXlaTensor'), 4)

t2 = torch.zeros(10, device=xm.xla_device())
t1.set_(t2)
self.assertEqual(met.counter_value('DestroyXlaTensor'), 6)

# shouldn't crash
t2.cpu()
self.assertTrue(torch.allclose(t2.cpu(), torch.zeros(10)))

def test_replace_xla_tensor(self):
met.clear_all()

@unittest.skip(
"functorch.functionalize doesn't seem to support updating .data directly")
def test_view_data_slice(self):
t1 = torch.zeros(50, device=xm.xla_device())
t1_slice = t1.data[:5]
# Assigning the view back to origonal tensor's data should be OK.
t1.data = t1_slice
self.assertEqual(t1.tolist(), [0, 0, 0, 0, 0])
t1 += 1
xm.mark_step()
self.assertEqual(met.counter_value('DestroyXlaTensor'), 3)

t2 = torch.zeros(10, device=xm.xla_device())
self.assertEqual(met.counter_value('DestroyXlaTensor'), 4)
torch_xla._XLAC._replace_xla_tensor(t1, t2)
self.assertEqual(met.counter_value('DestroyXlaTensor'), 5)

# shouldn't crash
self.assertTrue(torch.allclose(t2.cpu(), torch.zeros(10)))

def test_pred_type(self):
xla_device = xm.xla_device()
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1609,6 +1609,10 @@ void InitXlaModuleBindings(py::module m) {
MapXlaEnvVarsToLazy();
InitXlaBackend();
});
m.def("_replace_xla_tensor",
[](at::Tensor& self, const at::Tensor& source) -> at::Tensor& {
return XLANativeFunctions::set_(self, source);
});

/* The distributed runtime service is used by the PjRt GPU client. */
py::class_<xla::DistributedRuntimeService,
Expand Down
17 changes: 10 additions & 7 deletions torch_xla/distributed/fsdp/xla_fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.nn.utils.rnn import PackedSequence
import torch_xla
import torch_xla.core.xla_model as xm

from .xla_flatten_params_wrapper import XlaFlattenParamsWrapper
Expand Down Expand Up @@ -656,14 +657,14 @@ def _shard_parameters_(self, params_to_shard) -> None:
".", "_FSDP_SHARD_SEPARATOR_")
self.register_parameter(p_shard._name, p_shard)
self.sharded_params.append(p_shard)
# Free the full parameter storage (here we free its internal XLATensor) but keep the tensor itself
# for auto-grad tracing (like `torch.autograd.Variable` before the tensor-variable merge).
p.set_(p.new_zeros(1))
if p.device != self.xla_device:
# cast to XLA device if not already on XLA
p = p.to(self.xla_device).requires_grad_(p.requires_grad)
# update p in full_params since id(p) changed after the casting
self.full_params[idx] = p
# Free the full parameter storage (here we free its internal XLATensor) but keep the tensor itself
# for auto-grad tracing (like `torch.autograd.Variable` before the tensor-variable merge).
torch_xla._XLAC._replace_xla_tensor(p, p.new_zeros(1))
p._sharded_param = p_shard # add a handle to the sharded parameter
p._has_full_param = False
# deregister the full parameter tensors from their modules (so that they won't
Expand Down Expand Up @@ -1361,10 +1362,12 @@ def _rebuild_full_params(self,
self.optimization_barrier_op([p_padded])
with torch.autograd._unsafe_preserve_version_counter(p):
if self._shard_param_on_dim_0:
p.set_(p_padded[:p_shard._orig_size[0]])
torch_xla._XLAC._replace_xla_tensor(
p, p_padded[:p_shard._orig_size[0]])
else:
p.set_(p_padded[:p_shard._orig_size.numel()].view(
p_shard._orig_size))
torch_xla._XLAC._replace_xla_tensor(
p,
p_padded[:p_shard._orig_size.numel()].view(p_shard._orig_size))
p._has_full_param = True

self.has_full_params = True
Expand Down Expand Up @@ -1395,7 +1398,7 @@ def _free_full_params(self,
if p._has_full_param:
# free the original full parameter
with torch.autograd._unsafe_preserve_version_counter(p):
p.set_(self._dummy_data_placeholder)
torch_xla._XLAC._replace_xla_tensor(p, self._dummy_data_placeholder)
p._has_full_param = False

if apply_opt_barrier:
Expand Down