-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Description
🐛 Bug
Step 1
I build a CNN + RNN based model and cast it to both CPU and GPU, then trace them using torch.jit.trace
as below:
model_cpu = Architecture()
model_gpu = copy.deepcopy(model_cpu).to(torch.device("cuda"))
dummy_input = torch.randn(1, 3, 224, 224)
traced_model_cpu = torch.jit.trace(model_cpu, dummy_input)
traced_model_gpu = torch.jit.trace(model_gpu, dummy_input.to(torch.device('cuda')))
At this step, the parameters of all models (model_cpu
, model_gpu
, traced_model_cpu
, traced_model_gpu
) are exactly the same.
Step 2
Then I save and load them using torch.jit
as below:
torch.jit.save(traced_model_cpu, "model_cpu.pth")
torch.jit.save(traced_model_gpu, "model_gpu.pth")
traced_model_cpu_loaded = torch.jit.load("model_cpu.pth")
traced_model_gpu_loaded = torch.jit.load("model_gpu.pth")
The parameters of the RNN layers in traced_model_gpu_loaded
is totally different from model_cpu
, model_gpu
, traced_model_cpu
, traced_model_gpu
and traced_model_cpu_loaded
.
To Reproduce
In order to reproduce the behavior, run the script given below:
import copy
import torch
import torch.nn as nn
from torchvision import models
class ReNet(nn.Module):
def __init__(self, n_input, n_units):
super(ReNet, self).__init__()
self.rnn = nn.GRU(n_input, n_units,
num_layers=1, batch_first=False,
bidirectional=True)
def rnn_forward(self, x):
b, n_height, n_width, n_filters = x.size()
x = x.view(b * n_height, n_width, n_filters)
x = x.permute(1, 0, 2)
x, _ = self.rnn(x)
x = x.permute(1, 0, 2)
x = x.view(b, n_height, n_width, -1)
return x
def forward(self, x):
#b, nf, h, w
x = x.permute(0, 2, 3, 1) #b, h, w, nf
x = self.rnn_forward(x) #b, h, w, nf
x = x.permute(0, 3, 1, 2) #b, nf, h, w
return x
class Architecture(nn.Module):
def __init__(self):
super(Architecture, self).__init__()
self.cnn = models.resnet50(pretrained=True)
self.cnn = nn.Sequential(*list(self.cnn.children())[:-5])
self.renet1 = ReNet(256, 50)
def forward(self, x):
x = self.cnn(x)
x = self.renet1(x)
return x
def compare_models(cpu_model, gpu_model):
is_identical = True
cpu_model_state_dict = cpu_model.state_dict()
gpu_model_state_dict = gpu_model.state_dict()
for param_key, cpu_params in cpu_model_state_dict.items():
gpu_params = gpu_model_state_dict[param_key]
_identical = torch.all(gpu_params == cpu_params.to(torch.device("cuda")))
if _identical.item() == 0:
print("\n\t# PARAMETER : ", param_key)
print("\t* GPU : ", gpu_params.view(-1)[:5])
print("\t* CPU : ", cpu_params.view(-1)[:5])
is_identical = False
return is_identical
def trace(model, usegpu):
with torch.set_grad_enabled(False):
model.eval()
dummy_input = torch.randn(1, 3, 224, 224)
if usegpu:
dummy_input = dummy_input.to(torch.device('cuda'))
traced_model = torch.jit.trace(model, dummy_input)
return traced_model
torch.manual_seed(13)
model_cpu = Architecture()
model_gpu = copy.deepcopy(model_cpu).to(torch.device("cuda"))
print("STEP 1 : ", compare_models(model_cpu, model_gpu))
traced_model_cpu = trace(model_cpu, False)
traced_model_gpu = trace(model_gpu, True)
print("STEP 2 : ", compare_models(traced_model_cpu, traced_model_gpu))
print("STEP 2 : ", compare_models(traced_model_gpu, model_gpu))
torch.jit.save(traced_model_cpu, "model_cpu.pth")
torch.jit.save(traced_model_gpu, "model_gpu.pth")
print("STEP 3 : ", compare_models(traced_model_cpu, traced_model_gpu))
print("STEP 3 : ", compare_models(traced_model_gpu, model_gpu))
traced_model_cpu_loaded = torch.jit.load("model_cpu.pth")
traced_model_gpu_loaded = torch.jit.load("model_gpu.pth")
print("\nSTEP 4 : ", compare_models(traced_model_cpu_loaded, model_cpu))
print("\nSTEP 4 : ", compare_models(traced_model_gpu_loaded, model_cpu))
This script prints the parameter names (and a small portion of the values) that are different between CPU and GPU.
Expected behavior
It prints the following output which indicates that the parameters of RNN layers of the model are different between CPU and GPU:
# PARAMETER : renet1.rnn.weight_hh_l0
* GPU : tensor([-0.5195, -0.7641, 0.7705, 0.6834, 0.0681], device='cuda:0')
* CPU : tensor([ 0.0051, 0.0621, 0.0859, -0.0506, -0.1000])
# PARAMETER : renet1.rnn.bias_ih_l0
* GPU : tensor([0., 0., 0., 0., 0.], device='cuda:0')
* CPU : tensor([-0.0004, 0.0263, 0.0537, -0.0810, 0.0930])
# PARAMETER : renet1.rnn.bias_hh_l0
* GPU : tensor([0., 0., 0., 0., 0.], device='cuda:0')
* CPU : tensor([ 0.0439, -0.0063, 0.0250, 0.0784, 0.0408])
# PARAMETER : renet1.rnn.weight_ih_l0_reverse
* GPU : tensor([0., 0., 0., 0., 0.], device='cuda:0')
* CPU : tensor([-0.0755, -0.0730, -0.0435, -0.0522, -0.0979])
# PARAMETER : renet1.rnn.weight_hh_l0_reverse
* GPU : tensor([0., 0., 0., 0., 0.], device='cuda:0')
* CPU : tensor([ 0.0104, 0.1401, -0.0695, 0.0870, -0.0896])
# PARAMETER : renet1.rnn.bias_ih_l0_reverse
* GPU : tensor([0.0165, 0.0165, 0.0383, 0.0383, 0.0000], device='cuda:0')
* CPU : tensor([ 0.0596, 0.1129, -0.1282, 0.0738, -0.0051])
# PARAMETER : renet1.rnn.bias_hh_l0_reverse
* GPU : tensor([ 0.0419, 0.0845, -0.0466, -0.1143, -0.0606], device='cuda:0')
* CPU : tensor([-0.0844, 0.0941, -0.1149, 0.0188, -0.0276])
Environment
PyTorch version: 1.0.0
Is debug build: No
CUDA used to build PyTorch: 9.0.176
OS: Ubuntu 16.04.1 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.10) 5.4.0 20160609
CMake version: version 3.5.1
Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: GeForce GTX 980M
Nvidia driver version: 396.44
cuDNN version: Probably one of the following:
/usr/local/cuda-8.0/lib64/libcudnn.so
/usr/local/cuda-8.0/lib64/libcudnn.so.6
/usr/local/cuda-8.0/lib64/libcudnn.so.6.0.21
/usr/local/cuda-8.0/lib64/libcudnn_static.a
/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudnn.so
/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudnn.so.7
/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudnn.so.7.1.4
/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudnn_static.a
Versions of relevant libraries:
[pip] Could not collect
[conda] blas 1.0 mkl
[conda] mkl 2019.1 144
[conda] mkl_fft 1.0.6 py37hd81dba3_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] pytorch 1.0.0 py3.7_cuda9.0.176_cudnn7.4.1_1 pytorch
[conda] torchvision 0.2.1 py_2 pytorch