Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Failed to export model to ONNX #9

Closed
ArielleBF opened this issue Dec 7, 2023 · 14 comments
Closed

Failed to export model to ONNX #9

ArielleBF opened this issue Dec 7, 2023 · 14 comments

Comments

@ArielleBF
Copy link

ArielleBF commented Dec 7, 2023

I have tried to export the encoder of the model to ONNX, but it informs me that the export has failed. Can anyone who has done relevant work give some advice?

@ximitiejiang
Copy link

same question.

@yformer
Copy link
Owner

yformer commented Dec 8, 2023

@ArielleBF @ximitiejiang can you share more information on the export? I will try to export one and share it.

@ArielleBF
Copy link
Author

@ArielleBF @ximitiejiang can you share more information on the export? I will try to export one and share it.
When I try to export the image_encoder model using torch.onnx.export, a bug occurred.It is possible that the bug is caused by incompatibility between the jit format file and onnx during the export.

@ximitiejiang
Copy link

ximitiejiang commented Dec 8, 2023

hello, i tried to export onnx with below codes which are similiar with SAM

def export_onnx():
    model = torch.jit.load('efficientsam_s_gpu.jit')
    output = "efficientsam_s_gpu.onnx"
    output_names = ["masks_predictions", "iou_predictions"]
    dummy_inputs = {
        "images": torch.randn((1, 3, 1024, 1024), dtype=torch.float),
        "point_coords": torch.randint(low=0, high=1024, size=(1, 1, 5, 2), dtype=torch.float),
        "point_labels": torch.randint(low=0, high=4, size=(1, 1, 5), dtype=torch.float)
    }
    dynamic_axes = {
        # "images": {2: "img_height", 3: "img_width"},
        "point_coords": {2: "num_points"},
        "point_labels": {2: "num_points"},
    }
    opset = 14
    with open(output, "wb") as f:
        print(f"Exporting onnx model to {output}...")
        torch.onnx.export(
            model,
            tuple(dummy_inputs.values()),
            f,
            export_params=True,
            verbose=False,
            opset_version=opset,
            do_constant_folding=True,
            input_names=list(dummy_inputs.keys()),
            output_names=output_names,
            dynamic_axes=dynamic_axes,
        )

and i got error as below:
torch.onnx.symbolic_registry.UnsupportedOperatorError: Exporting the operator ::tile to ONNX opset version 14 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub.

@yformer
Copy link
Owner

yformer commented Dec 10, 2023

@ArielleBF @ximitiejiang thanks for sharing the information. We will be trying to export one.

@chenin-wang
Copy link

same question.

1 similar comment
@alanzhai219
Copy link

same question.

@fPecc
Copy link

fPecc commented Dec 11, 2023

Same problem here trying to export 'efficientsam_ti_gpu.jit' to ONNX using PyTorch 2.0.1 and opset_version set to 18.

@kaka-lin
Copy link

Same problem

1 similar comment
@lxfater
Copy link

lxfater commented Dec 13, 2023

Same problem

@mchaniotakis
Copy link

The problem originates from torch.tile here:

        # Tile the image embedding for all queries.
        image_embeddings_tiled = torch.tile(
            image_embeddings[:, None, :, :, :], [1, max_num_queries, 1, 1, 1]
        ).view(
            batch_size * max_num_queries,
            image_embed_dim_c,
            image_embed_dim_h,
            image_embed_dim_w,
        )

This could probably be solved by replacing torch.tile() with tensor.repeat() or using a symbolic for onnx to patch it

@yformer
Copy link
Owner

yformer commented Jan 12, 2024

@ArielleBF, @ximitiejiang, @chenin-wang, @alanzhai219, @fPecc, @kaka-lin, @lxfater, @mchaniotakis, EfficientSAM onnx files are available at Hugging Face Space. The export script and running example are provided. Feel free to give it a try.

@yacineMTB
Copy link

Thanks yformer <3
FYI for other readers; the export tooling is on this repository

def export_onnx_esam_encoder(model, output):
onnx_model = onnx_models.OnnxEfficientSamEncoder(model=model)
dynamic_axes = {
"batched_images": {0: "batch", 2: "height", 3: "width"},
}
dummy_inputs = {
"batched_images": torch.randn(1, 3, 1080, 1920, dtype=torch.float),
}
output_names = ["image_embeddings"]
export_onnx(
onnx_model=onnx_model,
output=output,
dynamic_axes=dynamic_axes,
dummy_inputs=dummy_inputs,
output_names=output_names,
)

@kaka-lin
Copy link

@ArielleBF, @ximitiejiang, @chenin-wang, @alanzhai219, @fPecc, @kaka-lin, @lxfater, @mchaniotakis, EfficientSAM onnx files are available at Hugging Face Space. The export script and running example are provided. Feel free to give it a try.

Thanks yformer ~

I also create TensorFlow2.x version and coverted to tflite model
If anyone need to use tflite to inference please check here EfficientSAM-tf2-demo.

Thanks!!!

@yformer yformer closed this as completed Jan 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

10 participants