Skip to content

[JIT] isinstance(m, nn.Linear) returns False in ScriptModules #19348

@qbx2

Description

@qbx2

🐛 Bug

We usually use for m in self.modules() loop to initialize weights and biases. However, isinstance() is not working properly when the module is JIT-supported (all modules except for Unsupported torch.nn Modules https://pytorch.org/docs/stable/jit.html#builtin-functions).

To Reproduce

Steps to reproduce the behavior:

import torch.nn as nn
import torch.jit as jit


class SomeScriptModule(jit.ScriptModule):
    def __init__(self):
        super().__init__()

        self.linear = nn.Linear(16, 16)
        self.conv2d = nn.Conv2d(3, 8, 3)
        self.conv3d = nn.Conv3d(3, 8, 3)
        self.gru = nn.GRU(16, 16)
        self.lstm = nn.LSTM(16, 16)

        for m in self.modules():
            print(m, type(m))

            if isinstance(m, nn.Linear):
                print(f'm is Linear')
                continue

            if isinstance(m, nn.Conv2d):
                print(f'm is Conv2d')
                continue

            if isinstance(m, nn.Conv3d):
                print(f'm is Conv3d')
                continue

            if isinstance(m, nn.GRU):
                print(f'm is GRU')
                continue

            if isinstance(m, nn.LSTM):
                print(f'm is LSTM')
                continue

            print('??????')


SomeScriptModule()
$ python test.py
SomeScriptModule(
  (linear): WeakScriptModuleProxy()
  (conv2d): WeakScriptModuleProxy()
  (conv3d): WeakScriptModuleProxy()
  (gru): GRU(16, 16)
  (lstm): LSTM(16, 16)
) <class '__main__.SomeScriptModule'>
??????
WeakScriptModuleProxy() <class 'torch.jit.WeakScriptModuleProxy'>
??????
WeakScriptModuleProxy() <class 'torch.jit.WeakScriptModuleProxy'>
??????
WeakScriptModuleProxy() <class 'torch.jit.WeakScriptModuleProxy'>
??????
GRU(16, 16) <class 'torch.nn.modules.rnn.GRU'>
m is GRU
LSTM(16, 16) <class 'torch.nn.modules.rnn.LSTM'>
m is LSTM

Expected behavior

isinstance() should return True in the code above. It does not seem to be possible, however. If then, I think that torch should provide torch.isinstance() or any method to support WeakScriptModuleProxy.

Environment

$ python collect_env.py
Collecting environment information...
PyTorch version: 1.0.1.post2
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: Could not collect

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration: GPU 0: GeForce RTX 2080
Nvidia driver version: 410.48
cuDNN version: /usr/local/cuda-10.0/lib64/libcudnn.so.7.4.1

Versions of relevant libraries:
[pip] numpy==1.15.4
[pip] numpydoc==0.8.0
[pip] torch==1.0.1.post2
[conda] blas 1.0 mkl
[conda] mkl 2019.1 144
[conda] mkl-service 1.1.2 py37he904b0f_5
[conda] mkl_fft 1.0.6 py37hd81dba3_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] pytorch 1.0.1 py3.7_cuda10.0.130_cudnn7.4.2_2 pytorch

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    low priorityWe're unlikely to get around to doing this in the near futureoncall: jitAdd this issue/PR to JIT oncall triage queue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions