In [1]:
import os
import json
import argparse
import torch
import numpy as np

import sys
sys.path.append("../FFTRadNet/")

CUDA_fraktion = 0.25
torch.cuda.set_per_process_memory_fraction(CUDA_fraktion, device=None)

from model.FFTRadNet import FFTRadNet
from dataset.dataset import RADIal
from dataset.encoder import ra_encoder
from matplotlib import pyplot as plt
from utils.util import DisplayHMI

In [2]:
config = json.load(open('../../FFTRadNet_RA_192_56_epoch78_loss_172.8239_AP_0.9813/config.json'))

model = FFTRadNet(blocks = config['model']['backbone_block'],
                        mimo_layer  = config['model']['MIMO_output'],
                        channels = config['model']['channels'], 
                        regression_layer = 2, 
                        detection_head = config['model']['DetectionHead'], 
                        segmentation_head = config['model']['SegmentationHead'])

model.to('cuda')

dict = torch.load('../../FFTRadNet_RA_192_56_epoch78_loss_172.8239_AP_0.9813/FFTRadNet_RA_192_56_epoch78_loss_172.8239_AP_0.9813.pth')
model.load_state_dict(dict['net_state_dict'])
model.eval()

FFTRadNet(
  (FPN): FPN_BackBone(
    (pre_enc): MIMO_PreEncoder(
      (conv): Conv2d(32, 192, kernel_size=(1, 12), stride=(1, 1), dilation=(1, 16), bias=False)
      (bn): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (block1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(32, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(128,

In [3]:
print(model)

FFTRadNet(
  (FPN): FPN_BackBone(
    (pre_enc): MIMO_PreEncoder(
      (conv): Conv2d(32, 192, kernel_size=(1, 12), stride=(1, 1), dilation=(1, 16), bias=False)
      (bn): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (block1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(32, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(128,

In [43]:
# Prepare the real input data.
# Copy all layers from the reference model except last two layers.
# So the last layer of all_except_dh_sh is RangeAngle_Decoder:
all_except_dh_sh = copy.deepcopy(model)
all_except_dh_sh.load_state_dict(dict['net_state_dict'])
all_except_dh_sh = torch.nn.Sequential(*list(all_except_dh_sh.children())[:-2])
all_except_dh_sh.eval()

Sequential(
  (0): FPN_BackBone(
    (pre_enc): MIMO_PreEncoder(
      (conv): Conv2d(32, 192, kernel_size=(1, 12), stride=(1, 1), dilation=(1, 16), bias=False)
      (bn): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (block1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(32, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(128, 

In [44]:
# Get the input data for the Detection_Header:
enc = ra_encoder(geometry = config['dataset']['geometry'], 
                        statistics = config['dataset']['statistics'],
                        regression_layer = 2)

dataset = RADIal(root_dir = config['dataset']['root_dir'],
                        statistics= config['dataset']['statistics'],
                        encoder=enc.encode)
for data in dataset:
    inputs = torch.tensor(data[0]).permute(2,0,1).to('cuda').float().unsqueeze(0)
    y_pred = all_except_dh_sh(inputs)
# save the all_except_dh_sh output / tensor to file:
torch.save(y_pred, 'output_fftradnet_upto_rangeangle_model_014606.pt')

In [46]:
# load the model output / tensor from file:
y_pred_loaded = torch.load('output_fftradnet_upto_rangeangle_model_014606.pt')
print(y_pred_loaded.shape)
print(y_pred_loaded)

torch.Size([1, 256, 128, 224])
tensor([[[[0.0000e+00, 2.5786e+00, 2.3156e+00,  ..., 0.0000e+00,
           0.0000e+00, 8.7500e+00],
          [8.3361e-01, 1.0027e+01, 8.7724e+00,  ..., 2.5602e-01,
           0.0000e+00, 1.0810e+01],
          [0.0000e+00, 4.6739e+00, 2.0324e+00,  ..., 0.0000e+00,
           0.0000e+00, 1.0848e+01],
          ...,
          [0.0000e+00, 2.5932e+00, 1.5007e+00,  ..., 0.0000e+00,
           2.8037e-01, 1.1255e+01],
          [0.0000e+00, 3.8899e+00, 3.3342e+00,  ..., 7.7271e-01,
           3.3343e+00, 1.2421e+01],
          [0.0000e+00, 0.0000e+00, 3.3878e-03,  ..., 0.0000e+00,
           2.8874e-01, 6.3513e+00]],

         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 7.0688e-01,
           2.8319e+00, 7.2265e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.7522e-01,
           2.6298e+00, 1.0072e+01],
          [0.0000e+00, 0.0000e+00, 3.8028e-01,  ..., 0.0000e+00,
           7.3995e-01, 9.1913e+00],
          ...,
          [0.0000e+00, 1.441

In [3]:
# get the keys:
# q = model.state_dict()
# print(q.keys())

# copy part of FFTRadNet model - DH_model - detection header
DH_model = model.detection_header
DH_model_dict = model.detection_header.state_dict()
for key in sorted(DH_model_dict.keys()):
    parameter = DH_model_dict[key]
    print(key)
    print(parameter.size())
    print(parameter)

bn1.bias
torch.Size([144])
tensor([ 0.0241, -0.0140,  0.0318,  0.0055, -0.0433, -0.0232, -0.0399, -0.0247,
         0.0386, -0.0121, -0.0513,  0.0499, -0.0461, -0.0028,  0.0389, -0.0214,
        -0.0252,  0.0398, -0.0520,  0.0248, -0.0237, -0.0544, -0.0455, -0.0498,
        -0.0046,  0.0292, -0.0573, -0.0514, -0.0181,  0.0491,  0.0282, -0.0658,
        -0.0457,  0.0176,  0.0243, -0.0161,  0.0257,  0.0450, -0.0004,  0.0265,
        -0.0367, -0.0466, -0.0155,  0.0499, -0.0300, -0.0211,  0.0347, -0.0661,
        -0.0032, -0.0265,  0.0077,  0.0948, -0.0142,  0.0069, -0.0477,  0.0751,
         0.0141,  0.0349, -0.0080,  0.0357,  0.0217,  0.0090,  0.0247, -0.0121,
         0.0455,  0.0324,  0.0509, -0.0465,  0.0391,  0.0232,  0.0143,  0.0141,
        -0.0418,  0.0272,  0.0053,  0.0139, -0.0306,  0.0127, -0.0998, -0.0238,
         0.0345,  0.0634,  0.0393,  0.0288, -0.0009,  0.0194, -0.0703, -0.0367,
         0.0174,  0.1173,  0.0056, -0.0474,  0.0518,  0.0398,  0.0028, -0.0575,
        -0.00

In [4]:
# Quantize DH_model:

# Post-training Static Quantization.
# Let's try without fusion:
DH_model.to('cpu')
backend = "qnnpack"
DH_model.qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend
model_static_quantized = torch.quantization.prepare(DH_model, inplace=False)
model_static_quantized = torch.quantization.convert(model_static_quantized, inplace=False)



In [5]:
print(model_static_quantized)

Detection_Header(
  (conv1): QuantizedConv2d(256, 144, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1))
  (bn1): QuantizedBatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): QuantizedConv2d(144, 96, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1))
  (bn2): QuantizedBatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): QuantizedConv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1))
  (bn3): QuantizedBatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): QuantizedConv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1))
  (bn4): QuantizedBatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (clshead): QuantizedConv2d(96, 1, kernel_size=(3, 3), stride=(1, 1), scale=1.0, zero_point=0, padding=(1, 1))
  (reghead): QuantizedConv2d(9

In [6]:
# Weights from non-fusion model:
print(model_static_quantized.conv1.weight().int_repr())

tensor([[[[ 18,  16,  -1],
          [ 11, -11,  10],
          [-13,  -9, -17]],

         [[-14,  30, -18],
          [  7,  10,   2],
          [  9,  41,  19]],

         [[ -5, -12,   6],
          [ -1,   1,  12],
          [ 10,  24,  -9]],

         ...,

         [[ 32,  15,  16],
          [ 14,  43,  15],
          [ 34,  14,  15]],

         [[ 39,   3,  29],
          [-25, -20,   3],
          [ 14,  19,  14]],

         [[-35, -13, -21],
          [ -4, -11,  -6],
          [ 13,   6,  -9]]],


        [[[ 22,  28,  42],
          [ -4,  17,  47],
          [ 31,  -3,  17]],

         [[-17, -51, -35],
          [-16, -46, -22],
          [-29, -38,  -2]],

         [[-15,   9,   4],
          [ 14,  10, -13],
          [-27, -15, -23]],

         ...,

         [[-20, -35,   5],
          [-40, -42,  -4],
          [-17, -42,   1]],

         [[ 14,  34,   3],
          [-19, -12, -23],
          [-15,   8,  -4]],

         [[ 16,  13,  12],
          [ 15,  19,   1],
 

In [5]:
# Quantize DH_model:

# Static quantization of a model consists of the following steps:
#     Fuse modules
#     Insert Quant/DeQuant Stubs
#     Prepare the fused module (insert observers before and after layers)
#     Calibrate the prepared module (pass it representative data)
#     Convert the calibrated module (replace with quantized version)


In [6]:
import copy
from torch import nn

DH_m = copy.deepcopy(DH_model)
DH_m.eval()

Detection_Header(
  (conv1): Conv2d(256, 144, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(144, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn4): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (clshead): Conv2d(96, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (reghead): Conv2d(96, 2, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

In [7]:
class Quantized(nn.Module):
    def __init__(self, model_fp32):
        super(Quantized, self).__init__()
        # QuantStub converts tensors from floating point to quantized.
        # This will only be used for inputs.
        self.quant = torch.quantization.QuantStub()
        # FP32 model
        self.model_fp32 = model_fp32
        # DeQuantStub converts tensors from quantized to floating point.
        # This will only be used for outputs.
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        # manually specify where tensors will be converted from floating
        # point to quantized in the quantized model
        x = self.quant(x)
        x = self.model_fp32(x)
        # manually specify where tensors will be converted from quantized
        # to floating point in the quantized model
        x = self.dequant(x)
        return x

In [8]:
"""Fuse
- Inplace fusion replaces the first module in the sequence with the fused module, and the rest with identity modules
"""
torch.quantization.fuse_modules(DH_m, ['conv1','bn1'], inplace=True) # fuse first Conv-BatchNorm pair
torch.quantization.fuse_modules(DH_m, ['conv2','bn2'], inplace=True) # fuse second Conv-BatchNorm pair
torch.quantization.fuse_modules(DH_m, ['conv3','bn3'], inplace=True) # fuse third Conv-BatchNorm pair
torch.quantization.fuse_modules(DH_m, ['conv4','bn4'], inplace=True) # fuse fourth Conv-BatchNorm pair

"""Insert stubs"""
quantized_model = Quantized(model_fp32=DH_m)

"""Prepare"""
# DH_m.qconfig = torch.quantization.get_default_qconfig(backend)
quantized_model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
    
# Print quantization configurations
print(quantized_model.qconfig)

DH_m_prepared = torch.quantization.prepare(quantized_model)

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=True){}, weight=functools.partial(<class 'torch.ao.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric){})




In [9]:
# calibrate the prepared model to determine quantization parameters for activations
# in a real world setting, the calibration would be done with a representative dataset
input_fp32 = torch.randn(144, 256, 3, 3)
DH_m_prepared(input_fp32)

# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
# used with each activation tensor, and replaces key operators with quantized
# implementations.
DH_m_int8 = torch.quantization.convert(DH_m_prepared)

DH_m_int8.eval()
# run the model, relevant calculations will happen in int8
res_int8 = DH_m_int8(input_fp32)

In [10]:
print(DH_m_prepared)
#print(DH_m_int8.model_fp32.conv1.weight()) # float but has output scale and zero_factor
print(DH_m_int8.model_fp32.conv1.weight().int_repr()) # output result in int8

Quantized(
  (quant): QuantStub(
    (activation_post_process): HistogramObserver(min_val=-4.521451473236084, max_val=4.27418327331543)
  )
  (model_fp32): Detection_Header(
    (conv1): Conv2d(
      256, 144, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
      (activation_post_process): HistogramObserver(min_val=-5.190605163574219, max_val=4.576180458068848)
    )
    (bn1): Identity()
    (conv2): Conv2d(
      144, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
      (activation_post_process): HistogramObserver(min_val=-3.0808920860290527, max_val=3.477532386779785)
    )
    (bn2): Identity()
    (conv3): Conv2d(
      96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
      (activation_post_process): HistogramObserver(min_val=-2.301859140396118, max_val=2.0950284004211426)
    )
    (bn3): Identity()
    (conv4): Conv2d(
      96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
      (activation_post_process): HistogramObserver(min_val=-1.3431458473205566,

In [11]:
print(DH_m_int8)

Quantized(
  (quant): Quantize(scale=tensor([0.0647]), zero_point=tensor([66]), dtype=torch.quint8)
  (model_fp32): Detection_Header(
    (conv1): QuantizedConv2d(256, 144, kernel_size=(3, 3), stride=(1, 1), scale=0.033345166593790054, zero_point=129, padding=(1, 1))
    (bn1): Identity()
    (conv2): QuantizedConv2d(144, 96, kernel_size=(3, 3), stride=(1, 1), scale=0.022014625370502472, zero_point=115, padding=(1, 1))
    (bn2): Identity()
    (conv3): QuantizedConv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), scale=0.01560093741863966, zero_point=136, padding=(1, 1))
    (bn3): Identity()
    (conv4): QuantizedConv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), scale=0.010387929156422615, zero_point=125, padding=(1, 1))
    (bn4): Identity()
    (clshead): QuantizedConv2d(96, 1, kernel_size=(3, 3), stride=(1, 1), scale=0.08049296587705612, zero_point=230, padding=(1, 1))
    (reghead): QuantizedConv2d(96, 2, kernel_size=(3, 3), stride=(1, 1), scale=0.005809594877064228, zero_point=129,

In [34]:
DH_m_int8.state_dict().keys()

odict_keys(['quant.scale', 'quant.zero_point', 'model_fp32.conv1.weight', 'model_fp32.conv1.bias', 'model_fp32.conv1.scale', 'model_fp32.conv1.zero_point', 'model_fp32.conv2.weight', 'model_fp32.conv2.bias', 'model_fp32.conv2.scale', 'model_fp32.conv2.zero_point', 'model_fp32.conv3.weight', 'model_fp32.conv3.bias', 'model_fp32.conv3.scale', 'model_fp32.conv3.zero_point', 'model_fp32.conv4.weight', 'model_fp32.conv4.bias', 'model_fp32.conv4.scale', 'model_fp32.conv4.zero_point', 'model_fp32.clshead.weight', 'model_fp32.clshead.bias', 'model_fp32.clshead.scale', 'model_fp32.clshead.zero_point', 'model_fp32.reghead.weight', 'model_fp32.reghead.bias', 'model_fp32.reghead.scale', 'model_fp32.reghead.zero_point'])

In [37]:
# Export reference parameters:
np.set_printoptions(threshold=sys.maxsize)
original_stdout = sys.stdout

torch.set_printoptions(profile="full")
with open('Detection_Header_model_weights/ref_parameters.txt', 'w') as f:
    sys.stdout = f

    for keys in DH_m_int8.state_dict().keys():
        print(keys, " : ", DH_m_int8.state_dict()[keys])
    sys.stdout = original_stdout

quant.scale  :  tensor([0.0684])
quant.zero_point  :  tensor([62])
model_fp32.conv1.weight  :  tensor([[[[ 0.0046,  0.0046,  0.0000],
          [ 0.0034, -0.0034,  0.0023],
          [-0.0034, -0.0023, -0.0046]],

         [[-0.0046,  0.0080, -0.0046],
          [ 0.0023,  0.0023,  0.0000],
          [ 0.0023,  0.0115,  0.0057]],

         [[-0.0011, -0.0034,  0.0011],
          [ 0.0000,  0.0000,  0.0034],
          [ 0.0023,  0.0069, -0.0023]],

         ...,

         [[ 0.0092,  0.0046,  0.0046],
          [ 0.0046,  0.0126,  0.0046],
          [ 0.0092,  0.0034,  0.0046]],

         [[ 0.0115,  0.0011,  0.0080],
          [-0.0069, -0.0057,  0.0011],
          [ 0.0034,  0.0057,  0.0046]],

         [[-0.0103, -0.0034, -0.0057],
          [-0.0011, -0.0034, -0.0011],
          [ 0.0034,  0.0023, -0.0023]]],


        [[[ 0.0057,  0.0080,  0.0115],
          [-0.0011,  0.0046,  0.0126],
          [ 0.0080, -0.0011,  0.0046]],

         [[-0.0046, -0.0138, -0.0092],
          [-0.00

In [12]:
# Export weights (float, int8, shape), biases and other parameters from fused model:
np.set_printoptions(threshold=sys.maxsize)
original_stdout = sys.stdout

torch.set_printoptions(profile="full")
with open('Detection_Header_model_weights/parameters_float_quantized.txt', 'w') as f:
    sys.stdout = f

    for keys in DH_m_int8.state_dict().keys():
        if('weight' in keys):
            weights_tran =DH_m_int8.state_dict()[keys].transpose(0,1)
            # print(keys, "_float : ", weights_tran)
            print(keys, "_int8 : ", weights_tran.int_repr())
            print(keys, "_shape: ", weights_tran.shape)
        elif('bias' in keys):
            name = keys[0:-4] # take the whole name up to bias
            bias_matrix = DH_m_int8.state_dict()[keys]
            scale = DH_m_int8.state_dict()[name+'scale']
            zero_point = DH_m_int8.state_dict()[name+'zero_point']
            # print("!!! scale: ", scale, " zero_point: ", zero_point)

            bias_matrix = torch.div(bias_matrix, scale, rounding_mode='trunc')
            bias_matrix = torch.add(bias_matrix, zero_point).to(torch.int8)

            print("bias_matrix: ", bias_matrix)
        else:
            print(keys, " : ", DH_m_int8.state_dict()[keys])
            
    sys.stdout = original_stdout

In [13]:
# Export only quantized weights from fused model:
np.set_printoptions(threshold=sys.maxsize)
original_stdout = sys.stdout

torch.set_printoptions(profile="full")
with open('Detection_Header_model_weights/weights_only.h', 'w') as f:
    sys.stdout = f

    for keys in DH_m_int8.state_dict().keys():
        if('weight' in keys):
            weights_tran =DH_m_int8.state_dict()[keys].transpose(0,1)
            print(keys, "_int8 : ", weights_tran.int_repr())
    sys.stdout = original_stdout

In [10]:
print(DH_m_int8.model_fp32.conv1.weight().int_repr().shape)

torch.Size([144, 256, 3, 3])


In [18]:
# Functions for save and load a model:
def save_torchscript_model(model, model_dir, model_filename):

    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    model_filepath = os.path.join(model_dir, model_filename)
    torch.jit.save(torch.jit.script(model), model_filepath)

def load_torchscript_model(model_filepath, device):

    model = torch.jit.load(model_filepath, map_location=device)

    return model

In [17]:
# Save quantized model:
save_torchscript_model(model=DH_m_int8, model_dir="/scratch2/xm0523/RADIal/RADIal/pytorch_model_quantization", model_filename="Detection_Header_quant_model")

In [23]:
# And save state_dict:
torch.save(DH_m_int8.state_dict, "/scratch2/xm0523/RADIal/RADIal/pytorch_model_quantization/DH_model_dict.pth")

In [19]:
# Load quantized model:
DH_quant_model = load_torchscript_model(model_filepath="/scratch2/xm0523/RADIal/RADIal/pytorch_model_quantization/Detection_Header_quant_model", device="cpu")

In [24]:
# Load the state_dict of quantized
DH_model_dict = torch.load("/scratch2/xm0523/RADIal/RADIal/pytorch_model_quantization/DH_model_dict.pth")

In [20]:
print(DH_quant_model)

RecursiveScriptModule(
  original_name=Quantized
  (quant): RecursiveScriptModule(original_name=Quantize)
  (model_fp32): RecursiveScriptModule(
    original_name=Detection_Header
    (conv1): RecursiveScriptModule(original_name=Conv2d)
    (bn1): RecursiveScriptModule(original_name=Identity)
    (conv2): RecursiveScriptModule(original_name=Conv2d)
    (bn2): RecursiveScriptModule(original_name=Identity)
    (conv3): RecursiveScriptModule(original_name=Conv2d)
    (bn3): RecursiveScriptModule(original_name=Identity)
    (conv4): RecursiveScriptModule(original_name=Conv2d)
    (bn4): RecursiveScriptModule(original_name=Identity)
    (clshead): RecursiveScriptModule(original_name=Conv2d)
    (reghead): RecursiveScriptModule(original_name=Conv2d)
  )
  (dequant): RecursiveScriptModule(original_name=DeQuantize)
)
