In [20]:
import numpy as np

import torch
import torchvision

In [21]:
in_size = 320

input_shape = (3, in_size, in_size)

In [22]:
def do_trace(model, inp):
    model_trace = torch.jit.trace(model, inp)
    model_trace.eval()
    return model_trace

In [23]:
def dict_to_tuple(out_dict):
    if "masks" in out_dict.keys():
        return out_dict["boxes"], out_dict["scores"], out_dict["labels"], out_dict["masks"]
    return out_dict["boxes"], out_dict["scores"], out_dict["labels"]

In [24]:
class TraceWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model.to(device)
        
    def forward(self, inp):
        out = self.model(inp.to("cuda"))
        bboxes, scores, labels = dict_to_tuple(out[0])
        return bboxes, scores, labels

In [29]:
model = torchvision.models.detection.ssdlite320_mobilenet_v3_large()
model.load_state_dict(torch.load("./downloads/ssdlite320_mobilenet_v3.pth"))

<All keys matched successfully>

In [30]:
device = torch.device("cuda")

In [31]:
model = TraceWrapper(model)
model.eval()

TraceWrapper(
  (model): SSD(
    (backbone): SSDLiteFeatureExtractorMobileNet(
      (features): Sequential(
        (0): Sequential(
          (0): ConvBNActivation(
            (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
            (1): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
            (2): Hardswish()
          )
          (1): InvertedResidual(
            (block): Sequential(
              (0): ConvBNActivation(
                (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
                (1): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
                (2): ReLU(inplace=True)
              )
              (1): ConvBNActivation(
                (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (1): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      

In [32]:
inp = torch.Tensor(np.random.uniform(0.0, 250.0, size=(1, 3, in_size, in_size)))

In [33]:
with torch.no_grad():
    out = model(inp)
    script_module = do_trace(model, inp)

In [34]:
script_module

TraceWrapper(
  original_name=TraceWrapper
  (model): SSD(
    original_name=SSD
    (backbone): SSDLiteFeatureExtractorMobileNet(
      original_name=SSDLiteFeatureExtractorMobileNet
      (features): Sequential(
        original_name=Sequential
        (0): Sequential(
          original_name=Sequential
          (0): ConvBNActivation(
            original_name=ConvBNActivation
            (0): Conv2d(original_name=Conv2d)
            (1): BatchNorm2d(original_name=BatchNorm2d)
            (2): Hardswish(original_name=Hardswish)
          )
          (1): InvertedResidual(
            original_name=InvertedResidual
            (block): Sequential(
              original_name=Sequential
              (0): ConvBNActivation(
                original_name=ConvBNActivation
                (0): Conv2d(original_name=Conv2d)
                (1): BatchNorm2d(original_name=BatchNorm2d)
                (2): ReLU(original_name=ReLU)
              )
              (1): ConvBNActivation(
          

In [35]:
script_module(inp)

(tensor([[  1.0560,   0.0000, 315.9628, 308.1848],
         [  0.0000, 137.2875, 151.2692, 320.0000],
         [ 17.4331,   0.0000, 167.8567,  86.0987],
         ...,
         [250.1722, 225.1313, 274.6321, 245.8570],
         [277.7100,  98.7329, 320.0000, 139.3244],
         [168.9152,  60.3280, 203.7917,  95.9146]], device='cuda:0',
        grad_fn=<StackBackward>),
 tensor([0.9612, 0.9088, 0.9040, 0.8987, 0.8844, 0.7765, 0.7206, 0.5646, 0.4750,
         0.4192, 0.3903, 0.3790, 0.3775, 0.3312, 0.2872, 0.2592, 0.2153, 0.1811,
         0.1691, 0.1642, 0.1614, 0.1420, 0.1392, 0.1373, 0.0992, 0.0940, 0.0936,
         0.0904, 0.0717, 0.0710, 0.0708, 0.0678, 0.0646, 0.0591, 0.0586, 0.0570,
         0.0564, 0.0539, 0.0522, 0.0514, 0.0513, 0.0496, 0.0496, 0.0495, 0.0493,
         0.0478, 0.0477, 0.0472, 0.0462, 0.0462, 0.0444, 0.0434, 0.0421, 0.0418,
         0.0409, 0.0401, 0.0400, 0.0394, 0.0385, 0.0381, 0.0372, 0.0370, 0.0366,
         0.0365, 0.0364, 0.0359, 0.0357, 0.0357, 0.0354, 0.03

In [36]:
#script_module.save("./downloads/ssdlite320_mobilenet_v3_traced.pt")
script_module.save("../models/ssd_mobilenet_v3/1/model.pt")

In [37]:
loaded = torch.jit.load("../models/ssd_mobilenet_v3/1/model.pt")

In [22]:
print(loaded)
print(loaded.code)

RecursiveScriptModule(
  original_name=TraceWrapper
  (model): RecursiveScriptModule(
    original_name=SSD
    (backbone): RecursiveScriptModule(
      original_name=SSDLiteFeatureExtractorMobileNet
      (features): RecursiveScriptModule(
        original_name=Sequential
        (0): RecursiveScriptModule(
          original_name=Sequential
          (0): RecursiveScriptModule(
            original_name=ConvBNActivation
            (0): RecursiveScriptModule(original_name=Conv2d)
            (1): RecursiveScriptModule(original_name=BatchNorm2d)
            (2): RecursiveScriptModule(original_name=Hardswish)
          )
          (1): RecursiveScriptModule(
            original_name=InvertedResidual
            (block): RecursiveScriptModule(
              original_name=Sequential
              (0): RecursiveScriptModule(
                original_name=ConvBNActivation
                (0): RecursiveScriptModule(original_name=Conv2d)
                (1): RecursiveScriptModule(original_n

In [38]:
loaded(inp)

(tensor([[  1.0560,   0.0000, 315.9628, 308.1848],
         [  0.0000, 137.2875, 151.2692, 320.0000],
         [ 17.4331,   0.0000, 167.8567,  86.0987],
         ...,
         [250.1722, 225.1313, 274.6321, 245.8570],
         [277.7100,  98.7329, 320.0000, 139.3244],
         [168.9152,  60.3280, 203.7917,  95.9146]], device='cuda:0',
        grad_fn=<StackBackward>),
 tensor([0.9612, 0.9088, 0.9040, 0.8987, 0.8844, 0.7765, 0.7206, 0.5646, 0.4750,
         0.4192, 0.3903, 0.3790, 0.3775, 0.3312, 0.2872, 0.2592, 0.2153, 0.1811,
         0.1691, 0.1642, 0.1614, 0.1420, 0.1392, 0.1373, 0.0992, 0.0940, 0.0936,
         0.0904, 0.0717, 0.0710, 0.0708, 0.0678, 0.0646, 0.0591, 0.0586, 0.0570,
         0.0564, 0.0539, 0.0522, 0.0514, 0.0513, 0.0496, 0.0496, 0.0495, 0.0493,
         0.0478, 0.0477, 0.0472, 0.0462, 0.0462, 0.0444, 0.0434, 0.0421, 0.0418,
         0.0409, 0.0401, 0.0400, 0.0394, 0.0385, 0.0381, 0.0372, 0.0370, 0.0366,
         0.0365, 0.0364, 0.0359, 0.0357, 0.0357, 0.0354, 0.03