Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions scripts/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,24 +290,31 @@ class TensorFetcher(object):
def __init__(self, var_name):
self.var_name = var_name
self.tvar_name = '{}_tensors'.format(self.var_name)
self.wvar_name = '{}_writeables'.format(self.var_name)
self.tensors = []
self.writeable = []

def add(self, name, writeable):
if writeable:
self.writeable.append(len(self.tensors))
self.tensors.append(name)
self.writeable.append(writeable)
return '{}[{}]'.format(self.var_name, len(self.tensors) - 1)

def generate(self):
def generate_fetches(self):
code = ''
code += ' std::vector<at::Tensor> {} = {{{}}};\n'.format(
self.tvar_name, ', '.join(self.tensors))
writeable_strings = ['true' if w else 'false' for w in self.writeable]
code += ' std::vector<bool> {} = {{{}}};\n'.format(
self.wvar_name, ', '.join(writeable_strings))
code += (' auto {} = bridge::XlaCreateTensorList({}, &{});\n').format(
self.var_name, self.tvar_name, self.wvar_name)
code += (' auto {} = bridge::XlaCreateTensorList({});\n').format(
self.var_name, self.tvar_name)
return code

def generate_updates(self):
code = ''
if self.writeable:
ivar_name = '{}_update_indices'.format(self.var_name)
code += ' std::vector<size_t> {} = {{{}}};\n'.format(
ivar_name, ', '.join(str(x) for x in self.writeable))
code += ' bridge::XlaUpdateTensors({}, {}, {});\n'.format(
self.tvar_name, self.var_name, ivar_name)
return code


Expand Down Expand Up @@ -795,9 +802,8 @@ def generate_aten_to_xla(ctx, tree, rwxtree, fname, sig, rwsig, params, fnopts):
pname = param_name(p)
if cptype == 'TensorList':
xname = 'l_{}'.format(pname)
code += (
' auto {} = bridge::XlaCreateTensorList({}, /*writeable=*/nullptr);\n'
).format(xname, pname)
code += (' auto {} = bridge::XlaCreateTensorList({});\n').format(
xname, pname)
param_vars.append(xname)
elif cptype == 'TensorOptions':
gcode, xname = rewrite_tensor_options(fname, pname)
Expand All @@ -813,11 +819,12 @@ def generate_aten_to_xla(ctx, tree, rwxtree, fname, sig, rwsig, params, fnopts):
param_vars.append(xname)
if p == ref_param and not get_optional(fnopts, 'ref_param'):
xla_ref_param = param_vars[-1]
code += tfetcher.generate()
code += tfetcher.generate_fetches()
result_assign = generate_result_assignment(tree, _RESULT_NAME)
code += ' {}{};\n'.format(
result_assign, get_handling_function(ctx, fname, xla_ref_param,
param_vars))
code += tfetcher.generate_updates()
if result_assign:
code += (' static_cast<void>({}); // Avoid warnings in case not '
'used\n'.format(_RESULT_NAME))
Expand Down
20 changes: 15 additions & 5 deletions test/test_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,8 +480,7 @@ def test_rand_like(self):
def test_randint_like(self):
shape = (5, 1, 1)
x = torch.randint_like(
torch.zeros(shape, device=xm.xla_device(), dtype=torch.uint8),
6, 10)
torch.zeros(shape, device=xm.xla_device(), dtype=torch.uint8), 6, 10)
self.assertEqual(x.device.type, 'xla')

def test_no_storage(self):
Expand Down Expand Up @@ -671,8 +670,8 @@ def test_norm_p0(self):
xla_device = xm.xla_device()
a = torch.randn(3, 2)
xla_a = a.to(xla_device)
norm = a.norm(p = 0)
xla_norm = xla_a.norm(p = 0)
norm = a.norm(p=0)
xla_norm = xla_a.norm(p=0)
self.assertEqual(norm, xla_norm)

def test_slice_start_end(self):
Expand Down Expand Up @@ -701,7 +700,8 @@ def test_fn(a):

def test_scatter_add_bool(self):
xla_device = xm.xla_device()
a = torch.tensor([[True, True, True, True, True], [True, True, True, True, True]])
a = torch.tensor([[True, True, True, True, True],
[True, True, True, True, True]])
b = torch.zeros(3, 5, dtype=torch.bool)
index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]])
b.scatter_add_(0, index, a)
Expand Down Expand Up @@ -732,6 +732,16 @@ def test_max_throw(self):
self.assertRaises(RuntimeError, lambda: torch.max(xla_a, dim=1))
self.assertRaises(RuntimeError, lambda: torch.max(xla_a))

def test_writeable_tensors_updates(self):

def test_fn(s, i):
out = torch.zeros(2, 4, device=s.device)
return torch.index_select(s, 0, i, out=out)

self.runAtenTest(
[torch.randn(3, 4),
torch.tensor([2, 1], dtype=torch.long)], test_fn)

def test_save(self):
xla_device = xm.xla_device()
x = torch.randn(5, device=xla_device)
Expand Down
33 changes: 18 additions & 15 deletions torch_xla/csrc/aten_xla_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,31 +86,24 @@ XLATensor GetOrCreateXlaTensor(const at::Tensor& tensor, const Device& device) {
return xtensor ? *xtensor : XLATensor::Create(tensor, device);
}

std::vector<at::Tensor> XlaCreateTensorList(
const at::TensorList& tensors, const std::vector<bool>* writeable) {
std::vector<at::Tensor> XlaCreateTensorList(const at::TensorList& tensors) {
std::vector<at::Tensor> aten_xla_tensors(tensors.size());
std::vector<XLATensor> xla_tensors;
// We need to separate out the defined tensors first, GetXlaTensor() doesn't
// work with undefined tensors.
std::vector<bool> defined_writeable;
std::vector<bool> to_translate(tensors.size());
for (size_t i = 0; i < tensors.size(); ++i) {
const at::Tensor& tensor = tensors[i];
if (!tensor.defined()) {
XLA_CHECK(writeable == nullptr || !(*writeable)[i])
<< "Trying to write to an undefined tensor";
} else if (tensor.device().is_cpu()) {
aten_xla_tensors[i] = tensor;
} else {
to_translate[i] = true;
xla_tensors.push_back(GetXlaTensorUnwrap(tensor));
if (writeable != nullptr) {
defined_writeable.push_back((*writeable)[i]);
if (tensor.defined()) {
if (tensor.device().is_cpu()) {
aten_xla_tensors[i] = tensor;
} else {
to_translate[i] = true;
xla_tensors.push_back(GetXlaTensorUnwrap(tensor));
}
}
}
auto defined_aten_xla_tensors = XLATensor::GetTensors(
&xla_tensors, writeable ? &defined_writeable : nullptr);
auto defined_aten_xla_tensors = XLATensor::GetTensors(&xla_tensors);
// Insert undefined tensors into the result, back into the original undefined
// positions.
for (size_t i = 0, defined_pos = 0; i < tensors.size(); ++i) {
Expand All @@ -121,6 +114,16 @@ std::vector<at::Tensor> XlaCreateTensorList(
return aten_xla_tensors;
}

void XlaUpdateTensors(
tensorflow::gtl::ArraySlice<const at::Tensor> dest_xla_tensors,
tensorflow::gtl::ArraySlice<const at::Tensor> source_cpu_tensors,
tensorflow::gtl::ArraySlice<const size_t> indices) {
for (auto index : indices) {
XLATensor xtensor = GetXlaTensorUnwrap(dest_xla_tensors.at(index));
xtensor.UpdateFromTensor(source_cpu_tensors.at(index));
}
}

c10::optional<Device> GetXlaDevice(const at::Tensor& tensor) {
auto xtensor = TryGetXlaTensor(tensor);
if (!xtensor) {
Expand Down
11 changes: 6 additions & 5 deletions torch_xla/csrc/aten_xla_bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ XLATensor GetXlaTensorUnwrap(const at::Tensor& tensor);
XLATensor GetOrCreateXlaTensor(const at::Tensor& tensor, const Device& device);

// Creates a vector of at::Tensor objects extracted from a list of XLA tensors.
// If the writeable vector is not nullptr, it must be the same size as tensors,
// and the corresponding bool tells whether the ATEN tensor to be retrieved
// should the a writeable copy.
std::vector<at::Tensor> XlaCreateTensorList(const at::TensorList& tensors,
const std::vector<bool>* writeable);
std::vector<at::Tensor> XlaCreateTensorList(const at::TensorList& tensors);

void XlaUpdateTensors(
tensorflow::gtl::ArraySlice<const at::Tensor> dest_xla_tensors,
tensorflow::gtl::ArraySlice<const at::Tensor> source_cpu_tensors,
tensorflow::gtl::ArraySlice<const size_t> indices);

// Tries to extract the device out of the XLA tensor. Returns nullopt if the
// input is not an XLA tensor.
Expand Down
3 changes: 1 addition & 2 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,7 @@ at::Tensor AtenXlaType::_copy_from(const at::Tensor& self,
// Do not mark the tensor creation as writeable to not discard the XLA tensor
// device context, but make a copy to avoid core data to be shared.
std::vector<at::Tensor> tensors = {self};
auto xla_tensors =
bridge::XlaCreateTensorList(tensors, /*writeable=*/nullptr);
auto xla_tensors = bridge::XlaCreateTensorList(tensors);
// Hack in an overwrite of a const tensor.
at::Tensor t = CopyTensor(xla_tensors.front(), dst.scalar_type());
const_cast<at::Tensor&>(dst).unsafeGetTensorImpl()->shallow_copy_from(
Expand Down
57 changes: 13 additions & 44 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -743,23 +743,22 @@ void XLATensor::SetScalarType(
data()->logical_element_type = logical_element_type;
}

void XLATensor::MakeWriteableTensorDataSource() {
c10::optional<at::Tensor> tensor_data = CurrentTensorData();
XLA_CHECK(tensor_data);
if (data()->view != nullptr) {
ir::Value ir_value = GetIrValueForTensor(*tensor_data, GetDevice());
data()->view = UpdateView(data()->view, std::move(ir_value));
}
void XLATensor::SetTensor(at::Tensor tensor) {
SetTensorData(tensor);
data()->view = nullptr;
data()->xla_data = nullptr;
AssignIrValue(ir::Value());
data()->generation += 1;
}

void XLATensor::SetTensor(at::Tensor tensor) {
void XLATensor::UpdateFromTensor(at::Tensor tensor) {
SetTensorData(tensor);
data()->view = nullptr;
if (data()->view != nullptr) {
ir::Value ir_value = GetIrValueForTensor(tensor, GetDevice());
data()->view = UpdateView(data()->view, std::move(ir_value));
}
data()->xla_data = nullptr;
AssignIrValue(ir::Value());
data()->generation += 1;
}

std::vector<XLATensor> XLATensor::GetLiveTensors(const Device* device) {
Expand Down Expand Up @@ -790,7 +789,7 @@ std::vector<xla::ComputationClient::DataPtr> XLATensor::GatherTensorsXlaData(
}

std::vector<at::Tensor> XLATensor::GetTensorsOpByOp(
std::vector<XLATensor>* tensors, const std::vector<bool>* writeable) {
std::vector<XLATensor>* tensors) {
SyncTensorsConfig config;
config.force_xla_data = false;
SyncTensorCollection coll = CollectSyncTensors(*tensors, config);
Expand Down Expand Up @@ -827,33 +826,17 @@ std::vector<at::Tensor> XLATensor::GetTensorsOpByOp(
++literals_index;
}
}
if (writeable != nullptr) {
XLA_CHECK_EQ(tensors->size(), writeable->size());
for (size_t i = 0; i < tensors->size(); ++i) {
if ((*writeable)[i]) {
// If all we have for this tensor is ATEN tensor data, we need to set it
// before calling MakeWriteableTensorDataSource(), which will otherwise
// error out.
if (!(*tensors)[i].CurrentTensorData()) {
(*tensors)[i].SetTensorData(results[i]);
}
(*tensors)[i].MakeWriteableTensorDataSource();
}
}
}
return results;
}

std::vector<at::Tensor> XLATensor::GetTensors(
std::vector<XLATensor>* tensors, const std::vector<bool>* writeable) {
std::vector<at::Tensor> XLATensor::GetTensors(std::vector<XLATensor>* tensors) {
static const bool op_by_op =
xla::sys_util::GetEnvBool("GET_TENSORS_OPBYOP", false);
return op_by_op ? GetTensorsOpByOp(tensors, writeable)
: GetTensorsFused(tensors, writeable);
return op_by_op ? GetTensorsOpByOp(tensors) : GetTensorsFused(tensors);
}

std::vector<at::Tensor> XLATensor::GetTensorsFused(
std::vector<XLATensor>* tensors, const std::vector<bool>* writeable) {
std::vector<XLATensor>* tensors) {
SyncTensorsConfig config;
config.force_xla_data = false;
auto async = SyncTensorsGraphInternal(tensors, {}, config);
Expand Down Expand Up @@ -884,20 +867,6 @@ std::vector<at::Tensor> XLATensor::GetTensorsFused(
++literals_index;
}
}
if (writeable != nullptr) {
XLA_CHECK_EQ(tensors->size(), writeable->size());
for (size_t i = 0; i < tensors->size(); ++i) {
if ((*writeable)[i]) {
// If all we have for this tensor is ATEN tensor data, we need to set it
// before calling MakeWriteableTensorDataSource(), which will otherwise
// error out.
if (!(*tensors)[i].CurrentTensorData()) {
(*tensors)[i].SetTensorData(results[i]);
}
(*tensors)[i].MakeWriteableTensorDataSource();
}
}
}
return results;
}

Expand Down
20 changes: 7 additions & 13 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class XLATensor {
// Assigns the tensor value to the XLA tensor.
void SetTensor(at::Tensor tensor);

void UpdateFromTensor(at::Tensor tensor);

at::ScalarType dtype() const;
xla::util::MaybeRef<xla::Shape> shape() const;

Expand Down Expand Up @@ -129,12 +131,9 @@ class XLATensor {
// the computation boundaries.
static void MarkStep(const Device* device);

// Retrieves the PyTorch tensors behind the XLA tensors. If the writeable
// vector is not nullptr, it must be the same size as tensors, and the
// corresponding bool tells whether the ATEN tensor to be retrieved should the
// a writeable copy. All the tensors must be on the same device.
static std::vector<at::Tensor> GetTensors(std::vector<XLATensor>* tensors,
const std::vector<bool>* writeable);
// Retrieves the PyTorch tensors behind the XLA tensors. All the tensors must
// be on the same device.
static std::vector<at::Tensor> GetTensors(std::vector<XLATensor>* tensors);

// Operation which creates XLA tensors out of autograd variable by batching
// the requests to the computation servers.
Expand Down Expand Up @@ -1030,11 +1029,6 @@ class XLATensor {

void SetScalarType(c10::optional<at::ScalarType> logical_element_type);

// Discards all the XLA and IR data, by making the ATEN tensor one the only
// source for this XLA tensor. An error is generated if the XLA tensor does
// not have ATEN tensors data.
void MakeWriteableTensorDataSource();

// We build an XLA graph accumulating XLA operations, but at a given point we
// need to force a rendering, otherwise the graph can grow without control.
// Think:
Expand All @@ -1054,10 +1048,10 @@ class XLATensor {

// Implementation of the GetTensors() API using the op-by-op executor.
static std::vector<at::Tensor> GetTensorsOpByOp(
std::vector<XLATensor>* tensors, const std::vector<bool>* writeable);
std::vector<XLATensor>* tensors);

static std::vector<at::Tensor> GetTensorsFused(
std::vector<XLATensor>* tensors, const std::vector<bool>* writeable);
std::vector<XLATensor>* tensors);

// Runs an asynchronous syn operation using the op-by-op executor.
using OpByOpAsync = xla::util::AsyncTask<xla::Status>;
Expand Down