From 06b6543132fe04ad4b7c872c9992a81198959827 Mon Sep 17 00:00:00 2001 From: Davide Libenzi Date: Sat, 27 Jun 2020 13:56:07 -0700 Subject: [PATCH 1/7] Do not issue a sys.exit() for non forked MP. --- torch_xla/distributed/xla_multiprocessing.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/torch_xla/distributed/xla_multiprocessing.py b/torch_xla/distributed/xla_multiprocessing.py index ba9709a5b0da..abf6c346bb15 100644 --- a/torch_xla/distributed/xla_multiprocessing.py +++ b/torch_xla/distributed/xla_multiprocessing.py @@ -226,9 +226,13 @@ def _start_fn(index, pf_cfg, fn, args): # Calling _setup_replication() will trigger XLA library initialization, so the # environment must be fully setup before doing so. _setup_replication() + fn(gindex, *args) + + +def _mp_start_fn(index, pf_cfg, fn, args): exit_code = 0 try: - fn(gindex, *args) + _start_fn(index, pf_cfg, fn, args) except Exception as e: print( 'Exception in device={}: {}'.format(_get_multiprocessing_device(), @@ -288,7 +292,7 @@ def spawn(fn, _start_fn(0, pf_cfg, fn, args) else: return torch.multiprocessing.start_processes( - _start_fn, + _mp_start_fn, args=(pf_cfg, fn, args), nprocs=pf_cfg.num_devices, join=join, From ca3585741699a870cdb708432c9e51242ed76a77 Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Fri, 21 Aug 2020 19:10:22 -0700 Subject: [PATCH 2/7] Clamp the intermediate value in bce lowering to avoid nan (#2448) --- test/test_operations.py | 14 ++++++++++++++ torch_xla/csrc/reduction.cpp | 14 +++++++++++--- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/test/test_operations.py b/test/test_operations.py index 706f134799fe..9a7e47fbea07 100644 --- a/test/test_operations.py +++ b/test/test_operations.py @@ -663,6 +663,20 @@ def test_get_xla_tensor(self): self.assertEqual(tx, sx.data.cpu()) +class TestBinaryCrossEntropyLimitValue(XlaTestCase): + + def test_cross_entropy_loss(self): + + def test_fn(pred, target): + lossfn = nn.BCELoss() + return lossfn(pred, target) + + pred = torch.tensor(1.0) + target = torch.tensor(1.0) + for offset in [1, 0, 1e-8, 1e-7]: + self.runAtenTest([pred - offset, target], test_fn) + + class TestDynamicShape(XlaTestCase): def test_nonzero_shape(self): diff --git a/torch_xla/csrc/reduction.cpp b/torch_xla/csrc/reduction.cpp index e339f1778080..20e1330d90dd 100644 --- a/torch_xla/csrc/reduction.cpp +++ b/torch_xla/csrc/reduction.cpp @@ -127,6 +127,7 @@ xla::XlaOp CreateProduct(xla::XlaOp input, xla::XlaOp BuildBinaryCrossEntropy(xla::XlaOp input, xla::XlaOp target, const absl::optional& weight, ReductionMode reduction) { + static const float kLogBound = -100; const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); xla::XlaOp xweight; if (weight) { @@ -137,8 +138,11 @@ xla::XlaOp BuildBinaryCrossEntropy(xla::XlaOp input, xla::XlaOp target, XlaHelpers::ScalarBroadcast(1.0, input_shape, target.builder()); } xla::XlaOp one = xla::One(input.builder(), input_shape.element_type()); - xla::XlaOp result = -xweight * (target * xla::Log(input) + - (one - target) * xla::Log(one - input)); + xla::XlaOp log_bound = XlaHelpers::ScalarValue( + kLogBound, input_shape.element_type(), input.builder()); + xla::XlaOp result = + -xweight * (target * xla::Max(xla::Log(input), log_bound) + + (one - target) * xla::Max(xla::Log(one - input), log_bound)); if (reduction == ReductionMode::kNone) { return result; } @@ -154,6 +158,7 @@ xla::XlaOp BuildBinaryCrossEntropy(xla::XlaOp input, xla::XlaOp target, xla::XlaOp BuildBinaryCrossEntropyBackward( xla::XlaOp grad_output, xla::XlaOp input, xla::XlaOp target, const absl::optional& weight, ReductionMode reduction) { + static const float kEpsilon = 1e-12; const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); xla::XlaOp xweight; if (weight) { @@ -164,7 +169,10 @@ xla::XlaOp BuildBinaryCrossEntropyBackward( XlaHelpers::ScalarBroadcast(1.0, input_shape, target.builder()); } xla::XlaOp one = xla::One(input.builder(), input_shape.element_type()); - xla::XlaOp result = xweight * (input - target) / input / (one - input); + xla::XlaOp epsilon = XlaHelpers::ScalarValue( + kEpsilon, input_shape.element_type(), input.builder()); + xla::XlaOp result = + xweight * (input - target) / xla::Max(input * (one - input), epsilon); if (reduction == ReductionMode::kNone) { return result * grad_output; } From cde74785d6037d571897fa4fc1e4cbadfb257730 Mon Sep 17 00:00:00 2001 From: Jin Young Sohn Date: Tue, 1 Sep 2020 22:27:59 +0000 Subject: [PATCH 3/7] Fix version string --- torch_xla/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index 7f4cb6da04d7..a56b38f7330a 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -4,7 +4,7 @@ import socket import time -from .version import __version__ as version +from .version import __version__ def _maybe_select_tpu_version(): @@ -40,9 +40,9 @@ def _wait_for_open(version, timeout=100, interval=10, log=True): import cloud_tpu_client client = cloud_tpu_client.Client(tpu_name) - client.configure_tpu_version(f'pytorch-{version}', restart_type='ifNeeded') + client.configure_tpu_version(f'pytorch-{__version__}', restart_type='ifNeeded') # client.wait_for_healthy() API doesn't work as we dont have TPU API access - _wait_for_open(version) + _wait_for_open(__version__) except ImportError: logging.warning(( 'Not selecting corresponding TPU runtime since cloud_tpu_client is not ' @@ -50,7 +50,7 @@ def _wait_for_open(version, timeout=100, interval=10, log=True): except Exception: # This path is hit, when we get throttled by the verison changer # when we import torch_xla from xmp.spawn-ed processes. - _wait_for_open(version, log=False) + _wait_for_open(__version__, log=False) def _setup_grpc(): From 39bf27f719a8281af2008aa68077ffd8f821f538 Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Thu, 27 Aug 2020 22:55:21 +0000 Subject: [PATCH 4/7] Update tensorflow to include the all_to_all fix --- third_party/tensorflow | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/tensorflow b/third_party/tensorflow index 44067f0783c5..21ca8897d7f1 160000 --- a/third_party/tensorflow +++ b/third_party/tensorflow @@ -1 +1 @@ -Subproject commit 44067f0783c56ad092f6ef5ea1034e6926559d86 +Subproject commit 21ca8897d7f1f82331c5fb8e1c8864b394e9a157 From cba060355ca55dfc5d35288d665a6c78a16281c8 Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Sun, 30 Aug 2020 22:04:39 -0700 Subject: [PATCH 5/7] Add all_to_all back to our doc (#2472) --- docs/source/index.rst | 1 + torch_xla/core/xla_model.py | 3 --- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index 545aee858728..072084665a78 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -16,6 +16,7 @@ xla_model .. autofunction:: xrt_world_size .. autofunction:: all_reduce .. autofunction:: all_gather +.. autofunction:: all_to_all .. autofunction:: add_step_closure .. autofunction:: wait_device_ops .. autofunction:: optimizer_step diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index 13dc18d27292..e2ae87bc9d41 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -478,9 +478,6 @@ def all_to_all(value, groups=None): """Performs an XLA `AllToAll()` operation on the input tensor. - WARNING: This function is not very reliable, may produce wrong results under - certain inputs. Use it at your own risk. - See: https://www.tensorflow.org/xla/operation_semantics#alltoall Args: From 1b372dd1c7494fe0413da7fc692524ed10a86552 Mon Sep 17 00:00:00 2001 From: Jin Young Sohn Date: Tue, 1 Sep 2020 23:09:51 +0000 Subject: [PATCH 6/7] Pin TF to all_to_all fix --- third_party/tensorflow | 2 +- torch_xla/__init__.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/third_party/tensorflow b/third_party/tensorflow index 21ca8897d7f1..21133c9daffe 160000 --- a/third_party/tensorflow +++ b/third_party/tensorflow @@ -1 +1 @@ -Subproject commit 21ca8897d7f1f82331c5fb8e1c8864b394e9a157 +Subproject commit 21133c9daffe5fd991d45359f97bf0be642ecd8b diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index a56b38f7330a..2b7a311cf477 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -40,7 +40,8 @@ def _wait_for_open(version, timeout=100, interval=10, log=True): import cloud_tpu_client client = cloud_tpu_client.Client(tpu_name) - client.configure_tpu_version(f'pytorch-{__version__}', restart_type='ifNeeded') + client.configure_tpu_version( + f'pytorch-{__version__}', restart_type='ifNeeded') # client.wait_for_healthy() API doesn't work as we dont have TPU API access _wait_for_open(__version__) except ImportError: From df0dec303e22fce21fadf0b135c681d624e269dc Mon Sep 17 00:00:00 2001 From: Jin Young Sohn Date: Thu, 3 Sep 2020 14:54:16 +0000 Subject: [PATCH 7/7] Skip test_topk_nonfinite due to TFXLA HLO change --- test/pytorch_test_base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/pytorch_test_base.py b/test/pytorch_test_base.py index 58d75e7c92e5..da024c3a30b7 100644 --- a/test/pytorch_test_base.py +++ b/test/pytorch_test_base.py @@ -169,6 +169,8 @@ 'test_masked_select_mem_overlap', # doesn't raise 'test_scatter_mem_overlap', # doesn't raise 'test_index_mem_overlap', # doesn't raise + 'test_topk_nonfinite_xla_float32', # TFXLA update HLO changed for 1.6 + 'test_topk_nonfinite_xla_float64', # TFXLA update HLO changed for 1.6 }, 'TestViewOpsXLA': { 'test_contiguous_nonview',