In [2]:
import torch

import onnx
from onnx import numpy_helper

from StyleTransferModel_128 import StyleTransferModel

In [106]:
INSWAPPER = False # set to false if converting from old reswapper architecture weights, set to true if converting from inswapper_128.onnx
WEIGHTS_PATH = "models/reswapper-1019500.pth" # file ending .pt, .pth, .onnx
# WEIGHTS_PATH = "models/inswapper_128_batched.onnx" # file ending .pt, .pth, .onnx

transfer_weights = {}

In [107]:
if WEIGHTS_PATH.endswith('.onnx'):
    onnx_model   = onnx.load(WEIGHTS_PATH)
    INTIALIZERS  = onnx_model.graph.initializer
    transfer_weights = {}
    for initializer in INTIALIZERS:
        W = numpy_helper.to_array(initializer)
        transfer_weights[initializer.name] = W
elif WEIGHTS_PATH.endswith('.pth') or WEIGHTS_PATH.endswith('.pt'):
    transfer_weights = torch.load(WEIGHTS_PATH) 
else:
    print('File type must be of (.pt, .pth, .onnx)')

In [108]:
model = StyleTransferModel()

In [109]:
weight_shapes = []
for n, p in model.named_parameters():
    weight_shapes.append((n, '-'.join([str(x) for x in list(p.shape)])))

In [110]:
for k, v in model.state_dict().items():
    print(k)
    print(v.shape, '\n')

down.0.weight
torch.Size([128, 3, 7, 7]) 

down.0.bias
torch.Size([128]) 

down.2.weight
torch.Size([256, 128, 3, 3]) 

down.2.bias
torch.Size([256]) 

down.4.weight
torch.Size([512, 256, 3, 3]) 

down.4.bias
torch.Size([512]) 

down.6.weight
torch.Size([1024, 512, 3, 3]) 

down.6.bias
torch.Size([1024]) 

style_blocks.0.conv1.conv.weight
torch.Size([1024, 1024, 3, 3]) 

style_blocks.0.conv1.conv.bias
torch.Size([1024]) 

style_blocks.0.conv1.adain.style.weight
torch.Size([2048, 512]) 

style_blocks.0.conv1.adain.style.bias
torch.Size([2048]) 

style_blocks.0.conv2.conv.weight
torch.Size([1024, 1024, 3, 3]) 

style_blocks.0.conv2.conv.bias
torch.Size([1024]) 

style_blocks.0.conv2.adain.style.weight
torch.Size([2048, 512]) 

style_blocks.0.conv2.adain.style.bias
torch.Size([2048]) 

style_blocks.1.conv1.conv.weight
torch.Size([1024, 1024, 3, 3]) 

style_blocks.1.conv1.conv.bias
torch.Size([1024]) 

style_blocks.1.conv1.adain.style.weight
torch.Size([2048, 512]) 

style_blocks.1.conv1.a

In [111]:
for k, v in transfer_weights.items():
    print(k)
    print(v.shape, '\n')

target_encoder.0.weight
torch.Size([128, 3, 7, 7]) 

target_encoder.0.bias
torch.Size([128]) 

target_encoder.2.weight
torch.Size([256, 128, 3, 3]) 

target_encoder.2.bias
torch.Size([256]) 

target_encoder.4.weight
torch.Size([512, 256, 3, 3]) 

target_encoder.4.bias
torch.Size([512]) 

target_encoder.6.weight
torch.Size([1024, 512, 3, 3]) 

target_encoder.6.bias
torch.Size([1024]) 

style_blocks.0.conv1.weight
torch.Size([1024, 1024, 3, 3]) 

style_blocks.0.conv1.bias
torch.Size([1024]) 

style_blocks.0.conv2.weight
torch.Size([1024, 1024, 3, 3]) 

style_blocks.0.conv2.bias
torch.Size([1024]) 

style_blocks.0.style1.weight
torch.Size([2048, 512]) 

style_blocks.0.style1.bias
torch.Size([2048]) 

style_blocks.0.style2.weight
torch.Size([2048, 512]) 

style_blocks.0.style2.bias
torch.Size([2048]) 

style_blocks.1.conv1.weight
torch.Size([1024, 1024, 3, 3]) 

style_blocks.1.conv1.bias
torch.Size([1024]) 

style_blocks.1.conv2.weight
torch.Size([1024, 1024, 3, 3]) 

style_blocks.1.conv2.

In [112]:
if INSWAPPER:
    replacement_dict = {
        'styles': 'style_blocks',
        'conv1.1': 'conv1.conv',
        'style1.linear': 'conv1.adain.style',
        'conv2.1': 'conv2.conv',
        'style2.linear': 'conv2.adain.style',
        'up0.1': 'up.8',
        'onnx::Conv_833': 'down.0.weight',
        'onnx::Conv_834': 'down.0.bias',
        'onnx::Conv_836': 'down.2.weight',
        'onnx::Conv_837': 'down.2.bias',
        'onnx::Conv_839': 'down.4.weight',
        'onnx::Conv_840': 'down.4.bias',
        'onnx::Conv_842': 'down.6.weight',
        'onnx::Conv_843': 'down.6.bias',
        'onnx::Conv_845': 'up.1.weight',
        'onnx::Conv_846': 'up.1.bias',
        'onnx::Conv_848': 'up.4.weight',
        'onnx::Conv_849': 'up.4.bias',
        'onnx::Conv_851': 'up.6.weight',
        'onnx::Conv_852': 'up.6.bias',
        # 'initializer': 'initializer.weight'
    }
else:
    replacement_dict = {
        'target_encoder': 'down',
        'conv1': 'conv1.conv',
        'style1': 'conv1.adain.style',
        'conv2': 'conv2.conv',
        'style2': 'conv2.adain.style',
        'decoder.0': 'up.1',
        'decoderPart1.0': 'up.4',
        'decoderPart1.2': 'up.6',
        'decoderPart2.0': 'up.8'
    }

In [113]:
renamed_weights = {}

for k, v in transfer_weights.items():
    orig_k = k
    for name, replacement in replacement_dict.items():
        k = k.replace(name, replacement)

    if k == orig_k:
        shape_name = '-'.join([str(x) for x in v.shape])
        print(f'Shape of {k}: {v.shape}')
        print(f'Possible names:')
        replacements = []
        for weight_shape in weight_shapes:
            if shape_name == weight_shape[-1]:
                print(f'  {weight_shape[0]}')
                replacements.append(weight_shape[0])
        if len(replacements) == 1:
            k = replacements[0]

    if k != orig_k:
        renamed_weights[k] = v

In [114]:
state_dict = {k: torch.tensor(v) for k, v in renamed_weights.items()}

model.load_state_dict(state_dict)

  state_dict = {k: torch.tensor(v) for k, v in renamed_weights.items()}


<All keys matched successfully>

In [115]:
torch.save(model.state_dict(), 'inswapper_128_dict.pt')

In [25]:
torch.save(model, 'inswapper_128_cleaned.pt')

In [26]:
model(torch.rand((1, 3, 128, 128)), torch.rand((1, 512))).shape

torch.Size([1, 3, 256, 256])

In [None]:
# test
tgt = torch.randn(1, 3, 128, 128)
src_e = torch.randn(1, 512)

y = model(tgt, src_e)

model.eval()

torch.onnx.export(
    model, 
    (tgt, src_e), 
    "inswapper_eg.onnx", 
    export_params=True,
    opset_version=11, 
    input_names=['target', 'source'], 
    output_names=['output'], 
    dynamic_axes={'target': {0: 'batch_size'}, 'source': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)