Skip to content

load pretrained model from .pth #18

Closed
@HaoRan-hash

Description

@HaoRan-hash

I write a model using Pytorch, and save its state_dict() to .pth. Now I want to use tensorlayerx to write it, so other people (using tensorflow etc.) can use this model.
My model definition is same in Pytorch and Tensorlayerx, but I can't load pretrained model of .pth in tensorlayerx.
Below is my code. (simple model is used here for clarity, the actual model is more complex than this)

"""
a_torch.py
"""
import torch
from torch import nn

class A(nn.Module):
    def __init__(self):
        super(A, self).__init__()
        self.conv = nn.Conv2d(3, 16, kernel_size=1)
        self.bn = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

if __name__ == '__main__':
    a = A()
    torch.save(a.state_dict(), 'a.pth')
"""
a_tlx.py
"""
import tensorlayerx as tlx
import torch
from tensorlayerx import nn

class A(nn.Module):
    def __init__(self):
        super(A, self).__init__()
        self.conv = nn.Conv2d(16, kernel_size=1, data_format='channels_first')
        self.bn = nn.BatchNorm2d(num_features=16, data_format='channels_first')
        self.relu = nn.activation.ReLU()
    
    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

def pth2npz(pth_path):
    temp = torch.load(pth_path)   # type(temp) = OrderedDict
    tlx.files.save_npz_dict(temp.items(), pth_path.split('.')[0] + '.npz')

if __name__ == '__main__':
    a = A()
    pth2npz('a.pth')
    tlx.files.load_and_assign_npz_dict('a.npz', a)

First run a_torch.py, then run a_tlx.py.
The error is below.

Using PyTorch backend.
Traceback (most recent call last):
  File "test/test_03.py", line 25, in <module>
    tlx.files.load_and_assign_npz_dict('test/a.npz', a)
  File "/home/mchen/anaconda3/envs/kpconv/lib/python3.8/site-packages/tensorlayerx/files/utils.py", line 2208, in load_and_assign_npz_dict
    raise RuntimeError(
RuntimeError: Weights named 'conv.weight' not found in network. Hint: set argument skip=Ture if you want to skip redundant or mismatch weights

Then I debug and look at the tlx.files.load_and_assign_npz_dict() source code. I find tensorlayerx parameter name is different from PyTorch. This results in key mismatch when loading pre-trained model.
In the following two figures, the first is the parameter name of PyTorch and the second is the parameter name of TensorLayerx.
屏幕截图 2022-08-07 202607
屏幕截图 2022-08-07 202555
Now the solution I can think of is to write a key map table, but it is hard for large model. So can you give me a simple solution ? (same model definition in pytorch and tensorlayerx, load pretrained model in .pth) 😁

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions