-
Notifications
You must be signed in to change notification settings - Fork 26k
Description
🐛 Bug
We usually use for m in self.modules() loop to initialize weights and biases. However, isinstance() is not working properly when the module is JIT-supported (all modules except for Unsupported torch.nn Modules https://pytorch.org/docs/stable/jit.html#builtin-functions).
To Reproduce
Steps to reproduce the behavior:
import torch.nn as nn
import torch.jit as jit
class SomeScriptModule(jit.ScriptModule):
def __init__(self):
super().__init__()
self.linear = nn.Linear(16, 16)
self.conv2d = nn.Conv2d(3, 8, 3)
self.conv3d = nn.Conv3d(3, 8, 3)
self.gru = nn.GRU(16, 16)
self.lstm = nn.LSTM(16, 16)
for m in self.modules():
print(m, type(m))
if isinstance(m, nn.Linear):
print(f'm is Linear')
continue
if isinstance(m, nn.Conv2d):
print(f'm is Conv2d')
continue
if isinstance(m, nn.Conv3d):
print(f'm is Conv3d')
continue
if isinstance(m, nn.GRU):
print(f'm is GRU')
continue
if isinstance(m, nn.LSTM):
print(f'm is LSTM')
continue
print('??????')
SomeScriptModule()$ python test.py
SomeScriptModule(
(linear): WeakScriptModuleProxy()
(conv2d): WeakScriptModuleProxy()
(conv3d): WeakScriptModuleProxy()
(gru): GRU(16, 16)
(lstm): LSTM(16, 16)
) <class '__main__.SomeScriptModule'>
??????
WeakScriptModuleProxy() <class 'torch.jit.WeakScriptModuleProxy'>
??????
WeakScriptModuleProxy() <class 'torch.jit.WeakScriptModuleProxy'>
??????
WeakScriptModuleProxy() <class 'torch.jit.WeakScriptModuleProxy'>
??????
GRU(16, 16) <class 'torch.nn.modules.rnn.GRU'>
m is GRU
LSTM(16, 16) <class 'torch.nn.modules.rnn.LSTM'>
m is LSTM
Expected behavior
isinstance() should return True in the code above. It does not seem to be possible, however. If then, I think that torch should provide torch.isinstance() or any method to support WeakScriptModuleProxy.
Environment
$ python collect_env.py
Collecting environment information...
PyTorch version: 1.0.1.post2
Is debug build: No
CUDA used to build PyTorch: 10.0.130
OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: Could not collect
Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration: GPU 0: GeForce RTX 2080
Nvidia driver version: 410.48
cuDNN version: /usr/local/cuda-10.0/lib64/libcudnn.so.7.4.1
Versions of relevant libraries:
[pip] numpy==1.15.4
[pip] numpydoc==0.8.0
[pip] torch==1.0.1.post2
[conda] blas 1.0 mkl
[conda] mkl 2019.1 144
[conda] mkl-service 1.1.2 py37he904b0f_5
[conda] mkl_fft 1.0.6 py37hd81dba3_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] pytorch 1.0.1 py3.7_cuda10.0.130_cudnn7.4.2_2 pytorch