Skip to content

Commit 36e2aca

Browse files
authored
model should be on xla device for Dynamo torchxla_trace_once backend (#4205)
* model should be on xla device for Dynamo torchxla_trace_once backend * torch pin * Add _clear_pending_irs API and tests * Delete .torch_pin
1 parent 6a792dd commit 36e2aca

File tree

5 files changed

+77
-12
lines changed

5 files changed

+77
-12
lines changed

test/dynamo/test_dynamo.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,40 +23,48 @@ def fn_simple_dynamo(self, x, y):
2323
return self.fn_simple(x, y)
2424

2525
@dynamo.optimize('torchxla_trace_once')
26-
def resetnet_18_dynamo(self, model, data):
26+
def run_model_with_dynamo(self, model, data):
2727
return model(data)
2828

2929
def test_simple_model(self):
30+
device = xm.xla_device()
3031
x = torch.tensor(100.0)
3132
y = torch.tensor(200.0)
33+
xla_x = x.to(device)
34+
xla_y = y.to(device)
3235
res_cpu = self.fn_simple(x, y)
33-
res_xla_dynamo = self.fn_simple_dynamo(x, y)
36+
res_xla_dynamo = self.fn_simple_dynamo(xla_x, xla_y)
3437
self.assertIn('xla::add', met.counter_names())
3538
torch.allclose(res_cpu, res_xla_dynamo.cpu())
3639
# verifiy that tracing is skipped in following runs
3740
met.clear_counters()
38-
res_xla_dynamo_2 = self.fn_simple_dynamo(x, y)
41+
res_xla_dynamo_2 = self.fn_simple_dynamo(xla_x, xla_y)
3942
self.assertNotIn('xla::add', met.counter_names())
4043
torch.allclose(res_cpu, res_xla_dynamo_2.cpu())
4144
# verify that dynamo can handle different inputs
42-
res_xla_dynamo_3 = self.fn_simple_dynamo(x + y, y * 3)
45+
res_xla_dynamo_3 = self.fn_simple_dynamo(xla_x + xla_y, xla_y * 3)
4346
res_cpu_3 = self.fn_simple(x + y, y * 3)
4447
torch.allclose(res_cpu, res_xla_dynamo_3.cpu())
4548

4649
def test_resnet18(self):
50+
device = xm.xla_device()
4751
batch_size = xu.getenv_as('BATCH_SIZE', int, defval=4)
4852
sample_count = xu.getenv_as('SAMPLE_COUNT', int, defval=10)
4953
loader = xu.SampleGenerator(
50-
data=(torch.randn(batch_size, 3, 224,
51-
224), torch.zeros(batch_size, dtype=torch.int64)),
54+
data=(torch.randn(batch_size, 3, 224, 224, device=device),
55+
torch.zeros(batch_size, dtype=torch.int64, device=device)),
5256
sample_count=sample_count)
53-
model = torchvision.models.resnet18()
54-
model.eval()
57+
resnet18 = torchvision.models.resnet18()
58+
resnet18.eval()
59+
xla_resnet18 = torchvision.models.resnet18().to(device)
60+
xla_resnet18.eval()
5561
for data, _ in loader:
56-
output = self.resetnet_18_dynamo(model, data)
57-
torch.allclose(model(data), output.cpu())
58-
self.assertEqual(met.metric_data('CompileTime')[0], 1)
59-
self.assertEqual(met.metric_data('ExecuteTime')[0], sample_count + 1)
62+
output = self.run_model_with_dynamo(xla_resnet18, data)
63+
torch.allclose(resnet18(data.cpu()), output.cpu())
64+
# One graph for initial input data materialization. Another grpah for the
65+
# real model code.
66+
self.assertEqual(met.metric_data('CompileTime')[0], 2)
67+
self.assertEqual(met.metric_data('ExecuteTime')[0], sample_count + 2)
6068
self.assertEqual(
6169
met.metric_data('RunCachedGraphInputData')[0], sample_count)
6270
self.assertEqual(

test/dynamo/test_dynamo_integrations_util.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,25 @@ def test_get_graph_hash(self):
8989
xla_out_2 = xla_dummy_model(xla_input)
9090
assert (hash == torch_xla._XLAC._get_graph_hash([xla_out_2]))
9191

92+
def test_clear_pending_irs(self):
93+
xla_device = xm.xla_device()
94+
xm.mark_step()
95+
t1 = torch.randn(20, 5).to(xla_device)
96+
t2 = torch.randn(20, 5).to(xla_device)
97+
t3 = t2 + t1
98+
t4 = t3 * t2
99+
met.clear_metrics()
100+
torch_xla._XLAC._xla_sync_multi([t4], devices=[], wait=True)
101+
# only t4 is materialized
102+
self.assertIn("aten::add", torch_xla._XLAC._get_xla_tensors_text([t3]))
103+
self.assertEqual(met.metric_data('ExecuteTime')[0], 1)
104+
torch_xla._XLAC._clear_pending_irs(str(xla_device))
105+
self.assertNotIn("aten::add", torch_xla._XLAC._get_xla_tensors_text([t3]))
106+
self.assertEqual(met.metric_data('ExecuteTime')[0], 1)
107+
xm.mark_step()
108+
# mark_step should not incur new execution
109+
self.assertEqual(met.metric_data('ExecuteTime')[0], 1)
110+
92111
def test_run_cached_graph(self):
93112
xla_device = xm.xla_device()
94113
xla_input = torch.randn(64, 256, 14, 14).to(xla_device)

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,13 @@ std::string GetLiveTensorsReport(size_t nodes_threshold,
400400
return ss.str();
401401
}
402402

403+
void ClearPendingIrs(const std::string& device_str) {
404+
auto opt_device = GetOptionalDevice(device_str);
405+
XLA_CHECK(opt_device);
406+
auto tensors = XLATensor::GetLiveTensors(&opt_device.value());
407+
XLATensor::ClearPendingIrs(tensors, opt_device.value());
408+
}
409+
403410
std::ptrdiff_t GetTensorViewAliasId(const at::Tensor& tensor) {
404411
XLATensorPtr xtensor = bridge::GetXlaTensor(tensor);
405412
return xtensor->GetViewAliasId();
@@ -1621,6 +1628,12 @@ void InitXlaModuleBindings(py::module m) {
16211628
return py::bytes(bin);
16221629
});
16231630

1631+
m.def("_clear_pending_irs", [](const std::string& device) {
1632+
// Use with caution. Those tensor whole ir was cleared with be replaced
1633+
// with a placeholder XLAData and SHOULD NOT be accessed.
1634+
ClearPendingIrs(device);
1635+
});
1636+
16241637
m.def("_run_cached_graph",
16251638
[](const std::string& hash_str,
16261639
const std::vector<at::IValue>& graph_inputs)

torch_xla/csrc/tensor.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,6 +1083,28 @@ XLATensor::ExecuteComputationWithBarrier(
10831083
device);
10841084
}
10851085

1086+
void XLATensor::ClearPendingIrs(std::vector<XLATensorPtr> tensors,
1087+
const torch::lazy::BackendDevice& device) {
1088+
std::unordered_set<int64_t> tensor_ids;
1089+
for (size_t i = 0; i < tensors.size(); ++i) {
1090+
if (tensor_ids.insert(tensors[i]->GetUniqueId()).second &&
1091+
tensors[i]->CurrentXlaData() == nullptr) {
1092+
torch::lazy::Value ir_value = tensors[i]->CurrentIrValue();
1093+
if (ir_value) {
1094+
xla::Shape shape = MakeShapeWithDeviceLayout(
1095+
tensors[i]->shape(), static_cast<XlaDeviceType>(device.type()));
1096+
torch::lazy::BackendDataPtr xla_data =
1097+
WrapXlaData(xla::ComputationClient::Get()->CreateDataPlaceholder(
1098+
device.toString(), std::move(shape)));
1099+
tensors[i]->AssignIrValue(torch::lazy::Value());
1100+
tensors[i]->data()->xla_data = xla_data;
1101+
tensors[i]->data()->view = nullptr;
1102+
tensors[i]->data()->tensor_data = c10::nullopt;
1103+
}
1104+
}
1105+
}
1106+
}
1107+
10861108
std::vector<at::Tensor> XLATensor::GetTensorsOpByOp(
10871109
std::vector<XLATensorPtr>* tensors) {
10881110
SyncTensorsConfig config;

torch_xla/csrc/tensor.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1250,6 +1250,9 @@ class XLATensor : public c10::intrusive_ptr_target {
12501250
c10::ArrayRef<torch::lazy::BackendDataPtr> arguments,
12511251
const torch::lazy::BackendDevice& device);
12521252

1253+
static void ClearPendingIrs(std::vector<XLATensorPtr> tensors,
1254+
const torch::lazy::BackendDevice& device);
1255+
12531256
private:
12541257
struct SyncTensorsConfig {
12551258
// Whether we want to force XLA data on the target tensors (hence trimming

0 commit comments

Comments
 (0)