Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ xla_model
.. autofunction:: xrt_world_size
.. autofunction:: all_reduce
.. autofunction:: all_gather
.. autofunction:: all_to_all
.. autofunction:: add_step_closure
.. autofunction:: wait_device_ops
.. autofunction:: optimizer_step
Expand Down
2 changes: 2 additions & 0 deletions test/pytorch_test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@
'test_masked_select_mem_overlap', # doesn't raise
'test_scatter_mem_overlap', # doesn't raise
'test_index_mem_overlap', # doesn't raise
'test_topk_nonfinite_xla_float32', # TFXLA update HLO changed for 1.6
'test_topk_nonfinite_xla_float64', # TFXLA update HLO changed for 1.6
},
'TestViewOpsXLA': {
'test_contiguous_nonview',
Expand Down
14 changes: 14 additions & 0 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,20 @@ def test_get_xla_tensor(self):
self.assertEqual(tx, sx.data.cpu())


class TestBinaryCrossEntropyLimitValue(XlaTestCase):

def test_cross_entropy_loss(self):

def test_fn(pred, target):
lossfn = nn.BCELoss()
return lossfn(pred, target)

pred = torch.tensor(1.0)
target = torch.tensor(1.0)
for offset in [1, 0, 1e-8, 1e-7]:
self.runAtenTest([pred - offset, target], test_fn)


class TestDynamicShape(XlaTestCase):

def test_nonzero_shape(self):
Expand Down
2 changes: 1 addition & 1 deletion third_party/tensorflow
Submodule tensorflow updated 5737 files
9 changes: 5 additions & 4 deletions torch_xla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import socket
import time

from .version import __version__ as version
from .version import __version__


def _maybe_select_tpu_version():
Expand Down Expand Up @@ -40,17 +40,18 @@ def _wait_for_open(version, timeout=100, interval=10, log=True):

import cloud_tpu_client
client = cloud_tpu_client.Client(tpu_name)
client.configure_tpu_version(f'pytorch-{version}', restart_type='ifNeeded')
client.configure_tpu_version(
f'pytorch-{__version__}', restart_type='ifNeeded')
# client.wait_for_healthy() API doesn't work as we dont have TPU API access
_wait_for_open(version)
_wait_for_open(__version__)
except ImportError:
logging.warning((
'Not selecting corresponding TPU runtime since cloud_tpu_client is not '
'installed. Ignore if not running on Colab/Kaggle TPU.'))
except Exception:
# This path is hit, when we get throttled by the verison changer
# when we import torch_xla from xmp.spawn-ed processes.
_wait_for_open(version, log=False)
_wait_for_open(__version__, log=False)


def _setup_grpc():
Expand Down
3 changes: 0 additions & 3 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,9 +478,6 @@ def all_to_all(value,
groups=None):
"""Performs an XLA `AllToAll()` operation on the input tensor.

WARNING: This function is not very reliable, may produce wrong results under
certain inputs. Use it at your own risk.

See: https://www.tensorflow.org/xla/operation_semantics#alltoall

Args:
Expand Down
14 changes: 11 additions & 3 deletions torch_xla/csrc/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ xla::XlaOp CreateProduct(xla::XlaOp input,
xla::XlaOp BuildBinaryCrossEntropy(xla::XlaOp input, xla::XlaOp target,
const absl::optional<xla::XlaOp>& weight,
ReductionMode reduction) {
static const float kLogBound = -100;
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
xla::XlaOp xweight;
if (weight) {
Expand All @@ -137,8 +138,11 @@ xla::XlaOp BuildBinaryCrossEntropy(xla::XlaOp input, xla::XlaOp target,
XlaHelpers::ScalarBroadcast<float>(1.0, input_shape, target.builder());
}
xla::XlaOp one = xla::One(input.builder(), input_shape.element_type());
xla::XlaOp result = -xweight * (target * xla::Log(input) +
(one - target) * xla::Log(one - input));
xla::XlaOp log_bound = XlaHelpers::ScalarValue(
kLogBound, input_shape.element_type(), input.builder());
xla::XlaOp result =
-xweight * (target * xla::Max(xla::Log(input), log_bound) +
(one - target) * xla::Max(xla::Log(one - input), log_bound));
if (reduction == ReductionMode::kNone) {
return result;
}
Expand All @@ -154,6 +158,7 @@ xla::XlaOp BuildBinaryCrossEntropy(xla::XlaOp input, xla::XlaOp target,
xla::XlaOp BuildBinaryCrossEntropyBackward(
xla::XlaOp grad_output, xla::XlaOp input, xla::XlaOp target,
const absl::optional<xla::XlaOp>& weight, ReductionMode reduction) {
static const float kEpsilon = 1e-12;
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
xla::XlaOp xweight;
if (weight) {
Expand All @@ -164,7 +169,10 @@ xla::XlaOp BuildBinaryCrossEntropyBackward(
XlaHelpers::ScalarBroadcast<float>(1.0, input_shape, target.builder());
}
xla::XlaOp one = xla::One(input.builder(), input_shape.element_type());
xla::XlaOp result = xweight * (input - target) / input / (one - input);
xla::XlaOp epsilon = XlaHelpers::ScalarValue(
kEpsilon, input_shape.element_type(), input.builder());
xla::XlaOp result =
xweight * (input - target) / xla::Max(input * (one - input), epsilon);
if (reduction == ReductionMode::kNone) {
return result * grad_output;
}
Expand Down
8 changes: 6 additions & 2 deletions torch_xla/distributed/xla_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,13 @@ def _start_fn(index, pf_cfg, fn, args):
# Calling _setup_replication() will trigger XLA library initialization, so the
# environment must be fully setup before doing so.
_setup_replication()
fn(gindex, *args)


def _mp_start_fn(index, pf_cfg, fn, args):
exit_code = 0
try:
fn(gindex, *args)
_start_fn(index, pf_cfg, fn, args)
except Exception as e:
print(
'Exception in device={}: {}'.format(_get_multiprocessing_device(),
Expand Down Expand Up @@ -288,7 +292,7 @@ def spawn(fn,
_start_fn(0, pf_cfg, fn, args)
else:
return torch.multiprocessing.start_processes(
_start_fn,
_mp_start_fn,
args=(pf_cfg, fn, args),
nprocs=pf_cfg.num_devices,
join=join,
Expand Down