Skip to content

Corrupted RNN parameters in models exported on GPU (torch.jit) #15271

@Wizaron

Description

@Wizaron

🐛 Bug

Step 1

I build a CNN + RNN based model and cast it to both CPU and GPU, then trace them using torch.jit.trace as below:

model_cpu = Architecture()
model_gpu = copy.deepcopy(model_cpu).to(torch.device("cuda"))

dummy_input = torch.randn(1, 3, 224, 224)

traced_model_cpu = torch.jit.trace(model_cpu, dummy_input)
traced_model_gpu = torch.jit.trace(model_gpu, dummy_input.to(torch.device('cuda')))

At this step, the parameters of all models (model_cpu, model_gpu, traced_model_cpu, traced_model_gpu) are exactly the same.

Step 2

Then I save and load them using torch.jit as below:

torch.jit.save(traced_model_cpu, "model_cpu.pth")
torch.jit.save(traced_model_gpu, "model_gpu.pth")

traced_model_cpu_loaded = torch.jit.load("model_cpu.pth")
traced_model_gpu_loaded = torch.jit.load("model_gpu.pth")

The parameters of the RNN layers in traced_model_gpu_loaded is totally different from model_cpu, model_gpu, traced_model_cpu, traced_model_gpu and traced_model_cpu_loaded.

To Reproduce

In order to reproduce the behavior, run the script given below:

import copy
import torch
import torch.nn as nn
from torchvision import models


class ReNet(nn.Module):

    def __init__(self, n_input, n_units):
        super(ReNet, self).__init__()

        self.rnn = nn.GRU(n_input, n_units,
                          num_layers=1, batch_first=False,
                          bidirectional=True)

    def rnn_forward(self, x):

        b, n_height, n_width, n_filters = x.size()

        x = x.view(b * n_height, n_width, n_filters)
        x = x.permute(1, 0, 2)
        x, _ = self.rnn(x)
        x = x.permute(1, 0, 2)
        x = x.view(b, n_height, n_width, -1)

        return x

    def forward(self, x):
                                       #b, nf, h, w
        x = x.permute(0, 2, 3, 1)      #b, h, w, nf
        x = self.rnn_forward(x)        #b, h, w, nf
        x = x.permute(0, 3, 1, 2)      #b, nf, h, w

        return x


class Architecture(nn.Module):

    def __init__(self):
        super(Architecture, self).__init__()

        self.cnn = models.resnet50(pretrained=True)
        self.cnn = nn.Sequential(*list(self.cnn.children())[:-5])

        self.renet1 = ReNet(256, 50)

    def forward(self, x):
        x = self.cnn(x)
        x = self.renet1(x)

        return x


def compare_models(cpu_model, gpu_model):

    is_identical = True

    cpu_model_state_dict = cpu_model.state_dict()
    gpu_model_state_dict = gpu_model.state_dict()

    for param_key, cpu_params in cpu_model_state_dict.items():
        gpu_params = gpu_model_state_dict[param_key]
        _identical = torch.all(gpu_params == cpu_params.to(torch.device("cuda")))
        if _identical.item() == 0:
            print("\n\t# PARAMETER : ", param_key)
            print("\t* GPU : ", gpu_params.view(-1)[:5])
            print("\t* CPU : ", cpu_params.view(-1)[:5])
            is_identical = False

    return is_identical


def trace(model, usegpu):
    with torch.set_grad_enabled(False):
        model.eval()

        dummy_input = torch.randn(1, 3, 224, 224)
        
        if usegpu:
            dummy_input = dummy_input.to(torch.device('cuda'))

        traced_model = torch.jit.trace(model, dummy_input)

    return traced_model


torch.manual_seed(13)

model_cpu = Architecture()
model_gpu = copy.deepcopy(model_cpu).to(torch.device("cuda"))

print("STEP 1 : ", compare_models(model_cpu, model_gpu))

traced_model_cpu = trace(model_cpu, False)
traced_model_gpu = trace(model_gpu, True)
print("STEP 2 : ", compare_models(traced_model_cpu, traced_model_gpu))
print("STEP 2 : ", compare_models(traced_model_gpu, model_gpu))

torch.jit.save(traced_model_cpu, "model_cpu.pth")
torch.jit.save(traced_model_gpu, "model_gpu.pth")
print("STEP 3 : ", compare_models(traced_model_cpu, traced_model_gpu))
print("STEP 3 : ", compare_models(traced_model_gpu, model_gpu))

traced_model_cpu_loaded = torch.jit.load("model_cpu.pth")
traced_model_gpu_loaded = torch.jit.load("model_gpu.pth")
print("\nSTEP 4 : ", compare_models(traced_model_cpu_loaded, model_cpu))
print("\nSTEP 4 : ", compare_models(traced_model_gpu_loaded, model_cpu))

This script prints the parameter names (and a small portion of the values) that are different between CPU and GPU.

Expected behavior

It prints the following output which indicates that the parameters of RNN layers of the model are different between CPU and GPU:

        # PARAMETER :  renet1.rnn.weight_hh_l0
        * GPU :  tensor([-0.5195, -0.7641,  0.7705,  0.6834,  0.0681], device='cuda:0')
        * CPU :  tensor([ 0.0051,  0.0621,  0.0859, -0.0506, -0.1000])

        # PARAMETER :  renet1.rnn.bias_ih_l0
        * GPU :  tensor([0., 0., 0., 0., 0.], device='cuda:0')
        * CPU :  tensor([-0.0004,  0.0263,  0.0537, -0.0810,  0.0930])

        # PARAMETER :  renet1.rnn.bias_hh_l0
        * GPU :  tensor([0., 0., 0., 0., 0.], device='cuda:0')
        * CPU :  tensor([ 0.0439, -0.0063,  0.0250,  0.0784,  0.0408])

        # PARAMETER :  renet1.rnn.weight_ih_l0_reverse
        * GPU :  tensor([0., 0., 0., 0., 0.], device='cuda:0')
        * CPU :  tensor([-0.0755, -0.0730, -0.0435, -0.0522, -0.0979])

        # PARAMETER :  renet1.rnn.weight_hh_l0_reverse
        * GPU :  tensor([0., 0., 0., 0., 0.], device='cuda:0')
        * CPU :  tensor([ 0.0104,  0.1401, -0.0695,  0.0870, -0.0896])

        # PARAMETER :  renet1.rnn.bias_ih_l0_reverse
        * GPU :  tensor([0.0165, 0.0165, 0.0383, 0.0383, 0.0000], device='cuda:0')
        * CPU :  tensor([ 0.0596,  0.1129, -0.1282,  0.0738, -0.0051])

        # PARAMETER :  renet1.rnn.bias_hh_l0_reverse
        * GPU :  tensor([ 0.0419,  0.0845, -0.0466, -0.1143, -0.0606], device='cuda:0')
        * CPU :  tensor([-0.0844,  0.0941, -0.1149,  0.0188, -0.0276])

Environment

PyTorch version: 1.0.0
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Ubuntu 16.04.1 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.10) 5.4.0 20160609
CMake version: version 3.5.1

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: GeForce GTX 980M
Nvidia driver version: 396.44
cuDNN version: Probably one of the following:
/usr/local/cuda-8.0/lib64/libcudnn.so
/usr/local/cuda-8.0/lib64/libcudnn.so.6
/usr/local/cuda-8.0/lib64/libcudnn.so.6.0.21
/usr/local/cuda-8.0/lib64/libcudnn_static.a
/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudnn.so
/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudnn.so.7
/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudnn.so.7.1.4
/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudnn_static.a

Versions of relevant libraries:
[pip] Could not collect
[conda] blas                      1.0                         mkl  
[conda] mkl                       2019.1                      144  
[conda] mkl_fft                   1.0.6            py37hd81dba3_0  
[conda] mkl_random                1.0.2            py37hd81dba3_0  
[conda] pytorch                   1.0.0           py3.7_cuda9.0.176_cudnn7.4.1_1    pytorch
[conda] torchvision               0.2.1                      py_2    pytorch

Metadata

Metadata

Assignees

Labels

oncall: jitAdd this issue/PR to JIT oncall triage queue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions