Skip to content

Cannot serialize NJTs using torch.save #129366

@vmoens

Description

@vmoens

🐛 Describe the bug

torch.save fails when saving NJTs:

layout = torch.jagged
g0 = torch.zeros(1)
g1 = torch.zeros(2)
g = torch.nested.nested_tensor([g0, g1], layout=layout)
torch.save(g, "file.p")

which raises

  File "<ipython-input-8-962d8302f59b>", line 1, in <module>
    torch.save(g, "file.p")
  File "/Users/vmoens/venv/rl/lib/python3.11/site-packages/torch/serialization.py", line 726, in save
    _save(
  File "/Users/vmoens/venv/rl/lib/python3.11/site-packages/torch/serialization.py", line 954, in _save
    pickler.dump(obj)
TypeError: cannot pickle 'PyCapsule' object

cc @mruberry @mikaylagawarecki @cpuhrsch @jbschlosser @bhosmer @drisspg @soulitzer

Versions

PT nightly

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNot as big of a feature, but technically not a bug. Should be easy to fixmodule: nestedtensorNestedTensor tag see issue #25032module: serializationIssues related to serialization (e.g., via pickle, or otherwise) of PyTorch objectstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions