# TFLiteモデルをPyTorch用に変換する

In [1]:
import os
import numpy as np
from collections import OrderedDict
import tensorflow as tf

## TFLiteファイルから重みを取得する

In [2]:
interpreter = tf.lite.Interpreter(model_path="../models/tf/FaceMesh.tflite")
interpreter.allocate_tensors()

In [3]:
# モデルの情報を確認
for d in interpreter.get_tensor_details():
    print("%3d %30s %15s %s" % (d['index'], d['name'], d['dtype'], d['shape']))

  0                        input_1 <class 'numpy.float32'> [  1 192 192   3]
  1                  conv2d/Kernel <class 'numpy.float32'> [16  3  3  3]
  2                    conv2d/Bias <class 'numpy.float32'> [16]
  3                         conv2d <class 'numpy.float32'> [ 1 96 96 16]
  4                  p_re_lu/Alpha <class 'numpy.float32'> [ 1  1 16]
  5                        p_re_lu <class 'numpy.float32'> [ 1 96 96 16]
  6        depthwise_conv2d/Kernel <class 'numpy.float32'> [ 1  3  3 16]
  7          depthwise_conv2d/Bias <class 'numpy.float32'> [16]
  8               depthwise_conv2d <class 'numpy.float32'> [ 1 96 96 16]
  9                conv2d_1/Kernel <class 'numpy.float32'> [16  1  1 16]
 10                  conv2d_1/Bias <class 'numpy.float32'> [16]
 11                       conv2d_1 <class 'numpy.float32'> [ 1 96 96 16]
 12                            add <class 'numpy.float32'> [ 1 96 96 16]
 13                p_re_lu_1/Alpha <class 'numpy.float32'> [ 1  1 16]
 14    

テンソル名に基づいてテンソルインデックスを取得できるようなルックアップテーブルを作成

In [4]:
tensor_dict = {d['name']: d['index'] for d in interpreter.get_tensor_details()}
tensor_dict

{'input_1': 0,
 'conv2d/Kernel': 1,
 'conv2d/Bias': 2,
 'conv2d': 3,
 'p_re_lu/Alpha': 4,
 'p_re_lu': 5,
 'depthwise_conv2d/Kernel': 6,
 'depthwise_conv2d/Bias': 7,
 'depthwise_conv2d': 8,
 'conv2d_1/Kernel': 9,
 'conv2d_1/Bias': 10,
 'conv2d_1': 11,
 'add': 12,
 'p_re_lu_1/Alpha': 13,
 'p_re_lu_1': 14,
 'depthwise_conv2d_1/Kernel': 15,
 'depthwise_conv2d_1/Bias': 16,
 'depthwise_conv2d_1': 17,
 'conv2d_2/Kernel': 18,
 'conv2d_2/Bias': 19,
 'conv2d_2': 20,
 'add_1': 21,
 'p_re_lu_2/Alpha': 22,
 'p_re_lu_2': 23,
 'depthwise_conv2d_2/Kernel': 24,
 'depthwise_conv2d_2/Bias': 25,
 'depthwise_conv2d_2': 26,
 'max_pooling2d': 27,
 'conv2d_3/Kernel': 28,
 'conv2d_3/Bias': 29,
 'conv2d_3': 30,
 'channel_padding/Paddings': 31,
 'channel_padding': 32,
 'add_2': 33,
 'p_re_lu_3/Alpha': 34,
 'p_re_lu_3': 35,
 'depthwise_conv2d_3/Kernel': 36,
 'depthwise_conv2d_3/Bias': 37,
 'depthwise_conv2d_3': 38,
 'conv2d_4/Kernel': 39,
 'conv2d_4/Bias': 40,
 'conv2d_4': 41,
 'add_3': 42,
 'p_re_lu_4/Alpha': 43

In [5]:
def is_with_weight(tensor, info):
    if ('add' in info['name']) or ('padding' in info['name']):
        return False
    elif tensor.sum()==0:
        if not '/' in info['name']:
            return False
        else:
            return True
    else:
        return True

In [6]:
tensors_info = interpreter.get_tensor_details()

tflite_names=list() #tflite側の名前のリスト
i=0
for info in tensors_info:
    tensor = interpreter.get_tensor(info['index'])
    if is_with_weight(tensor, info):
        tflite_names.append(info['name'])
        print("%3d %30s %3d" % (i, info['name'], info['index']))
        i+=1

  0                  conv2d/Kernel   1
  1                    conv2d/Bias   2
  2                  p_re_lu/Alpha   4
  3        depthwise_conv2d/Kernel   6
  4          depthwise_conv2d/Bias   7
  5                conv2d_1/Kernel   9
  6                  conv2d_1/Bias  10
  7                p_re_lu_1/Alpha  13
  8      depthwise_conv2d_1/Kernel  15
  9        depthwise_conv2d_1/Bias  16
 10                conv2d_2/Kernel  18
 11                  conv2d_2/Bias  19
 12                p_re_lu_2/Alpha  22
 13      depthwise_conv2d_2/Kernel  24
 14        depthwise_conv2d_2/Bias  25
 15                conv2d_3/Kernel  28
 16                  conv2d_3/Bias  29
 17                p_re_lu_3/Alpha  34
 18      depthwise_conv2d_3/Kernel  36
 19        depthwise_conv2d_3/Bias  37
 20                conv2d_4/Kernel  39
 21                  conv2d_4/Bias  40
 22                p_re_lu_4/Alpha  43
 23      depthwise_conv2d_4/Kernel  45
 24        depthwise_conv2d_4/Bias  46
 25                conv2d

In [7]:
def get_weights(name):
    idx = tensor_dict[name]
    W = interpreter.get_tensor(idx)
    return W

In [8]:
W = get_weights('conv2d_18/Kernel')
W.shape

(32, 1, 1, 128)

## PyTorchのフォーマットを確認する

In [9]:
import torch
import torch.nn as nn
from facemesh_pytorch import FaceMesh

In [10]:
net = FaceMesh()
pytorch_names = list(net.state_dict().keys())
pytorch_names[:10]

['backbone.0.weight',
 'backbone.0.bias',
 'backbone.1.weight',
 'backbone.2.convs.0.weight',
 'backbone.2.convs.0.bias',
 'backbone.2.convs.1.weight',
 'backbone.2.convs.1.bias',
 'backbone.2.act.weight',
 'backbone.3.convs.0.weight',
 'backbone.3.convs.0.bias']

tflite側とpytorchの重みの数が一致することを確認

In [11]:
print(len(pytorch_names), len(tflite_names))

113 113


2つのモデル間のレイヤー名をマッピングするルックアップテーブルを作成

テンソルが両方のモデルで同じ順番であると仮定  
分岐以降がぐちゃぐちゃなので、マニュアルで対応させる（力技）

In [12]:
convert = {}
for name_py, name_tf in zip(pytorch_names, tflite_names):
    convert[name_py] = name_tf

In [13]:
print(net(torch.randn(2,3,192,192))[0].shape)
#torch.save(net.state_dict(), 'pytorch_model_tmp.pth')

torch.Size([2, 1404])


In [14]:
manual_mapping ={
 'coord_head.0.convs.0.weight':'depthwise_conv2d_14/Kernel',
 'coord_head.0.convs.0.bias':'depthwise_conv2d_14/Bias',
 'coord_head.0.convs.1.weight':'conv2d_15/Kernel',
 'coord_head.0.convs.1.bias':'conv2d_15/Bias',
 'coord_head.0.act.weight':'p_re_lu_15/Alpha',
 'coord_head.1.convs.0.weight':'depthwise_conv2d_15/Kernel',
 'coord_head.1.convs.0.bias':'depthwise_conv2d_15/Bias',
 'coord_head.1.convs.1.weight':'conv2d_16/Kernel',
 'coord_head.1.convs.1.bias':'conv2d_16/Bias',
 'coord_head.1.act.weight':'p_re_lu_16/Alpha',
 'coord_head.2.convs.0.weight':'depthwise_conv2d_16/Kernel',
 'coord_head.2.convs.0.bias':'depthwise_conv2d_16/Bias',
 'coord_head.2.convs.1.weight':'conv2d_17/Kernel',
 'coord_head.2.convs.1.bias':'conv2d_17/Bias',
 'coord_head.2.act.weight':'p_re_lu_17/Alpha',
 'coord_head.3.weight':'conv2d_18/Kernel',
 'coord_head.3.bias':'conv2d_18/Bias',
 'coord_head.4.weight':'p_re_lu_18/Alpha',
 'coord_head.5.convs.0.weight':'depthwise_conv2d_17/Kernel',
 'coord_head.5.convs.0.bias':'depthwise_conv2d_17/Bias',
 'coord_head.5.convs.1.weight':'conv2d_19/Kernel',
 'coord_head.5.convs.1.bias':'conv2d_19/Bias',
 'coord_head.5.act.weight':'p_re_lu_19/Alpha',
 'coord_head.6.weight':'conv2d_20/Kernel',
 'coord_head.6.bias':'conv2d_20/Bias',
 'conf_head.0.convs.0.weight':'depthwise_conv2d_22/Kernel',
 'conf_head.0.convs.0.bias':'depthwise_conv2d_22/Bias',
 'conf_head.0.convs.1.weight':'conv2d_27/Kernel',
 'conf_head.0.convs.1.bias':'conv2d_27/Bias',
 'conf_head.0.act.weight':'p_re_lu_25/Alpha',
 'conf_head.1.weight':'conv2d_28/Kernel',
 'conf_head.1.bias':'conv2d_28/Bias',
 'conf_head.2.weight':'p_re_lu_26/Alpha',
 'conf_head.3.convs.0.weight':'depthwise_conv2d_23/Kernel',
 'conf_head.3.convs.0.bias':'depthwise_conv2d_23/Bias',
 'conf_head.3.convs.1.weight':'conv2d_29/Kernel',
 'conf_head.3.convs.1.bias':'conv2d_29/Bias',
 'conf_head.3.act.weight':'p_re_lu_27/Alpha',
 'conf_head.4.weight':'conv2d_30/Kernel',
 'conf_head.4.bias':'conv2d_30/Bias'}
convert.update(manual_mapping)

## 重みをレイヤーにコピー

PyTorchとTFLiteでは重みの順序が異なるので、転置する必要がある

Convolution weights:

    TFLiteの場合 (out_channels, kernel_height, kernel_width, in_channels)
    PyTorchの場合(out_channels, in_channels, kernel_height, kernel_width)

Depthwise convolution weights

    TFLiteの場合 (1, kernel_height, kernel_width, channels)
    PyTorchの場合(channels, 1, kernel_height, kernel_width)
    
PReLU:

    TFLiteの場合  (1, 1, 1, num_channels)
    PyTorchの場合 (num_channels, )

In [15]:
new_state_dict = OrderedDict()

for dst, src in convert.items():
    W = get_weights(src)
    print(dst, src, W.shape, net.state_dict()[dst].shape)

    if W.ndim == 4:
        if W.shape[0] == 1 and dst != "conf_head.4.weight":
            W = W.transpose((3, 0, 1, 2))  # depthwise conv
        else:
            W = W.transpose((0, 3, 1, 2))  # regular conv
    elif W.ndim == 3:
        W = W.reshape(-1)
    
    new_state_dict[dst] = torch.from_numpy(W)

backbone.0.weight conv2d/Kernel (16, 3, 3, 3) torch.Size([16, 3, 3, 3])
backbone.0.bias conv2d/Bias (16,) torch.Size([16])
backbone.1.weight p_re_lu/Alpha (1, 1, 16) torch.Size([16])
backbone.2.convs.0.weight depthwise_conv2d/Kernel (1, 3, 3, 16) torch.Size([16, 1, 3, 3])
backbone.2.convs.0.bias depthwise_conv2d/Bias (16,) torch.Size([16])
backbone.2.convs.1.weight conv2d_1/Kernel (16, 1, 1, 16) torch.Size([16, 16, 1, 1])
backbone.2.convs.1.bias conv2d_1/Bias (16,) torch.Size([16])
backbone.2.act.weight p_re_lu_1/Alpha (1, 1, 16) torch.Size([16])
backbone.3.convs.0.weight depthwise_conv2d_1/Kernel (1, 3, 3, 16) torch.Size([16, 1, 3, 3])
backbone.3.convs.0.bias depthwise_conv2d_1/Bias (16,) torch.Size([16])
backbone.3.convs.1.weight conv2d_2/Kernel (16, 1, 1, 16) torch.Size([16, 16, 1, 1])
backbone.3.convs.1.bias conv2d_2/Bias (16,) torch.Size([16])
backbone.3.act.weight p_re_lu_2/Alpha (1, 1, 16) torch.Size([16])
backbone.4.convs.0.weight depthwise_conv2d_2/Kernel (1, 3, 3, 16) torch.S

In [16]:
net.load_state_dict(new_state_dict, strict=True)

<All keys matched successfully>

In [17]:
# 通常のpytorchモデルを保存
torch.save(net.state_dict(), '../models/pytorch/FaceMesh.pth')

In [18]:
example = torch.rand(1,3,192,192)
net.eval()
traced_script_module = torch.jit.trace(net, example)

In [19]:
traced_script_module.save('../models/libtorch/FaceMesh.pth')