In [2]:
import torch
from typing import NamedTuple
model = torch.hub.load('facebookresearch/detr',
                       'detr_resnet50',
                       pretrained=True,
                       num_classes=91)

model.eval()
dummy_input = torch.randn(1, 3, 800, 800) 

class DETROutput(NamedTuple):
    pred_logits: torch.Tensor
    pred_boxes: torch.Tensor

class DETRTracedWrapper(torch.nn.Module):
    def __init__(self, model):
        super(DETRTracedWrapper, self).__init__()
        self.model = model

    def forward(self, x):
        outputs = self.model(x)
        return DETROutput(outputs["pred_logits"], outputs["pred_boxes"])

wrapped_model = DETRTracedWrapper(model)

traced_model = torch.jit.trace(wrapped_model, dummy_input)
traced_model.save("detr_traced.pt")

class DETROnnxWrapper(torch.nn.Module):
    def __init__(self, model):
        super(DETROnnxWrapper, self).__init__()
        self.model = model

    def forward(self, x):
        outputs = self.model(x)
        # Extracting desired outputs (like 'pred_logits' and 'pred_boxes'). You can adjust as needed.
        return outputs["pred_logits"], outputs["pred_boxes"]

onnx_model = DETROnnxWrapper(model)

torch.onnx.export(onnx_model, dummy_input, "detr.onnx", opset_version=11)


Using cache found in /home/slava/.cache/torch/hub/facebookresearch_detr_main


In [8]:
import onnx

onnx_model = onnx.load("detr.onnx")
onnx.checker.check_model(onnx_model)
print(onnx.helper.printable_graph(onnx_model.graph))

graph main_graph (
  %samples[FLOAT, 1x3x800x800]
) initializers (
  %model.transformer.encoder.layers.0.self_attn.out_proj.weight[FLOAT, 256x256]
  %model.transformer.encoder.layers.0.self_attn.out_proj.bias[FLOAT, 256]
  %model.transformer.encoder.layers.0.linear1.bias[FLOAT, 2048]
  %model.transformer.encoder.layers.0.linear2.bias[FLOAT, 256]
  %model.transformer.encoder.layers.0.norm1.weight[FLOAT, 256]
  %model.transformer.encoder.layers.0.norm1.bias[FLOAT, 256]
  %model.transformer.encoder.layers.0.norm2.weight[FLOAT, 256]
  %model.transformer.encoder.layers.0.norm2.bias[FLOAT, 256]
  %model.transformer.encoder.layers.1.self_attn.out_proj.weight[FLOAT, 256x256]
  %model.transformer.encoder.layers.1.self_attn.out_proj.bias[FLOAT, 256]
  %model.transformer.encoder.layers.1.linear1.bias[FLOAT, 2048]
  %model.transformer.encoder.layers.1.linear2.bias[FLOAT, 256]
  %model.transformer.encoder.layers.1.norm1.weight[FLOAT, 256]
  %model.transformer.encoder.layers.1.norm1.bias[FLOAT, 256]