Skip to content

✨[Feature] Add pickle_protocol argument for torch_tensorrt.save #3294

Closed
@fortminors

Description

@fortminors

Is your feature request related to a problem? Please describe.
I am trying to save an optimized CLIP model to the disk, but it's too large (around 8 Gb if I remember correctly) for pickle to handle with pickle_protocol < 4 (which is what is used by default) - it gives OverflowError: serializing a string larger than 4 GiB requires pickle protocol 4 or higher

Describe the solution you'd like
It would be great if you could add pickle_protocol argument to torch_tensorrt.save method, similar to how torch.save does it, so it would be possible to save large optimized models

Describe alternatives you've considered
I've tried torch.save, but it does not save all of the required information about the optimized model, so it yields the following error when I do model = torch.export.load(model_path).module():

    model = torch.export.load(model_path).module()
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.venv/lib/python3.11/site-packages/torch/export/__init__.py", line 421, in load
    version = zipf.read("version").decode().split(".")
              ^^^^^^^^^^^^^^^^^^^^
  File "/python/cpython-3.11.10-linux-x86_64-gnu/lib/python3.11/zipfile.py", line 1527, in read
    with self.open(name, "r", pwd) as fp:
         ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/python/cpython-3.11.10-linux-x86_64-gnu/lib/python3.11/zipfile.py", line 1564, in open
    zinfo = self.getinfo(name)
            ^^^^^^^^^^^^^^^^^^
  File "/python/cpython-3.11.10-linux-x86_64-gnu/lib/python3.11/zipfile.py", line 1493, in getinfo
    raise KeyError(
KeyError: "There is no item named 'version' in the archive"

Additional context
MRE (requires pip install torch torch_tensorrt open_clip):

import torch
import torch_tensorrt
import open_clip


torch.set_float32_matmul_precision('high')


if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, _, preprocess = open_clip.create_model_and_transforms("convnext_xxlarge", pretrained="laion2b_s34b_b82k_augreg_soup")

    model = model.visual.to(device).eval()
    image_size = model.image_size

    image = torch.randn((1, 3, *image_size)).to(device) # (1, 3, 256, 256)

    model = torch_tensorrt.compile(model, ir="dynamo", inputs=[image])
    model(image)
    torch_tensorrt.save(model, "convnext_xxlarge_compiled.ep", inputs=[image])

Running the above script gives the following result:

WARNING:py.warnings:/.venv/lib/python3.11/site-packages/torch_tensorrt/dynamo/_exporter.py:387: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
  engine_node = gm.graph.get_attr(engine_name)

WARNING:py.warnings:/.venv/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_0_engine target _run_on_acc_0_engine _run_on_acc_0_engine of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '

Traceback (most recent call last):
  File "/src/repro.py", line 20, in <module>
    torch_tensorrt.save(model, "convnext_xxlarge_compiled.ep", inputs=[image])
  File "/.venv/lib/python3.11/site-packages/torch_tensorrt/_compile.py", line 529, in save
    torch.export.save(exp_program, file_path)
  File "/.venv/lib/python3.11/site-packages/torch/export/__init__.py", line 341, in save
    artifact: SerializedArtifact = serialize(ep, opset_version)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.venv/lib/python3.11/site-packages/torch/_export/serde/serialize.py", line 2374, in serialize
    serialized_program = ExportedProgramSerializer(opset_version).serialize(
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.venv/lib/python3.11/site-packages/torch/_export/serde/serialize.py", line 1410, in serialize
    serialize_torch_artifact(constants),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.venv/lib/python3.11/site-packages/torch/_export/serde/serialize.py", line 312, in serialize_torch_artifact
    torch.save(artifact, buffer)
  File "/.venv/lib/python3.11/site-packages/torch/serialization.py", line 850, in save
    _save(
  File "/.venv/lib/python3.11/site-packages/torch/serialization.py", line 1088, in _save
    pickler.dump(obj)
OverflowError: serializing a string larger than 4 GiB requires pickle protocol 4 or higher

By default, torch uses pickle protocol 2 as per this line

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions