Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NestedTensor] Graph breaks with SDPA + NT constructor #126472

Open
davidberard98 opened this issue May 16, 2024 · 2 comments
Open

[NestedTensor] Graph breaks with SDPA + NT constructor #126472

davidberard98 opened this issue May 16, 2024 · 2 comments
Labels
module: nestedtensor NestedTensor tag see issue #25032 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@davidberard98
Copy link
Contributor

davidberard98 commented May 16, 2024

🐛 Describe the bug

When we use SDPA, we need max_seqlen and min_seqlen. Getting max/min_seqlen normally requires a .item call (which usually graph breaks, I think?).

So this focuses on removing graph breaks where:

  • We construct the NT in the graph
  • we use SDPA and
  • we pass in the max/min_seqlen manually.

General repro - the approach is to call nested_view_from_values_offsets_lengths with max_seqlen and min_seqlen passed in:

import torch
from torch.nested._internal.nested_tensor import ViewNestedFromBuffer, nested_view_from_values_offsets_lengths
import torch._dynamo

# note: for testing with ViewNestedFromBuffer, which I wasn't able to get working
torch._dynamo.allow_in_graph(ViewNestedFromBuffer)


def convert_jagged_to_nested_tensor(
    values: torch.Tensor, offsets: torch.Tensor, max_length: int
) -> torch.Tensor:
    # metadata_cache = {"max_seqlen": max_length, "min_seqlen": 1}
    # nt = ViewNestedFromBuffer.apply(values, offsets, metadata_cache)
    nt = nested_view_from_values_offsets_lengths(values, offsets, lengths=None, min_seqlen=1, max_seqlen=max_length)
    return nt


class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.q_linear = torch.nn.Linear(8, 32*4)
        self.k_linear = torch.nn.Linear(8, 32*4)
        self.v_linear = torch.nn.Linear(8, 32*4)

    def forward(self, values, offsets):
        nt = convert_jagged_to_nested_tensor(values, offsets, 5)
        q, k, v = [mod(nt) for mod in (self.q_linear, self.k_linear, self.v_linear)]
        q, k, v = [
            x.view(4, -1, 4, 32).transpose(1, 2)
            for x in (q, k, v)
        ]
        sdpa_out = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=False)
        return sdpa_out.values()


values = torch.randn(10, 8, device='cuda')
offsets = torch.tensor([0, 1, 3, 6, 10], device='cuda')

# optionally, dynamic=False; but I wasn't able to get that to work either
torch.compile(MyModule().cuda(), fullgraph=True)(values, offsets)

Failure 1: With #122836 (rebased onto 7f1d5ab)

Traceback (most recent call last):
  File "/home/dberard/local/scripts/nt_2.py", line 38, in <module>
    torch.compile(MyModule().cuda(), fullgraph=True)(values, offsets)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/eval_frame.py", line 420, in _fn
    return fn(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 986, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 381, in _convert_frame_assert
    return _compile(
  File "/home/dberard/local/pytorch/torch/_utils_internal.py", line 70, in wrapper_function
    return function(*args, **kwargs)
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 708, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/dberard/local/pytorch/torch/_dynamo/utils.py", line 273, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 543, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/dberard/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1167, in transform_code_object
    transformations(instructions, code_options)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 172, in _fn
    return fn(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 490, in transform
    tracer.run()
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2234, in run
    super().run()
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 884, in run
    while self.step():
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 799, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2423, in RETURN_VALUE
    self._return(inst)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2408, in _return
    self.output.compile_subgraph(
  File "/home/dberard/local/pytorch/torch/_dynamo/output_graph.py", line 1083, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/dberard/local/pytorch/torch/_dynamo/output_graph.py", line 1300, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/home/dberard/local/pytorch/torch/_dynamo/utils.py", line 273, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/output_graph.py", line 1391, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/home/dberard/local/pytorch/torch/_dynamo/output_graph.py", line 1372, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/home/dberard/local/pytorch/torch/_dynamo/repro/after_dynamo.py", line 127, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/dberard/local/pytorch/torch/_dynamo/repro/after_dynamo.py", line 127, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/dberard/local/pytorch/torch/__init__.py", line 1747, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/dberard/local/pytorch/torch/_inductor/compile_fx.py", line 1478, in compile_fx
    return aot_autograd(
  File "/home/dberard/local/pytorch/torch/_dynamo/backends/common.py", line 65, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/home/dberard/local/pytorch/torch/_functorch/aot_autograd.py", line 962, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/home/dberard/local/pytorch/torch/_dynamo/utils.py", line 273, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_functorch/aot_autograd.py", line 554, in create_aot_dispatcher_function
    fw_metadata = run_functionalized_fw_and_collect_metadata(
  File "/home/dberard/local/pytorch/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 566, in inner
    dynamic_dims = {
  File "/home/dberard/local/pytorch/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 567, in <setcomp>
    i for i, s in enumerate(o.shape) if not is_concrete_int(s)
  File "/home/dberard/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 221, in is_concrete_int
    if isinstance(a.node.expr, sympy.core.numbers.Integer):
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AttributeError: 'torch._C._SymNode' object has no attribute 'expr'

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Failure 2: Based on the failure, I tried with @soulitzer's PR #124624 patched on top:

Traceback (most recent call last):
  File "/home/dberard/local/scripts/nt_2.py", line 38, in <module>
    torch.compile(MyModule().cuda(), fullgraph=True)(values, offsets)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/eval_frame.py", line 420, in _fn
    return fn(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 986, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 381, in _convert_frame_assert
    return _compile(
  File "/home/dberard/local/pytorch/torch/_utils_internal.py", line 70, in wrapper_function
    return function(*args, **kwargs)
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 737, in _compile
    raise InternalTorchDynamoError(str(e)).with_traceback(
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 708, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/dberard/local/pytorch/torch/_dynamo/utils.py", line 273, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 543, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/dberard/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1167, in transform_code_object
    transformations(instructions, code_options)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 172, in _fn
    return fn(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 490, in transform
    tracer.run()
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2234, in run
    super().run()
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 884, in run
    while self.step():
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 799, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 494, in wrapper
    return inner_fn(self, inst)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1253, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 737, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/dberard/local/pytorch/torch/_dynamo/variables/functions.py", line 293, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 743, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2447, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2563, in inline_call_
    tracer.run()
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 884, in run
    while self.step():
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 799, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 494, in wrapper
    return inner_fn(self, inst)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1306, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 737, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/dberard/local/pytorch/torch/_dynamo/variables/torch.py", line 754, in call_function
    tensor_variable = wrap_fx_proxy(
  File "/home/dberard/local/pytorch/torch/_dynamo/variables/builder.py", line 1585, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/variables/builder.py", line 1708, in wrap_fx_proxy_cls
    set_example_value(proxy.node, example_value)
  File "/home/dberard/local/pytorch/torch/_dynamo/utils.py", line 1166, in set_example_value
    if symbol_to_path := torch.fx.experimental.symbolic_shapes.compute_unbacked_bindings(
  File "/home/dberard/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 604, in compute_unbacked_bindings
    symbol_to_path = free_unbacked_symbols_with_path(example_value, ())
  File "/home/dberard/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 541, in free_unbacked_symbols_with_path
    real=a.real_tensor.size() if a.real_tensor is not None else None
torch._dynamo.exc.InternalTorchDynamoError: 'NestedTensor' object has no attribute 'real_tensor'

from user code:
   File "/home/dberard/local/scripts/nt_2.py", line 25, in forward
    nt = convert_jagged_to_nested_tensor(values, offsets, 5)
  File "/home/dberard/local/scripts/nt_2.py", line 13, in convert_jagged_to_nested_tensor
    nt = nested_view_from_values_offsets_lengths(values, offsets, lengths=None, min_seqlen=1, max_seqlen=max_length)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Failure 3: Based on this, I tried a quick patch: this change

Traceback (most recent call last):
  File "/home/dberard/local/scripts/nt_2.py", line 38, in <module>
    torch.compile(MyModule().cuda(), fullgraph=True)(values, offsets)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/eval_frame.py", line 420, in _fn
    return fn(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 986, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 381, in _convert_frame_assert
    return _compile(
  File "/home/dberard/local/pytorch/torch/_utils_internal.py", line 70, in wrapper_function
    return function(*args, **kwargs)
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 708, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/dberard/local/pytorch/torch/_dynamo/utils.py", line 273, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 543, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/dberard/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1167, in transform_code_object
    transformations(instructions, code_options)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 172, in _fn
    return fn(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 490, in transform
    tracer.run()
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2234, in run
    super().run()
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 884, in run
    while self.step():
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 799, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 494, in wrapper
    return inner_fn(self, inst)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1253, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 737, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/dberard/local/pytorch/torch/_dynamo/variables/functions.py", line 293, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 743, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2447, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2563, in inline_call_
    tracer.run()
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 884, in run
    while self.step():
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 799, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 494, in wrapper
    return inner_fn(self, inst)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1306, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 737, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/dberard/local/pytorch/torch/_dynamo/variables/torch.py", line 754, in call_function
    tensor_variable = wrap_fx_proxy(
  File "/home/dberard/local/pytorch/torch/_dynamo/variables/builder.py", line 1585, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/variables/builder.py", line 1708, in wrap_fx_proxy_cls
    set_example_value(proxy.node, example_value)
  File "/home/dberard/local/pytorch/torch/_dynamo/utils.py", line 1166, in set_example_value
    if symbol_to_path := torch.fx.experimental.symbolic_shapes.compute_unbacked_bindings(
  File "/home/dberard/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 605, in compute_unbacked_bindings
    assert not pending, (
AssertionError: pending {u0} not in NestedTensor(size=(4, u1, 8), offsets=FakeTensor(..., device='cuda:0', size=(5,), dtype=torch.int64), contiguous=True) ((8*u1, 8, 1), 0)

from user code:
   File "/home/dberard/local/scripts/nt_2.py", line 25, in forward
    nt = convert_jagged_to_nested_tensor(values, offsets, 5)
  File "/home/dberard/local/scripts/nt_2.py", line 13, in convert_jagged_to_nested_tensor
    nt = nested_view_from_values_offsets_lengths(values, offsets, lengths=None, min_seqlen=1, max_seqlen=max_length)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

I haven't gotten around to investigating this yet. Maybe #126198 is related (just based on unbacked symint <-> NT issues).

Failure 4: One other attempt - I figured I'd try #124803 to see if it would fix the issue without unbacked symint issues, but it runs into other issues where we get multiple NestedInts for the same dimension. (So we should probably just go with #124624 and figure out what the unbacked symint issue is about)

Traceback (most recent call last):
  File "/home/dberard/local/scripts/nt_2.py", line 38, in <module>
    torch.compile(MyModule().cuda(), fullgraph=True)(values, offsets)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/eval_frame.py", line 420, in _fn
    return fn(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 986, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 381, in _convert_frame_assert
    return _compile(
  File "/home/dberard/local/pytorch/torch/_utils_internal.py", line 70, in wrapper_function
    return function(*args, **kwargs)
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 708, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/dberard/local/pytorch/torch/_dynamo/utils.py", line 273, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 543, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/dberard/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1167, in transform_code_object
    transformations(instructions, code_options)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 172, in _fn
    return fn(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/convert_frame.py", line 490, in transform
    tracer.run()
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2234, in run
    super().run()
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 884, in run
    while self.step():
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 799, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2423, in RETURN_VALUE
    self._return(inst)
  File "/home/dberard/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2408, in _return
    self.output.compile_subgraph(
  File "/home/dberard/local/pytorch/torch/_dynamo/output_graph.py", line 1083, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/dberard/local/pytorch/torch/_dynamo/output_graph.py", line 1300, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/home/dberard/local/pytorch/torch/_dynamo/utils.py", line 273, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_dynamo/output_graph.py", line 1391, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/home/dberard/local/pytorch/torch/_dynamo/output_graph.py", line 1372, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/home/dberard/local/pytorch/torch/_dynamo/repro/after_dynamo.py", line 127, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/dberard/local/pytorch/torch/_dynamo/repro/after_dynamo.py", line 127, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/dberard/local/pytorch/torch/__init__.py", line 1747, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/dberard/local/pytorch/torch/_inductor/compile_fx.py", line 1478, in compile_fx
    return aot_autograd(
  File "/home/dberard/local/pytorch/torch/_dynamo/backends/common.py", line 65, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/home/dberard/local/pytorch/torch/_functorch/aot_autograd.py", line 962, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/home/dberard/local/pytorch/torch/_dynamo/utils.py", line 273, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/dberard/local/pytorch/torch/_functorch/aot_autograd.py", line 554, in create_aot_dispatcher_function
    fw_metadata = run_functionalized_fw_and_collect_metadata(
  File "/home/dberard/local/pytorch/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 692, in inner
    fw_graph_outs = pytree.tree_map(from_fun, f_fw_graph_outs)
  File "/home/dberard/local/pytorch/torch/utils/_pytree.py", line 943, in tree_map
    return treespec.unflatten(map(func, *flat_args))
  File "/home/dberard/local/pytorch/torch/utils/_pytree.py", line 782, in unflatten
    leaves = list(leaves)
  File "/home/dberard/local/pytorch/torch/_functorch/_aot_autograd/functional_utils.py", line 59, in from_fun
    out = transform_subclass(t, lambda _, inner_t: from_fun(inner_t))
  File "/home/dberard/local/pytorch/torch/utils/_python_dispatch.py", line 322, in transform_subclass
    assert sub.shape == outer_size, (
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError: Expected return value from <class 'torch.nested._internal.nested_tensor.NestedTensor'>__tensor_unflatten__() to have shape equal to torch.Size([4, j2, 8]), but got: torch.Size([4, j3, 8])

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Versions

Described above - but these were all built on 7f1d5ab for H100

cc @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer

@davidberard98 davidberard98 added the module: nestedtensor NestedTensor tag see issue #25032 label May 17, 2024
@jbschlosser
Copy link
Contributor

jbschlosser commented May 17, 2024

I haven't gotten around to investigating this yet. Maybe #126198 is related (just based on unbacked symint <-> NT issues).

I ran into a similar error, which prompted the fix in the linked PR.

FWIW I was able to get your repro working without graph breaks using a combination of:

@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 20, 2024
@davidberard98
Copy link
Contributor Author

Note: the is_concrete_int() issue seems somewhat hard to reproduce, but it appears to happen if the NJT that generated the values is still on the stack in python, when the return happens.

This doesn't always happen - but the easy way to force this to repro is to just return the NJT.

jbschlosser added a commit that referenced this issue Jul 15, 2024
Based on #130292; playing around with different designs, etc.

**Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`).

**Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up.

**This PR:**
* During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details:
    * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`.
    * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv.
    * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry.
        * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now).
* Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates:
    * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager.
    * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
jbschlosser added a commit that referenced this issue Jul 15, 2024
Based on #130292; playing around with different designs, etc.

**Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`).

**Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up.

**This PR:**
* During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details:
    * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`.
    * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv.
    * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry.
        * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now).
* Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates:
    * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager.
    * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
jbschlosser added a commit that referenced this issue Jul 16, 2024
Based on #130292; playing around with different designs, etc.

**Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`).

**Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up.

**This PR:**
* During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details:
    * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`.
    * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv.
    * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry.
        * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now).
* Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates:
    * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager.
    * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
jbschlosser added a commit that referenced this issue Jul 16, 2024
Based on #130292; playing around with different designs, etc.

**Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`).

**Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up.

**This PR:**
* During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details:
    * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`.
    * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv.
    * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry.
        * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now).
* Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates:
    * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager.
    * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
jbschlosser added a commit that referenced this issue Jul 17, 2024
Based on #130292; playing around with different designs, etc.

**Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`).

**Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up.

**This PR:**
* During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details:
    * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`.
    * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv.
    * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry.
        * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now).
* Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates:
    * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager.
    * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
jbschlosser added a commit that referenced this issue Jul 17, 2024
Based on #130292; playing around with different designs, etc.

**Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`).

**Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up.

**This PR:**
* During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details:
    * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`.
    * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv.
    * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry.
        * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now).
* Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates:
    * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager.
    * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
jbschlosser added a commit that referenced this issue Jul 18, 2024
Based on #130292; playing around with different designs, etc.

**Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`).

**Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up.

**This PR:**
* During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details:
    * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`.
    * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv.
    * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry.
        * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now).
* Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates:
    * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager.
    * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
jbschlosser added a commit that referenced this issue Jul 18, 2024
Based on #130292; playing around with different designs, etc.

**Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`).

**Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up.

**This PR:**
* During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details:
    * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`.
    * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv.
    * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry.
        * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now).
* Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates:
    * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager.
    * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
jbschlosser added a commit that referenced this issue Jul 18, 2024
Based on #130292; playing around with different designs, etc.

**Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`).

**Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up.

**This PR:**
* During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details:
    * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`.
    * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv.
    * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry.
        * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now).
* Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates:
    * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager.
    * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
jbschlosser added a commit that referenced this issue Jul 18, 2024
Based on #130292; playing around with different designs, etc.

**Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`).

**Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up.

**This PR:**
* During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details:
    * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`.
    * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv.
    * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry.
        * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now).
* Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates:
    * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager.
    * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
jbschlosser added a commit that referenced this issue Jul 18, 2024
Based on #130292; playing around with different designs, etc.

**Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`).

**Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up.

**This PR:**
* During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details:
    * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`.
    * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv.
    * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry.
        * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now).
* Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates:
    * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager.
    * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
jbschlosser added a commit that referenced this issue Jul 18, 2024
Based on #130292; playing around with different designs, etc.

**Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`).

**Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up.

**This PR:**
* During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details:
    * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`.
    * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv.
    * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry.
        * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now).
* Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates:
    * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager.
    * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
jbschlosser added a commit that referenced this issue Jul 19, 2024
Based on #130292; playing around with different designs, etc.

**Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`).

**Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up.

**This PR:**
* During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details:
    * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`.
    * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv.
    * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry.
        * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now).
* Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates:
    * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager.
    * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
jbschlosser added a commit that referenced this issue Jul 19, 2024
Based on #130292; playing around with different designs, etc.

**Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`).

**Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up.

**This PR:**
* During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details:
    * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`.
    * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv.
    * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry.
        * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now).
* Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates:
    * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager.
    * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
jbschlosser added a commit that referenced this issue Jul 19, 2024
Based on #130292; playing around with different designs, etc.

**Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`).

**Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up.

**This PR:**
* During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details:
    * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`.
    * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv.
    * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry.
        * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now).
* Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates:
    * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager.
    * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
jbschlosser added a commit that referenced this issue Jul 19, 2024
Based on #130292; playing around with different designs, etc.

**Background:** Nested ints uniquely define ragged structure and are associated with metadata tensors `offsets` / `lengths`. This association is maintained within the nested int registry, which maps `offsets` / `lengths` -> `nested int`. If a new `offsets` / `lengths` tensor object is seen during NJT construction, a new nested int (e.g. `j0`) is created to associate with it, and the nested int counter is incremented (next up: `j1`).

**Core problem:** During tracing, various fake tensors and functional tensors are created that *conceptually* refer to the same object as their real source. For fake / functional `offsets` / `lengths`, these should be considered to have the same ragged structure as the real `offsets` / `lengths`, and thus be associated with the same nested int. Further, whenever we're dealing with fake tensors, we should be dealing with symbolic nested ints. These are not done consistently today, leading to all sorts of problems (#126472, #130272, [error](https://gist.github.com/davidberard98/877a52f6ea57025cc122d64361e598da)). To avoid graph breaks or hard errors during in-graph NJT construction, this needs to be cleaned up.

**This PR:**
* During dense tensor fake-ification, mirrors `offsets -> nested int` relationships via `fake offsets -> symbolic nested int` entries within the nested int registry. Details:
    * `describe_tensor()` is updated to store any nested int associated with the real dense tensor as part of its `MetaTensorDesc`.
    * During dense tensor fake-ification, if the tensor's desc has a nested int, it is symbolicized within the relevant ShapeEnv.
    * An `output fake tensor -> symbolic nested int` entry is added to the nested int registry.
        * Note: this is done in two places right now: the usual end of dense tensor fake-ification and within `empty_create_subclass()` (which I think should call `meta_tensor()` recursively instead of what it does now).
* Updates AOTAutograd in a couple places to achieve the same conceptual grouping of `offsets` / `lengths` with associated fake intermediates:
    * During metadata collection (which is run twice), this PR emulates idempotence by restoring the nested int counter state after the fw graph run. Any `offsets` / `lengths` -> `nested int` relationships (and their fake analogues) should have the same nested int values across fw graph runs. This is accomplished via a (somewhat hacky) `freeze_nested_int_counter()` context manager.
    * As per `Note [AOT Autograd: Views to avoid tangents aliasing inputs]`, `view_avoid_dupes_with_primals()` purposefully creates an aliased output via `out = t.view(t.shape)`. Conceptually, `out` here should be associated with the same ragged structure as `t`, so this PR updates the nested int registry to maintain this relationship.

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: nestedtensor NestedTensor tag see issue #25032 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants