Skip to content

Commit

Permalink
Fixes #1096 - bad XLA comp model configuration on single device (#1105)
Browse files Browse the repository at this point in the history
Co-authored-by: Sylvain Desroziers <sylvain.desroziers@gmail.com>
  • Loading branch information
vfdev-5 and sdesrozis committed Jun 4, 2020
1 parent 2281af0 commit f9645c0
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 15 deletions.
15 changes: 15 additions & 0 deletions ignite/distributed/comp_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ def __init__(self):
self._nnodes = None
self._node = None

def _setup_attrs(self):
if self._ntasks_per_node is None:
self._ntasks_per_node = self._compute_ntasks_per_node() if self.get_world_size() > 1 else 1
if self._nnodes is None:
self._nnodes = self.get_world_size() // self._ntasks_per_node
if self._node is None:
self._node = self.get_rank() // self._ntasks_per_node

@abstractmethod
def _compute_ntasks_per_node(self) -> int:
pass

@abstractmethod
def get_local_rank(self) -> int:
pass
Expand Down Expand Up @@ -189,6 +201,9 @@ def backend(self) -> None:
def finalize(self):
pass

def _compute_ntasks_per_node(self) -> int:
return 1

@staticmethod
def create_from_context() -> "_SerialModel":
return _SerialModel()
Expand Down
8 changes: 0 additions & 8 deletions ignite/distributed/comp_models/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,6 @@ def _init_from_context(self):
self._master_addr = None
self._setup_attrs()

def _setup_attrs(self):
if self._ntasks_per_node is None:
self._ntasks_per_node = self._compute_ntasks_per_node()
if self._nnodes is None:
self._nnodes = self.get_world_size() // self._ntasks_per_node
if self._node is None:
self._node = self.get_rank() // self._ntasks_per_node

def _compute_ntasks_per_node(self):
tensor = torch.tensor([self.get_local_rank() + 1]).to(self.device())
dist.all_reduce(tensor, op=dist.ReduceOp.MAX)
Expand Down
10 changes: 4 additions & 6 deletions ignite/distributed/comp_models/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ignite.distributed.comp_models.base import ComputationModel

try:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

Expand All @@ -15,6 +16,7 @@

class _XlaDistModel(ComputationModel):
"""Private class for PyTorch XLA basic distributed computation model.
It handles single/multi-device computation model.
Supported XLA devices:
Expand Down Expand Up @@ -58,11 +60,6 @@ def _init_from_context(self):
self._backend = "xla-tpu"
self._setup_attrs()

def _setup_attrs(self):
self._ntasks_per_node = self._compute_ntasks_per_node()
self._nnodes = self.get_world_size() // self.get_ntasks_per_node()
self._node = self.get_rank() // self._ntasks_per_node

def _compute_ntasks_per_node(self):
tensor = torch.tensor([self.get_local_rank() + 1.0], dtype=torch.float).to(self.device())
xm.all_reduce("max", [tensor,])
Expand All @@ -87,7 +84,8 @@ def get_node_rank(self) -> int:
return self._node

def device(self) -> torch.device:
return xm.xla_device()
dev = torch_xla._XLAC._xla_get_default_device()
return torch.device(dev)

def backend(self) -> str:
return self._backend
Expand Down
2 changes: 1 addition & 1 deletion ignite/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def another_wrapper(self: Metric, *args, **kwargs) -> Callable:
if len(attrs) > 0 and not self._is_reduced:
for attr in attrs:
t = getattr(self, attr, None)
if t is not None:
if t is not None and idist.get_world_size() > 1:
t = idist.all_reduce(t)
self._is_reduced = True
setattr(self, attr, t)
Expand Down
55 changes: 55 additions & 0 deletions tests/ignite/distributed/comp_models/test_xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,58 @@ def _test__xla_dist_model_create_from_context_in_child_proc(index):
def test__xla_dist_model_create_from_context_in_child_proc(xmp_executor):
n = int(os.environ["NUM_TPU_WORKERS"])
xmp_executor(_test__xla_dist_model_create_from_context_in_child_proc, args=(), nprocs=n)


def main_fold(fold):
import time
import torch.nn as nn
import torch.optim as optim
import torch_xla.core.xla_model as xm
from ignite.engine import Engine, Events

device = xm.xla_device(fold + 1)

comp_model = _XlaDistModel.create_from_context()
assert comp_model.device() == device

model = nn.Linear(100, 10)
device = xm.xla_device(fold + 1)

model.to(device) # Move model before creating optimizer
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

def training_step(engine, _):
data = torch.rand(4, 100, device=device)
model.train()
data = data.to(device)
optimizer.zero_grad()
output = model(data)
loss = output.sum()
loss.backward()
xm.optimizer_step(optimizer, barrier=True)
return loss.item()

trainer = Engine(training_step)

# THIS CAN BE A CAUSE OF CRASH if DEVICE is OTHER THAN device
tensor = torch.tensor([fold + 1.0], dtype=torch.float).to(comp_model.device())
xm.all_reduce("max", [tensor,])

time.sleep(0.01 * fold)

@trainer.on(Events.ITERATION_COMPLETED)
def log_progress():
print(".", end=" ")

trainer.run([0] * 100, max_epochs=2)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars")
@pytest.mark.skipif(not has_xla_support, reason="Skip if no PyTorch XLA package")
def test__xla_dist_model_run_parallel_n_threads_without_sync():
# tests issue : https://github.com/pytorch/ignite/issues/1096
from joblib import Parallel, delayed

folds = 5
Parallel(n_jobs=folds, backend="threading")(delayed(main_fold)(i) for i in range(folds))

0 comments on commit f9645c0

Please sign in to comment.