Description
Is your feature request related to a problem? Please describe.
I am trying to save an optimized CLIP model to the disk, but it's too large (around 8 Gb if I remember correctly) for pickle to handle with pickle_protocol < 4 (which is what is used by default) - it gives OverflowError: serializing a string larger than 4 GiB requires pickle protocol 4 or higher
Describe the solution you'd like
It would be great if you could add pickle_protocol
argument to torch_tensorrt.save
method, similar to how torch.save
does it, so it would be possible to save large optimized models
Describe alternatives you've considered
I've tried torch.save
, but it does not save all of the required information about the optimized model, so it yields the following error when I do model = torch.export.load(model_path).module()
:
model = torch.export.load(model_path).module()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/torch/export/__init__.py", line 421, in load
version = zipf.read("version").decode().split(".")
^^^^^^^^^^^^^^^^^^^^
File "/python/cpython-3.11.10-linux-x86_64-gnu/lib/python3.11/zipfile.py", line 1527, in read
with self.open(name, "r", pwd) as fp:
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/python/cpython-3.11.10-linux-x86_64-gnu/lib/python3.11/zipfile.py", line 1564, in open
zinfo = self.getinfo(name)
^^^^^^^^^^^^^^^^^^
File "/python/cpython-3.11.10-linux-x86_64-gnu/lib/python3.11/zipfile.py", line 1493, in getinfo
raise KeyError(
KeyError: "There is no item named 'version' in the archive"
Additional context
MRE (requires pip install torch torch_tensorrt open_clip
):
import torch
import torch_tensorrt
import open_clip
torch.set_float32_matmul_precision('high')
if __name__ == "__main__":
device = "cuda" if torch.cuda.is_available() else "cpu"
model, _, preprocess = open_clip.create_model_and_transforms("convnext_xxlarge", pretrained="laion2b_s34b_b82k_augreg_soup")
model = model.visual.to(device).eval()
image_size = model.image_size
image = torch.randn((1, 3, *image_size)).to(device) # (1, 3, 256, 256)
model = torch_tensorrt.compile(model, ir="dynamo", inputs=[image])
model(image)
torch_tensorrt.save(model, "convnext_xxlarge_compiled.ep", inputs=[image])
Running the above script gives the following result:
WARNING:py.warnings:/.venv/lib/python3.11/site-packages/torch_tensorrt/dynamo/_exporter.py:387: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
engine_node = gm.graph.get_attr(engine_name)
WARNING:py.warnings:/.venv/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_0_engine target _run_on_acc_0_engine _run_on_acc_0_engine of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
Traceback (most recent call last):
File "/src/repro.py", line 20, in <module>
torch_tensorrt.save(model, "convnext_xxlarge_compiled.ep", inputs=[image])
File "/.venv/lib/python3.11/site-packages/torch_tensorrt/_compile.py", line 529, in save
torch.export.save(exp_program, file_path)
File "/.venv/lib/python3.11/site-packages/torch/export/__init__.py", line 341, in save
artifact: SerializedArtifact = serialize(ep, opset_version)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/torch/_export/serde/serialize.py", line 2374, in serialize
serialized_program = ExportedProgramSerializer(opset_version).serialize(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/torch/_export/serde/serialize.py", line 1410, in serialize
serialize_torch_artifact(constants),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.venv/lib/python3.11/site-packages/torch/_export/serde/serialize.py", line 312, in serialize_torch_artifact
torch.save(artifact, buffer)
File "/.venv/lib/python3.11/site-packages/torch/serialization.py", line 850, in save
_save(
File "/.venv/lib/python3.11/site-packages/torch/serialization.py", line 1088, in _save
pickler.dump(obj)
OverflowError: serializing a string larger than 4 GiB requires pickle protocol 4 or higher
By default, torch uses pickle protocol 2 as per this line