In [61]:
import numpy as np

import torch
import torchvision

In [62]:
in_size = 320

input_shape = (3, in_size, in_size)

In [63]:
def do_trace(model, inp):
    model_trace = torch.jit.trace(model, inp)
    model_trace.eval()
    return model_trace

In [64]:
def dict_to_tuple(out_dict):
    if "masks" in out_dict.keys():
        return out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"]
    return out_dict["boxes"], out_dict["scores"], out_dict["labels"]

In [65]:
class TraceWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, inp):
        out = self.model(inp)
        bboxes, scores, labels = dict_to_tuple(out[0])
        return bboxes, scores, labels

In [66]:
model = torchvision.models.detection.ssdlite320_mobilenet_v3_large()
model.load_state_dict(torch.load("./downloads/ssdlite320_mobilenet_v3.pth"))

<All keys matched successfully>

In [67]:
model = TraceWrapper(model)
model.eval()

TraceWrapper(
  (model): SSD(
    (backbone): SSDLiteFeatureExtractorMobileNet(
      (features): Sequential(
        (0): Sequential(
          (0): ConvBNActivation(
            (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
            (1): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
            (2): Hardswish()
          )
          (1): InvertedResidual(
            (block): Sequential(
              (0): ConvBNActivation(
                (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
                (1): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
                (2): ReLU(inplace=True)
              )
              (1): ConvBNActivation(
                (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      

In [54]:
with torch.no_grad():
    out = model(inp)
    script_module = do_trace(model, inp)

In [34]:
script_module

TraceWrapper(
  original_name=TraceWrapper
  (model): SSD(
    original_name=SSD
    (backbone): SSDLiteFeatureExtractorMobileNet(
      original_name=SSDLiteFeatureExtractorMobileNet
      (features): Sequential(
        original_name=Sequential
        (0): Sequential(
          original_name=Sequential
          (0): ConvBNActivation(
            original_name=ConvBNActivation
            (0): Conv2d(original_name=Conv2d)
            (1): BatchNorm2d(original_name=BatchNorm2d)
            (2): Hardswish(original_name=Hardswish)
          )
          (1): InvertedResidual(
            original_name=InvertedResidual
            (block): Sequential(
              original_name=Sequential
              (0): ConvBNActivation(
                original_name=ConvBNActivation
                (0): Conv2d(original_name=Conv2d)
                (1): BatchNorm2d(original_name=BatchNorm2d)
                (2): ReLU(original_name=ReLU)
              )
              (1): ConvBNActivation(
          

In [68]:
script_module(inp)

(tensor([[0.0000e+00, 0.0000e+00, 2.6767e+01, 1.2360e+02],
         [0.0000e+00, 1.4099e+02, 1.2757e+02, 3.2000e+02],
         [1.9929e+00, 6.6502e+00, 3.1752e+02, 3.1276e+02],
         ...,
         [1.9063e+02, 1.2368e-01, 2.6778e+02, 1.0608e+02],
         [1.9049e+02, 7.9263e+01, 1.9650e+02, 9.1153e+01],
         [2.7102e+02, 9.1529e+01, 3.2000e+02, 1.9755e+02]],
        grad_fn=<StackBackward>),
 tensor([0.9624, 0.9552, 0.9440, 0.8716, 0.6564, 0.4826, 0.4420, 0.3907, 0.3672,
         0.3627, 0.3362, 0.3287, 0.3282, 0.3186, 0.2864, 0.2682, 0.2631, 0.2579,
         0.2385, 0.2109, 0.1985, 0.1955, 0.1807, 0.1699, 0.1680, 0.1618, 0.1543,
         0.1408, 0.1398, 0.1313, 0.1312, 0.1202, 0.1184, 0.1140, 0.1056, 0.0962,
         0.0912, 0.0899, 0.0827, 0.0813, 0.0781, 0.0770, 0.0718, 0.0715, 0.0668,
         0.0667, 0.0663, 0.0633, 0.0631, 0.0627, 0.0612, 0.0597, 0.0583, 0.0565,
         0.0554, 0.0542, 0.0491, 0.0486, 0.0477, 0.0442, 0.0438, 0.0431, 0.0427,
         0.0401, 0.0399, 0.038

In [70]:
#script_module.save("./downloads/ssdlite320_mobilenet_v3_traced.pt")
script_module.save("../models/ssd_mobilenet_v3/1/model.pt")

In [58]:
loaded = torch.jit.load("./downloads/ssdlite320_mobilenet_v3_traced.pt")

In [37]:
print(loaded)
print(loaded.code)

RecursiveScriptModule(
  original_name=TraceWrapper
  (model): RecursiveScriptModule(
    original_name=SSD
    (backbone): RecursiveScriptModule(
      original_name=SSDLiteFeatureExtractorMobileNet
      (features): RecursiveScriptModule(
        original_name=Sequential
        (0): RecursiveScriptModule(
          original_name=Sequential
          (0): RecursiveScriptModule(
            original_name=ConvBNActivation
            (0): RecursiveScriptModule(original_name=Conv2d)
            (1): RecursiveScriptModule(original_name=BatchNorm2d)
            (2): RecursiveScriptModule(original_name=Hardswish)
          )
          (1): RecursiveScriptModule(
            original_name=InvertedResidual
            (block): RecursiveScriptModule(
              original_name=Sequential
              (0): RecursiveScriptModule(
                original_name=ConvBNActivation
                (0): RecursiveScriptModule(original_name=Conv2d)
                (1): RecursiveScriptModule(original_n

In [55]:
inp = torch.Tensor(np.random.uniform(0.0, 250.0, size=(1, 3, in_size, in_size)))

In [59]:
loaded(inp)

(tensor([[0.0000e+00, 0.0000e+00, 2.6767e+01, 1.2360e+02],
         [0.0000e+00, 1.4099e+02, 1.2757e+02, 3.2000e+02],
         [1.9929e+00, 6.6502e+00, 3.1752e+02, 3.1276e+02],
         ...,
         [1.9063e+02, 1.2368e-01, 2.6778e+02, 1.0608e+02],
         [1.9049e+02, 7.9263e+01, 1.9650e+02, 9.1153e+01],
         [2.7102e+02, 9.1529e+01, 3.2000e+02, 1.9755e+02]],
        grad_fn=<StackBackward>),
 tensor([0.9624, 0.9552, 0.9440, 0.8716, 0.6564, 0.4826, 0.4420, 0.3907, 0.3672,
         0.3627, 0.3362, 0.3287, 0.3282, 0.3186, 0.2864, 0.2682, 0.2631, 0.2579,
         0.2385, 0.2109, 0.1985, 0.1955, 0.1807, 0.1699, 0.1680, 0.1618, 0.1543,
         0.1408, 0.1398, 0.1313, 0.1312, 0.1202, 0.1184, 0.1140, 0.1056, 0.0962,
         0.0912, 0.0899, 0.0827, 0.0813, 0.0781, 0.0770, 0.0718, 0.0715, 0.0668,
         0.0667, 0.0663, 0.0633, 0.0631, 0.0627, 0.0612, 0.0597, 0.0583, 0.0565,
         0.0554, 0.0542, 0.0491, 0.0486, 0.0477, 0.0442, 0.0438, 0.0431, 0.0427,
         0.0401, 0.0399, 0.038