In [1]:
import os
import tempfile
from zipfile import ZipFile

import torch
import polish.models

In [2]:
model = polish.models.DeepDenoiser(conv2d=polish.models.SepConv2d)

In [3]:
# Insert your trained model here.
model.load_state_dict(torch.load('model.pt', map_location='cpu'))

<All keys matched successfully>

In [4]:
# We will export the parameters using these keys.
#
# Seeing the list may be helpful for implementing the
# model in the Go API.
model.state_dict().keys()

odict_keys(['conv1.weight', 'conv1.bias', 'conv2.spatial.weight', 'conv2.spatial.bias', 'conv2.depthwise.weight', 'conv2.depthwise.bias', 'residuals.0.0.weight', 'residuals.0.0.bias', 'residuals.0.2.spatial.weight', 'residuals.0.2.spatial.bias', 'residuals.0.2.depthwise.weight', 'residuals.0.2.depthwise.bias', 'residuals.0.4.spatial.weight', 'residuals.0.4.spatial.bias', 'residuals.0.4.depthwise.weight', 'residuals.0.4.depthwise.bias', 'residuals.1.0.weight', 'residuals.1.0.bias', 'residuals.1.2.spatial.weight', 'residuals.1.2.spatial.bias', 'residuals.1.2.depthwise.weight', 'residuals.1.2.depthwise.bias', 'residuals.1.4.spatial.weight', 'residuals.1.4.spatial.bias', 'residuals.1.4.depthwise.weight', 'residuals.1.4.depthwise.bias', 'residuals.2.0.weight', 'residuals.2.0.bias', 'residuals.2.2.spatial.weight', 'residuals.2.2.spatial.bias', 'residuals.2.2.depthwise.weight', 'residuals.2.2.depthwise.bias', 'residuals.2.4.spatial.weight', 'residuals.2.4.spatial.bias', 'residuals.2.4.depthwi

In [5]:
# Create a zip_data variable containing the parameters
# as a zip file, with one file per array.
with tempfile.TemporaryDirectory() as temp_dir:
    zip_path = os.path.join(temp_dir, 'params.zip')
    with ZipFile(zip_path, 'w') as f:
        for k, v in model.state_dict().items():
            arr = v.detach().cpu().numpy().flatten()
            f.writestr('%s' % k, arr.tobytes())
    with open(zip_path, 'rb') as f:
        zip_data = f.read()
print('Created zip file of %d bytes.' % len(zip_data))

Created zip file of 1845282 bytes.


In [6]:
def byte_str(b):
    """Convert a byte into an escape sequence for a string."""
    if b >= 32 and b <= 126 and b != ord('\\') and b != ord('"'):
        return chr(b)
    return '\\x%02x' % b

In [8]:
variable = 'deepModelZipData'
go_code = 'package polish\n\nconst %s = "' % variable
go_code += ''.join(byte_str(x) for x in zip_data)
go_code += '"\n'
print('Created code of length %d.' % len(go_code))
with open('model_data_deep.go', 'wt+') as f:
    f.write(go_code)

Created code of length 5225861.
