Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions test/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import numpy as np

import torch
from torch import nn
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.utils.utils as xu
import torch_xla.experimental.xla_sharding as xs
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
Expand Down Expand Up @@ -137,6 +139,25 @@ def test_mark_sharding(self):
dtype=torch.float,
device=xm.xla_device())))

def test_metrics_recorded(self):
met.clear_counters()
partition_spec = (0, 1)
xt1 = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]],
dtype=torch.float,
device=xm.xla_device())
xs.mark_sharding(xt1, self._get_mesh((1, self.n_devices)), partition_spec)
self.assertIn("VirtualDeviceUsage", met.counter_names())
self.assertNotEqual(met.counter_value("VirtualDeviceUsage"), 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me. Could we add another test case showing that the model param sharding doesn't use virtual device?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test for sharding on model weights. It looks like virtual device will still be used to delay the transfer of the model weights in nn.Linear(128, 64).to(xm.xla_device()) until the sharding is applied.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!


def test_model_weight_metrics(self):
met.clear_counters()
partition_spec = (0, 1)
model = nn.Linear(128, 64).to(xm.xla_device())
xs.mark_sharding(model.weight, self._get_mesh((1, self.n_devices)),
partition_spec)
self.assertIn("VirtualDeviceUsage", met.counter_names())
self.assertNotEqual(met.counter_value("VirtualDeviceUsage"), 0)


if __name__ == '__main__':
test = unittest.main()
Expand Down