Skip to content

Commit c6fca12

Browse files
authored
POC of dynamicShape+functionalization integration (#4611)
* Support dynamism on "transpose", unsqueeze_copy, and view_copy_symint. Support dynamism on "transpose" operation. (#4606) * added failing test_t_copy test * made transpose work * disable the backward pass. * fix linter Add dynamic shape test for unsqueeze_copy (#4608) * added failing test * test correctness. Support dynamism on view_copy_symint (#4629) * Disable failing XLA tests * Re-enable dynamo tests (#4454) * Turn on keep going * Re-enable more dynamic shape tests (#4558) * fix size node ne op kind to be size_ne * reenable test_nonzero_cast * fix linter * disable new failing test after rebase temporarily. * fix linter * [Functionalization] Fix ScatterReduce (#4576) Summary: ScatterReduce::reduce_ uses a unsafe type c10::string_view to store the string. Replace it with std::string and re-enable all previous failing tests. Test Plan: TPU_LIBRARY_PATH=/home/ptxla/.local/lib/python3.8/site-packages/libtpu/libtpu.so PJRT_DEVICE=CPU test/cpp/build/test_ptxla --gtest_filter=AtenXlaTensorTest.TestScatterReduce* * Revert "Turn on keep going" This reverts commit 224bdf3. * Turn on keep going * Revert "Turn on keep going" This reverts commit 224bdf3. * Turn on keep going * Revert "Turn on keep going" This reverts commit 224bdf3. * code compiles * wrote new test TestDynamicShapes.test_sizeMod and it succeeds. * the first view_copy_symint test test_view_copy_symint_with_dyn_input_shape passed * added two more tests * run linter * added a new failing test with negative shape * fix all tests * clean up * clean up * fix a small test failure. * fix pr comments * updated the tests for dynamic input and static input shape * not support dynamic input and static input shape. * fix linter * fix pr comments * fix build error * resolve a rebase conflict * run linter * add one more test * add one more test and it passes. * fix linter * remove experiemental test. --------- Co-authored-by: Wonjoo Lee <wonjoo@google.com> Co-authored-by: Jiewen Tan <jwtan@google.com> fix a merge error fix other merge error during rebase. * Track type of SymNode for XLASymNodeImpl::mod (#4732)
1 parent ea135c6 commit c6fca12

File tree

10 files changed

+288
-26
lines changed

10 files changed

+288
-26
lines changed

test/test_dynamic_shape_models.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import os
23
import sys
34

45
parser = argparse.ArgumentParser(add_help=False)
@@ -12,6 +13,9 @@
1213
import torch_xla.core.xla_model as xm
1314
import torch_xla.debug.metrics as met
1415

16+
# It enables us to run python implementations of CompositeAutogradImplicit ops.
17+
# CompositeAutogradImplicit means we don't have an explicit backward formula for an op instead an op is composed of a bunch of ops that do have backward formulas and combines this formulas is equivalent to differentiating the op explicitly.
18+
pd = torch._C._EnablePythonDispatcher()
1519
xla_dev = xm.xla_device()
1620

1721

@@ -92,9 +96,6 @@ def test_forward_pass_dynamic_input_compile_once(self):
9296
met.metric_data('CompileTime')[0],
9397
'number of compilation should not increase.')
9498

95-
@unittest.skip(
96-
"disable it due to https://github.com/pytorch/xla/pull/4322#issuecomment-1374312614."
97-
)
9899
def test_backward_pass_with_dynamic_input(self):
99100
num_features = 2
100101
num_test_samples = 5
@@ -156,5 +157,8 @@ def create_dynamic_test_data(self,
156157

157158

158159
if __name__ == '__main__':
160+
assert os.environ['XLA_EXPERIMENTAL'] != ''
159161
test = unittest.main(verbosity=FLAGS.verbosity, exit=False)
162+
# DISABLE PYTHON DISPATCHER FLAG
163+
del pd
160164
sys.exit(0 if test.result.wasSuccessful() else 1)

test/test_dynamic_shapes.py

Lines changed: 177 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,22 @@ def test_empty_symint(self):
144144
self.assertIsInstance(t2.shape[1], int)
145145
self.assertEqual(t2.shape[1], 2)
146146

147+
def test_t_copy(self):
148+
t1 = torch.tensor([[1, 0, 0, 5, 0, 6], [1, 3, 2, 0, 0, 1]], device=dev)
149+
t2 = torch.nonzero(t1)
150+
# t2.shape=torch.Size([<=12, 2]) with real size [7, 2]
151+
self.assertEqual(str(t2.shape[0]), '<=12')
152+
self.assertEqual(str(t2.shape[1]), '2')
153+
154+
t2_t = torch.t(t2)
155+
156+
self.assertIsInstance(t2_t.shape[0], int)
157+
self.assertIsInstance(t2_t.shape[1], torch.SymInt)
158+
self.assertEqual(str(t2_t.shape[0]), '2')
159+
self.assertEqual(str(t2_t.shape[1]), '<=12')
160+
self.assertEqual(t2_t.shape[0], 2)
161+
self.assertEqual(t2_t.shape[1], 7)
162+
147163
def test_nonzero_shape(self):
148164
x = torch.tensor((0, 1, 2, 0, 3, 4), device=xm.xla_device())
149165
x_dim0_shape = torch_xla._XLAC._get_xla_tensor_dimension_size(
@@ -175,13 +191,173 @@ def test_expand_symint_correctness(self):
175191
t2 = torch.zeros([size1, size2], device=dev)
176192
t2[3][0] = 1
177193
t2[3][1] = 1
178-
# t2 has size [<=10, 2]
194+
# t3 has size [<=10, 2]
179195
t3 = torch.nonzero(t2)
180196
t4 = torch.ones([size1, size2], device=dev)
181197
expand_out_xla = t4.expand(t3.shape[0], size1, size2)
182198
self.assertEqual(t3.shape[0], 2)
183199
self.assertEqual(expand_out_aten.cpu(), expand_out_xla.cpu())
184200

201+
def test_unsqueeze_copy_dynamism(self):
202+
t1 = torch.tensor([[1, 0, 0, 5, 0, 6], [1, 3, 2, 0, 0, 1]], device=dev)
203+
t2 = torch.nonzero(t1)
204+
# t2.shape=torch.Size([<=12, 2]) with real size [7, 2]
205+
206+
t2_unsqueeze = torch.unsqueeze(t2, 0)
207+
208+
self.assertEqual(len(t2_unsqueeze.size()), 3)
209+
self.assertIsInstance(t2_unsqueeze.shape[0], int)
210+
self.assertIsInstance(t2_unsqueeze.shape[1], torch.SymInt)
211+
self.assertIsInstance(t2_unsqueeze.shape[2], int)
212+
self.assertEqual(str(t2_unsqueeze.shape[0]), '1')
213+
self.assertEqual(str(t2_unsqueeze.shape[1]), '<=12')
214+
self.assertEqual(str(t2_unsqueeze.shape[2]), '2')
215+
self.assertEqual(t2_unsqueeze.shape[0], 1)
216+
self.assertEqual(t2_unsqueeze.shape[1], 7)
217+
self.assertEqual(t2_unsqueeze.shape[2], 2)
218+
219+
# test correctness
220+
t3 = torch.tensor([[1, 0, 0, 5, 0, 6], [1, 3, 2, 0, 0, 1]])
221+
t4 = torch.nonzero(t3)
222+
t4_unsqueeze = torch.unsqueeze(t4, 0)
223+
self.assertEqual(t2_unsqueeze.cpu(), t4_unsqueeze.cpu())
224+
225+
def test_view_copy_symint_with_static_input_dyn_input_shape(self):
226+
# If the input tensor and shape are “statically” incompatible, a compilation error is raised.
227+
t1 = torch.tensor([1, 0, 3, 5, 0, 6], device=dev)
228+
# t2.shape=torch.Size([<=6, 1]) with real size [4, 1]
229+
# t2 = [[0], [2], [3], [5]]
230+
t2 = torch.nonzero(t1)
231+
t3 = torch.randint(10, (2, 2), device=dev)
232+
self.assertRaises(RuntimeError, lambda: t3.view(t2.shape[0]))
233+
234+
# If their “dynamic” values are incompatible, a RuntimeError is raised.
235+
t4 = torch.randint(10, (2, 3), device=dev)
236+
self.assertRaises(RuntimeError, lambda: t4.view(t2.shape[0]))
237+
238+
# verify if dynamism is propagated correctly.
239+
t5 = torch.tensor([1, 1, 3, 5, 1, 6], device=dev)
240+
t6 = torch.nonzero(t5)
241+
t7 = torch.ones((2, 3), device=dev)
242+
t8 = t7.view(t6.shape[0])
243+
self.assertIsInstance(t8.shape[0], torch.SymInt)
244+
self.assertEqual(str(t8.shape[0]), '<=6')
245+
self.assertEqual(t8.shape[0], 6)
246+
247+
# verify correctness.
248+
t5_aten = torch.tensor([1, 1, 3, 5, 1, 6])
249+
t6_aten = torch.nonzero(t5_aten)
250+
t7_aten = torch.ones((2, 3))
251+
t8_aten = t7_aten.view(t6_aten.shape[0])
252+
self.assertEqual(t8.cpu(), t8_aten.cpu())
253+
254+
def test_view_copy_symint_with_static_input_dyn_input_shape2(self):
255+
# If the input tensor and shape are “statically” incompatible, a compilation error is raised.
256+
t1 = torch.tensor([[1, 0, 3]], device=dev)
257+
# t2.shape=torch.Size([<=3, 2]) with real size [2, 2]
258+
# t2 = [[0, 0], [0, 2]]
259+
t2 = torch.nonzero(t1)
260+
t3 = torch.ones((2, 4), device=dev)
261+
# Should fail in pytorch utils.infer_size
262+
self.assertRaises(RuntimeError, lambda: t3.view(t2.shape))
263+
264+
# If their “dynamic” values are incompatible, a RuntimeError is raised.
265+
t4 = torch.ones((2, 3), device=dev)
266+
# Also fails in pytorch utils.infer_size
267+
self.assertRaises(RuntimeError, lambda: t4.view(t2.shape))
268+
269+
# verify if dynamism is propagated correctly.
270+
t5 = torch.tensor([[1, 1, 3]], device=dev)
271+
t6 = torch.nonzero(t5)
272+
# t6.shape=[<=3, 2] with real size [3, 2]
273+
t7 = torch.ones((2, 3), device=dev)
274+
t8 = t7.view(t6.shape)
275+
self.assertIsInstance(t8.shape[0], torch.SymInt)
276+
self.assertEqual(str(t8.shape[0]), '<=3')
277+
self.assertEqual(t8.shape[0], 3)
278+
self.assertIsInstance(t8.shape[1], int)
279+
self.assertEqual(str(t8.shape[1]), '2')
280+
self.assertEqual(t8.shape[1], 2)
281+
282+
# verify correctness.
283+
t5_aten = torch.tensor([[1, 1, 3]])
284+
t6_aten = torch.nonzero(t5_aten)
285+
t7_aten = torch.ones((2, 3))
286+
t8_aten = t7_aten.view(t6_aten.shape)
287+
self.assertEqual(t8.cpu(), t8_aten.cpu())
288+
289+
def test_view_copy_symint_with_dyn_input_static_input_shape(self):
290+
# If the input tensor is dynamic and input shape is static,
291+
# it should fail because we will not likely have this case
292+
# in reality so we don't support this feature.
293+
t1 = torch.tensor([1, 1, 3, 5, 1, 6], device=dev)
294+
# t2.shape=torch.Size([<=6, 1]) with real size [6, 1]
295+
t2 = torch.nonzero(t1)
296+
self.assertRaises(RuntimeError, lambda: t2.view(2, 3))
297+
298+
def test_view_copy_symint_with_dyn_input_dyn_input_shape(self):
299+
# If the input tensor and shape are “statically” incompatible, a compilation error is raised.
300+
t1 = torch.tensor([1, 0, 3, 5, 0, 6], device=dev)
301+
# t2.shape=torch.Size([<=6, 1]) with real size [4, 1]
302+
# t2 = [[0], [2], [3], [5]]
303+
t2 = torch.nonzero(t1)
304+
t3 = torch.tensor([1, 0, 3, 5, 0, 6, 7], device=dev)
305+
# t4.shape=torch.Size([<=7, 1]) with real size [5, 1]
306+
t4 = torch.nonzero(t3)
307+
self.assertRaises(RuntimeError, lambda: t2.view(t4.shape[0]))
308+
309+
# If their “dynamic” values are incompatible, a RuntimeError is raised.
310+
t5 = torch.tensor([1, 2, 3, 4, 5, 6, 0], device=dev)
311+
# t6.shape=torch.Size([<=7, 1]) with real size [6, 1]
312+
t6 = torch.nonzero(t5)
313+
# statically compatible but dynamically incompatible.
314+
# It will fail in pytorch layer.
315+
self.assertRaises(RuntimeError, lambda: t6.view(t4.shape[0]))
316+
317+
# verify if dynamism is propagated correctly.
318+
t7 = torch.tensor([1, 0, 3, 5, 0, 6, 7], device=dev)
319+
t8 = torch.nonzero(t7)
320+
# t8.shape=torch.Size([<=7, 1]) with real size [5, 1]
321+
t9 = t8.view(t4.shape[0])
322+
self.assertIsInstance(t9.shape[0], torch.SymInt)
323+
self.assertEqual(str(t9.shape[0]), '<=7')
324+
self.assertEqual(t9.shape[0], 5)
325+
326+
# verify correctness.
327+
t7_aten = torch.tensor([1, 0, 3, 5, 0, 6, 7])
328+
t8_aten = torch.nonzero(t7_aten)
329+
# t8_aten.size=[5, 1]
330+
t3_aten = torch.tensor([1, 0, 3, 5, 0, 6, 7])
331+
t4_aten = torch.nonzero(t3_aten)
332+
# t4_aten.size=[5, 1]
333+
t9_aten = t8_aten.view(t4_aten.shape[0])
334+
self.assertEqual(t9.cpu(), t9_aten.cpu())
335+
336+
def test_sizeMod(self):
337+
met.clear_all()
338+
339+
size1 = 5
340+
size2 = 2
341+
t1 = torch.zeros([size1, size2], device=dev)
342+
t1[3][0] = 1
343+
# t2 has size [<=10, 2] with real size [1, 2]
344+
t2 = torch.nonzero(t1)
345+
# Create a SizeMod IR node.
346+
# t2.shape[1] generates a SizeConstant node.
347+
dyn_size = t2.shape[0] % t2.shape[1]
348+
self.assertGreater(met.counter_value("xla::size_mod"), 0)
349+
# Exercises SizeMod::getDynamicValue.
350+
dynamic_size = int(dyn_size)
351+
self.assertEqual(dynamic_size, 1)
352+
self.assertEqual(str(dyn_size), '<=0')
353+
354+
# t3 has size [<=10, 2] with real size [1, 2]
355+
t3 = torch.nonzero(t1)
356+
dyn_size = t2.shape[0] % t3.shape[0]
357+
dynamic_size = int(dyn_size)
358+
self.assertEqual(dynamic_size, 0)
359+
self.assertEqual(str(dyn_size), '<=0')
360+
185361
def test_sizeGe(self):
186362
met.clear_all()
187363

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3186,12 +3186,19 @@ std::tuple<at::Tensor, at::Tensor> XLANativeFunctions::var_mean(
31863186
}
31873187

31883188
at::Tensor XLANativeFunctions::view_copy_symint(const at::Tensor& self,
3189-
at::SymIntArrayRef sym_size) {
3190-
// TODO: support symbolic sizes
3191-
auto size = C10_AS_INTARRAYREF_SLOW(sym_size);
3189+
at::SymIntArrayRef shape) {
31923190
TORCH_LAZY_FN_COUNTER("xla::");
3193-
return bridge::AtenFromXlaTensor(tensor_methods::view(
3194-
bridge::GetXlaTensor(self), XlaHelpers::I64List(size)));
3191+
c10::optional<at::IntArrayRef> int_shape = c10::asIntArrayRefSlowOpt(shape);
3192+
bool input_shape_static = int_shape.has_value();
3193+
XLATensorPtr xla_input = bridge::GetXlaTensor(self);
3194+
bool input_has_dyn_shape = xla_input->shape().get().is_dynamic();
3195+
3196+
XLA_CHECK(!(input_has_dyn_shape && input_shape_static))
3197+
<< "This view op has dynamic input tensor but static input shape. This "
3198+
"behavior is currently unsupported; if the user believes this must be "
3199+
"supported, please file a feature request against PyTorch/XLA.";
3200+
return bridge::AtenFromXlaTensor(
3201+
tensor_methods::view_symint(xla_input, shape));
31953202
}
31963203

31973204
at::Tensor XLANativeFunctions::where(const at::Tensor& condition,

torch_xla/csrc/helpers.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -254,27 +254,27 @@ xla::XlaOp XlaHelpers::ReshapeToRank(xla::XlaOp input, int64_t expected_rank,
254254
absl::optional<XlaHelpers::DynamicReshapeInfo>
255255
XlaHelpers::GetDynamicReshapeInfo(const xla::Shape& input_shape,
256256
absl::Span<const int64_t> output_sizes) {
257-
int64_t input_dynamic_dimension = GetDynamicDimension(input_shape);
258-
if (input_dynamic_dimension < 0) {
257+
int64_t input_dyndim_idx = GetDynamicDimension(input_shape);
258+
if (input_dyndim_idx < 0) {
259259
return absl::nullopt;
260260
}
261261
DynamicReshapeInfo info;
262262
info.output_shape =
263263
xla::ShapeUtil::MakeShape(input_shape.element_type(), output_sizes);
264264
if (info.output_shape.rank() > 0) {
265-
int64_t size_at_dyndim = 1;
266-
for (int64_t i = 0; i <= input_dynamic_dimension; ++i) {
267-
size_at_dyndim *= input_shape.dimensions(i);
265+
int64_t size_prod_until_dyndim = 1;
266+
for (int64_t i = 0; i <= input_dyndim_idx; ++i) {
267+
size_prod_until_dyndim *= input_shape.dimensions(i);
268268
}
269269
int64_t dynamic_dimension = -1;
270270
int64_t out_size = 1;
271271
for (int64_t i = 0; i < output_sizes.size(); ++i) {
272-
XLA_CHECK_LE(out_size, size_at_dyndim / input_shape.dimensions(
273-
input_dynamic_dimension))
272+
XLA_CHECK_LE(out_size, size_prod_until_dyndim /
273+
input_shape.dimensions(input_dyndim_idx))
274274
<< "Unable to map dynamic dimension of shape " << input_shape
275275
<< " to output sizes (" << absl::StrJoin(output_sizes, ", ") << ")";
276276
out_size *= output_sizes[i];
277-
if (out_size >= size_at_dyndim) {
277+
if (out_size >= size_prod_until_dyndim) {
278278
dynamic_dimension = i;
279279
break;
280280
}

torch_xla/csrc/ops/dynamic_ir.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,4 +290,41 @@ XlaOpVector SizeDiv::Lower(LoweringContext* loctx) const {
290290
return ReturnOp(xla::Div(input1, input2), loctx);
291291
}
292292

293+
SizeMod::SizeMod(torch::lazy::Value a, torch::lazy::Value b)
294+
: XlaNode(
295+
torch::lazy::OpKind{c10::Symbol::fromQualString("aten::size_mod")},
296+
{a, b},
297+
xla::ShapeUtil::MakeShape(GetShapeDimensionType(/*device=*/nullptr),
298+
{}),
299+
1) {
300+
const torch::lazy::DimensionNode* dim_node_0 = DimCast(operand(0));
301+
const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1));
302+
// SizeDiv can only be perfomed between two DimensionNode
303+
XLA_CHECK(dim_node_0);
304+
XLA_CHECK(dim_node_1);
305+
// We don't need to hash upper_bound_ and because it is computed
306+
// from input shapes and input Node already hash its shape.
307+
XLA_CHECK(dim_node_1->getStaticValue() != 0)
308+
<< "Can't divide a dimension by zero";
309+
upper_bound_ = dim_node_0->getStaticValue() % dim_node_1->getStaticValue();
310+
};
311+
312+
int64_t SizeMod::getDynamicValue() const {
313+
const torch::lazy::DimensionNode* dim_node_0 = DimCast(operand(0));
314+
const torch::lazy::DimensionNode* dim_node_1 = DimCast(operand(1));
315+
XLA_CHECK(dim_node_0);
316+
XLA_CHECK(dim_node_1);
317+
XLA_CHECK(dim_node_1->getDynamicValue() != 0)
318+
<< "Can't mod a dynamic dimension by zero";
319+
return dim_node_0->getDynamicValue() % dim_node_1->getDynamicValue();
320+
}
321+
322+
std::string SizeMod::ToString() const { return "aten::size_mod"; }
323+
324+
XlaOpVector SizeMod::Lower(LoweringContext* loctx) const {
325+
auto input1 = loctx->GetOutputOp(operand(0));
326+
auto input2 = loctx->GetOutputOp(operand(1));
327+
return ReturnOp(xla::Rem(input1, input2), loctx);
328+
}
329+
293330
} // namespace torch_xla

torch_xla/csrc/ops/dynamic_ir.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,19 @@ class SizeDiv : public XlaNode, public torch::lazy::DimensionNode {
165165
int64_t upper_bound_;
166166
};
167167

168+
class SizeMod : public XlaNode, public torch::lazy::DimensionNode {
169+
public:
170+
SizeMod(torch::lazy::Value a, torch::lazy::Value b);
171+
int64_t getDynamicValue() const override;
172+
int64_t getStaticValue() const override { return upper_bound_; }
173+
bool isSymbolic() const override { return true; }
174+
std::string ToString() const override;
175+
virtual XlaOpVector Lower(LoweringContext* loctx) const override;
176+
177+
private:
178+
int64_t upper_bound_;
179+
};
180+
168181
const torch::lazy::DimensionNode* DimCast(torch::lazy::Output output);
169182
const torch::lazy::DimensionNode* DimCast(const torch::lazy::Node* node);
170183
const std::shared_ptr<torch::lazy::DimensionNode> DimCast(

torch_xla/csrc/ops/permute.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,12 @@ std::string Permute::ToString() const {
4444

4545
xla::Shape Permute::MakePermuteShape(const xla::Shape& source_shape,
4646
absl::Span<const int64_t> permutation) {
47-
return XlaHelpers::GetDynamicReshape(
48-
source_shape,
49-
XlaHelpers::Permute(permutation, source_shape.dimensions()));
47+
auto output_static_dims =
48+
XlaHelpers::Permute(permutation, source_shape.dimensions());
49+
auto output_dyn_dims =
50+
XlaHelpers::Permute(permutation, source_shape.dynamic_dimensions());
51+
return xla::ShapeUtil::MakeShape(source_shape.element_type(),
52+
output_static_dims, output_dyn_dims);
5053
}
5154

5255
} // namespace torch_xla

torch_xla/csrc/tensor.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -669,8 +669,14 @@ c10::SymNode XLASymNodeImpl::floordiv(const c10::SymNode& other) {
669669
}
670670

671671
c10::SymNode XLASymNodeImpl::mod(const c10::SymNode& other) {
672-
XLA_CHECK(false) << "XLASymNodeImpl::" << __FUNCTION__
673-
<< " has not been implemented.";
672+
TORCH_LAZY_FN_COUNTER("xla::size_");
673+
torch_xla::XLASymNodeImpl* p_other =
674+
dynamic_cast<XLASymNodeImpl*>(other.get());
675+
XLA_CHECK(is_int()) << __FUNCTION__ << " with non-int NYI";
676+
XLA_CHECK(p_other->is_int()) << __FUNCTION__ << " with non-int NYI";
677+
torch::lazy::NodePtr n_mod =
678+
torch::lazy::MakeNode<SizeMod>(node(), p_other->node());
679+
return c10::make_intrusive<XLASymNodeImpl>(n_mod, PyType::INT);
674680
}
675681

676682
c10::SymNode XLASymNodeImpl::eq(const c10::SymNode& other) {

0 commit comments

Comments
 (0)