In [1]:
import torch
from pathlib import Path

In [2]:
original_model_path = Path("od-load-test-model/1/model.pth")

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
model = torch.load(original_model_path)
model = model.to(device)
model = model.eval()

In [5]:
class JitWrapper(torch.nn.Module):
    # Modify original model to take int8 inputs and return Tuple[Tensor] results
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, inp):
        
        inp = inp.div(255)
        
        # this will make the jit model work with arbitrary batch sizes by splitting the input tensor along the batch axis
        _, preds = self.model([t.squeeze() for t in torch.split(inp,1)])
        
        bboxes = torch.stack([pred["boxes"] for pred in  preds])
        labels = torch.stack([pred["labels"] for pred in  preds])
        scores = torch.stack([pred["scores"] for pred in  preds])
        
        return bboxes, labels, scores

In [6]:
wrapped_model = JitWrapper(model)

In [7]:
jit_model = torch.jit.script(wrapped_model)

In [8]:
jit_model.save("resnet_fpn.pt")

In [9]:
import tarfile
model_name = "resnet_fpn"

model_tar_path = Path(f"{model_name}.tar.gz")

tar = tarfile.open(model_tar_path, "w:gz")
tar.add("resnet_fpn.pt", f"1/model.pt")
tar.add("config.pbtxt", "config.pbtxt")
tar.close()