Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorBoard] Graph with objects other than torch.nn.Module can not be visualized. #30459

Open
yangsenius opened this issue Nov 26, 2019 · 28 comments
Labels
module: tensorboard triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@yangsenius
Copy link
Contributor

馃悰 Bug

I use tensorboardX to add_graph(model, (dump,)), but

    writer_dict['writer'].add_graph(model, (dump_input, ))
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/tensorboardX/writer.py", line 774, in add_graph
    self._get_file_writer().add_graph(graph(model, input_to_model, verbose, **kwargs))
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/tensorboardX/pytorch_graph.py", line 275, in graph
    trace = torch.jit.trace(model, args)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py", line 772, in trace
    check_tolerance, _force_outplace, _module_class)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py", line 898, in trace_module
    module = make_module(mod, _module_class, _compilation_unit)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py", line 669, in make_module
    return _module_class(mod, _compilation_unit=_compilation_unit)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py", line 1386, in init_then_register
    original_init(self, *args, **kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py", line 1386, in init_then_register
    original_init(self, *args, **kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py", line 1881, in __init__
    self._modules[name] = TracedModule(submodule, id_set)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py", line 1386, in init_then_register
    original_init(self, *args, **kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py", line 1881, in __init__
    self._modules[name] = TracedModule(submodule, id_set)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py", line 1386, in init_then_register
    original_init(self, *args, **kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py", line 1881, in __init__
    self._modules[name] = TracedModule(submodule, id_set)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py", line 1386, in init_then_register
    original_init(self, *args, **kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py", line 1881, in __init__
    self._modules[name] = TracedModule(submodule, id_set)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py", line 1386, in init_then_register
    original_init(self, *args, **kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py", line 1881, in __init__
    self._modules[name] = TracedModule(submodule, id_set)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py", line 1386, in init_then_register
    original_init(self, *args, **kwargs)
  File "/home/ubuntu/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py", line 1855, in __init__
    assert(isinstance(orig, torch.nn.Module))
AssertionError

To Reproduce

Steps to reproduce the behavior:

import torch
from tensorboardX import SummaryWriter

writer = SummaryWriter()
model = mymodel()
dump_input = torch.rand( (1, 3,256, 256)  )
writer.add_graph(model, (dump_input, ))

Expected behavior

Environment

Collecting environment information...
PyTorch version: 1.2.0
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 16.04.6 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.11) 5.4.0 20160609
CMake version: Could not collect

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.1.105
GPU models and configuration:
GPU 0: GeForce RTX 2080 Ti
GPU 1: GeForce RTX 2080 Ti

Nvidia driver version: 418.43
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.5.0
/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcudnn.so.7

Versions of relevant libraries:
[pip3] numpy==1.16.2
[pip3] numpydoc==0.8.0
[pip3] torch==1.2.0
[pip3] torchvision==0.4.0
[conda] blas 1.0 mkl
[conda] mkl 2019.1 144
[conda] mkl-service 1.1.2 py37he904b0f_5
[conda] mkl_fft 1.0.10 py37ha843d7b_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] torch 1.2.0 pypi_0 pypi
[conda] torchvision 0.3.0 pypi_0 pypi

Additional context

The same problem occurred when using torch.utils.tensorboard

@pietern pietern added module: tensorboard triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Nov 27, 2019
@lanpa
Copy link
Collaborator

lanpa commented Dec 9, 2019

@yangsenius How mymodel is defined? I use

class mymodel(torch.nn.Module):
    def __init__(self):
        super(mymodel, self).__init__()
        pass
    def forward(self, x):
        return x

as a mock and there is no error.

@yangsenius
Copy link
Contributor Author

yangsenius commented Dec 11, 2019

@lanpa
Copy link
Collaborator

lanpa commented Dec 15, 2019

@yangsenius I noted that you are using torch==1.2.0. Also, you have mixed installation of pip and conda. Would you try again with a cleaner and newer environment? Thanks.

@yangsenius
Copy link
Contributor Author

Thanks

@tshrjn
Copy link

tshrjn commented Jan 3, 2020

@yangsenius How did you resolve this issue? I'm facing the same issue even with newly created conda env as well.

@yangsenius
Copy link
Contributor Author

No, I didn't resolve this bug even in a cleaner and newer environment. @tshrjn @lanpa

Environment

Collecting environment information...
PyTorch version: 1.4.0
Is debug build: No
CUDA used to build PyTorch: 10.1

OS: Ubuntu 16.04.6 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.12) 5.4.0 20160609
CMake version: Could not collect

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.1.105
GPU models and configuration:
GPU 0: GeForce RTX 2080 Ti
GPU 1: GeForce RTX 2080 Ti

Nvidia driver version: 418.43
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.5.0
/usr/local/cuda-10.1/targets/x86_64-linux/lib/libcudnn.so.7

Versions of relevant libraries:
[pip3] numpy==1.16.4
[conda] blas 1.0 mkl
[conda] mkl 2019.4 243
[conda] mkl-service 2.3.0 py37he904b0f_0
[conda] mkl_fft 1.0.15 py37ha843d7b_0
[conda] mkl_random 1.1.0 py37hd6b4f25_0
[conda] pytorch 1.4.0 py3.7_cuda10.1.243_cudnn7.6.3_0 pytorch-nightly
[conda] torchvision 0.5.0.dev20200103 py37_cu101 pytorch-nightly

I guess that current verion of PyTorch or Tensorboard lib does not support some specific computation graphs. I can not figure out what causes this issue.

When I used a simple model to add_graph, no bugs occurred.

import torch 
from torchvision.models.resnet import resnet50
from torch.utils.tensorboard import SummaryWriter
#from tensorboardX import SummaryWriter
net = resnet50()
inputs = torch.randn(1,3,256,256)
o = net(inputs)
graph = SummaryWriter()
graph.add_graph(net, (inputs,) )

I think this problem is not a special case, many people using new version of PyTorch (>1.0) have confused by this.

@yangsenius yangsenius reopened this Jan 4, 2020
@lanpa lanpa changed the title add_graph error: assert(isinstance(orig, torch.nn.Module)) when using tensorboardX or torch.utils.tensorboard [TensorBoard] Graph with objects other than torch.nn.Module can not be visualized. Jan 12, 2020
@lanpa
Copy link
Collaborator

lanpa commented Jan 15, 2020

@Muhtasham
Copy link

Muhtasham commented Mar 4, 2020

@yangsenius Could you solve the problem ? If so can you please share

@yangsenius
Copy link
Contributor Author

I think just the None appearing in forward function breaks the JIT like what lanpa has pointed. A common neural network without None has no problems. @Muhtasham

@Kevin0624
Copy link

Kevin0624 commented Mar 9, 2020

@yangsenius
Copy link
Contributor Author

yangsenius commented Mar 11, 2020

Because the nn.ModuleList() requires its elements to be nn.Module subclass or None.

So I try define a None_Module

class none_module(nn.Module):
    def __init__(self,):
        super(none_module, self).__init__()
        self.none_module_property = True

Then use _none= none_module() and transition_layers.append(_none).
In forward, use if not hasattr(self.transition_layers[i], "none_module_property"): instead of if self.transition[i] is not None:

@yangsenius
Copy link
Contributor Author

@Kevin0624 Does it work?

@andres-fr

This comment has been minimized.

@yangsenius
Copy link
Contributor Author

yangsenius commented Apr 28, 2020

I don't think the None appearing in the BasicBlock and Bottleneckstill triggers the same assert(isinstance(orig, torch.nn.Module)). They are not appended to the nn.ModuleListand just function as flags. Your modification is more elegant, thanks. And I modified my comment above, I forget to add thenot, it should be if not .... @andres-fr

EDIT 2: I don't think the comment below worked, I probably mixed up things when quick testing. The reason is probably that BasicBlock, Bottleneck and _make_fuse_layers also make use of None as an nn.Module attribute, so patching make_transition_layersalone still triggers the same assert(isinstance(orig, torch.nn.Module)) assertion. If this is true, patching also these could lead to the solution (but I can't try that out myself yet). Let me know if someone tries this out!

@yangsenius I confirm that it works, I made a slight modification but shouldn't matter much:

class NoOpModule(nn.Module):
    """
    https://github.com/pytorch/pytorch/issues/30459#issuecomment-597679482
    """
    def __init__(self, *args, **kwargs):
        super().__init__()

    def forward(self, *args, **kwargs):
        return args

Then in make_transition_layer: transition_layers.append(NoOpModule())

And then in forward, for each respective stage (e.g. stage 3 here):

            if not isinstance(self.transition3[i], NoOpModule):
                x_list.append(self.transition3[i](y_list[-1]))
            else:
                x_list.append(y_list[i])

Nevertheless, I still encounter the problem that the plotted graph is behind some 803 wrapper and can't dig inside. Anyone knows the reason/solution? EDIT: I suspect that it has to do with the mixed precision (fp16) wrapper, since it also happens to a different network that just shares the stem with the HHRNet

valid_test_log_2020-04-27-21-54

@yangsenius
Copy link
Contributor Author

yangsenius commented Apr 28, 2020

image

class NoOpModule(nn.Module):
    """
    https://github.com/pytorch/pytorch/issues/30459#issuecomment-597679482
    """
    def __init__(self, *args, **kwargs):
        super().__init__()

    def forward(self, *args, **kwargs):
        return args

Then in make_transition_layer: transition_layers.append(NoOpModule());
And in _make_fuse_layers: fuse_layer.append(NoOpModule())

And then in forward, for each respective stage (e.g. stage 3 here):

            if not isinstance(self.transition3[i], NoOpModule):
                x_list.append(self.transition3[i](y_list[-1]))
            else:
                x_list.append(y_list[i])

@H19012
Copy link

H19012 commented Jun 16, 2020

image

class NoOpModule(nn.Module):
    """
    https://github.com/pytorch/pytorch/issues/30459#issuecomment-597679482
    """
    def __init__(self, *args, **kwargs):
        super().__init__()

    def forward(self, *args, **kwargs):
        return args

Then in make_transition_layer: transition_layers.append(NoOpModule())

And then in forward, for each respective stage (e.g. stage 3 here):

            if not isinstance(self.transition3[i], NoOpModule):
                x_list.append(self.transition3[i](y_lis t[-1]))
            else:
                x_list.append(y_list[i])

I'm still getting assert(isinstance(orig, torch.nn.Module))
AssertionError

BTW I put transition_layers.append(NoOpModule()) in place of transition_layers.append(None)

@yangsenius
Copy link
Contributor Author

You may need to modify several places in which transition_layers is appended None.

@H19012
Copy link

H19012 commented Jun 18, 2020

You may need to modify several places in which transition_layers is appended None.

There seems to be only one line where None append of transition_layers is written. https://github.com/HRNet/HigherHRNet-Human-Pose-Estimation/blob/master/lib/models/pose_higher_hrnet.py#L410 @

@yangsenius
Copy link
Contributor Author

You may need to modify several places in which transition_layers is appended None.

There seems to be only one line where None append of transition_layers is written. https://github.com/HRNet/HigherHRNet-Human-Pose-Estimation/blob/master/lib/models/pose_higher_hrnet.py#L410 @

Also the fuse_layer.append(None): https://github.com/HRNet/HigherHRNet-Human-Pose-Estimation/blob/b4610aecaa5cf3de3cd69bfb13c7c79c8d514c7c/lib/models/pose_higher_hrnet.py#L198. And every if self.transition_xx is not None.

@H19012
Copy link

H19012 commented Jun 23, 2020

@yangsenius Thank you very much. It is working now.

@newwhitecheng
Copy link

You may need to modify several places in which transition_layers is appended None.

There seems to be only one line where None append of transition_layers is written. https://github.com/HRNet/HigherHRNet-Human-Pose-Estimation/blob/master/lib/models/pose_higher_hrnet.py#L410 @

Also the fuse_layer.append(None): https://github.com/HRNet/HigherHRNet-Human-Pose-Estimation/blob/b4610aecaa5cf3de3cd69bfb13c7c79c8d514c7c/lib/models/pose_higher_hrnet.py#L198. And every if self.transition_xx is not None.

@yangsenius May I ask which version of tensorboardx and pytorch you use?

@yangsenius
Copy link
Contributor Author

You may need to modify several places in which transition_layers is appended None.

There seems to be only one line where None append of transition_layers is written. https://github.com/HRNet/HigherHRNet-Human-Pose-Estimation/blob/master/lib/models/pose_higher_hrnet.py#L410 @

Also the fuse_layer.append(None): https://github.com/HRNet/HigherHRNet-Human-Pose-Estimation/blob/b4610aecaa5cf3de3cd69bfb13c7c79c8d514c7c/lib/models/pose_higher_hrnet.py#L198. And every if self.transition_xx is not None.

@yangsenius May I ask which version of tensorboardx and pytorch you use?

PyTorch == 1.2.0 and TensorboardX==2.0 (likely)

@newwhitecheng
Copy link

You may need to modify several places in which transition_layers is appended None.

There seems to be only one line where None append of transition_layers is written. https://github.com/HRNet/HigherHRNet-Human-Pose-Estimation/blob/master/lib/models/pose_higher_hrnet.py#L410 @

Also the fuse_layer.append(None): https://github.com/HRNet/HigherHRNet-Human-Pose-Estimation/blob/b4610aecaa5cf3de3cd69bfb13c7c79c8d514c7c/lib/models/pose_higher_hrnet.py#L198. And every if self.transition_xx is not None.

@yangsenius May I ask which version of tensorboardx and pytorch you use?

PyTorch == 1.2.0 and TensorboardX==2.0 (likely)

Hi @yangsenius thanks for your quick reply! I tried your method and Pytorch==1.2 and also TensorboardX==2.0 but I got the following error. Any Idea?

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/hc218/anaconda3/envs/torch/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 19, in _wrap
    fn(i, *args)
  File "/home/hc218/workspace/github/HRNet/HigherHRNet-Human-Pose-Estimation/tools/dist_train.py", line 196, in main_worker
    writer_dict['writer'].add_graph(model, (dump_input, ))
  File "/home/hc218/anaconda3/envs/torch/lib/python3.7/site-packages/tensorboardX/writer.py", line 804, in add_graph
    self._get_file_writer().add_graph(graph(model, input_to_model, verbose, profile_with_cuda, **kwargs))
  File "/home/hc218/anaconda3/envs/torch/lib/python3.7/site-packages/tensorboardX/pytorch_graph.py", line 335, in graph
    raise e
  File "/home/hc218/anaconda3/envs/torch/lib/python3.7/site-packages/tensorboardX/pytorch_graph.py", line 326, in graph
    trace = torch.jit.trace(model, args)
  File "/home/hc218/anaconda3/envs/torch/lib/python3.7/site-packages/torch/jit/__init__.py", line 772, in trace
    check_tolerance, _force_outplace, _module_class)
  File "/home/hc218/anaconda3/envs/torch/lib/python3.7/site-packages/torch/jit/__init__.py", line 904, in trace_module
    module._c._create_method_from_trace(method_name, func, example_inputs, var_lookup_fn, _force_outplace)
RuntimeError: Only tensors or tuples of tensors can be output from traced functions (getOutput at /tmp/pip-req-build-58y_cjjl/torch/csrc/jit/tracer.cpp:208)
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x6d (0x7f72f950c1
cd in /home/hc218/anaconda3/envs/torch/lib/python3.7/site-packages/torch/lib/libc10.so)
frame #1: torch::jit::tracer::TracingState::getOutput(c10::IValue const&) + 0x444 (0x7f72fce6b7f4 in /home/hc218/anaconda3/envs/torch/lib/python3.7/site-pack
ages/torch/lib/libtorch.so)
frame #2: torch::jit::tracer::exit(std::vector<c10::IValue, std::allocator<c10::IValue> > const&) + 0x4d (0x7f72fce6bb8d in /home/hc218/anaconda3/envs/torch/
lib/python3.7/site-packages/torch/lib/libtorch.so)
frame #3: <unknown function> + 0x4cda00 (0x7f731bea7a00 in /home/hc218/anaconda3/envs/torch/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #4: <unknown function> + 0x507b22 (0x7f731bee1b22 in /home/hc218/anaconda3/envs/torch/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #5: <unknown function> + 0x1c7126 (0x7f731bba1126 in /home/hc218/anaconda3/envs/torch/lib/python3.7/site-packages/torch/lib/libtorch_python.so)
frame #6: _PyMethodDef_RawFastCallKeywords + 0x264 (0x55a837429c94 in /home/hc218/anaconda3/envs/torch/bin/python)
frame #7: _PyCFunction_FastCallKeywords + 0x21 (0x55a837429db1 in /home/hc218/anaconda3/envs/torch/bin/python)
frame #8: _PyEval_EvalFrameDefault + 0x51d1 (0x55a8374959a1 in /home/hc218/anaconda3/envs/torch/bin/python)
frame #9: _PyEval_EvalCodeWithName + 0x2f9 (0x55a8373d92b9 in /home/hc218/anaconda3/envs/torch/bin/python)
frame #10: _PyFunction_FastCallKeywords + 0x325 (0x55a837429435 in /home/hc218/anaconda3/envs/torch/bin/python)
frame #11: _PyEval_EvalFrameDefault + 0x416 (0x55a837490be6 in /home/hc218/anaconda3/envs/torch/bin/python)
frame #12: _PyEval_EvalCodeWithName + 0x2f9 (0x55a8373d92b9 in /home/hc218/anaconda3/envs/torch/bin/python)
frame #13: _PyFunction_FastCallKeywords + 0x325 (0x55a837429435 in /home/hc218/anaconda3/envs/torch/bin/python)

@yangsenius
Copy link
Contributor Author

@newwhitecheng. Can you have a try with tensorboardX==1.4? I may have updated the version.

@newwhitecheng
Copy link

@newwhitecheng. Can you have a try with tensorboardX==1.4? I may have updated the version.

sure, thanks a lot, I think it's just a version mismatch problem. Thanks for your valuable solution!

@cpbotha
Copy link

cpbotha commented Aug 15, 2020

These days, you can also use torch.nn.Identity instead of adding your own NoOpModule, it's the same.

See cpbotha/deep-high-resolution-net.pytorch@b1a7fdb for how to apply this fix on the pose HRNet code.

@yangsenius
Copy link
Contributor Author

These days, you can also use torch.nn.Identity instead of adding your own NoOpModule, it's the same.

See cpbotha/deep-high-resolution-net.pytorch@b1a7fdb for how to apply this fix on the pose HRNet code.

@cpbotha Thanks very much for your proposal. The torch.nn.Identity is the placeholder operation that we need.

@annie-surla
Copy link

@newwhitecheng Did it work for you with tensoboardX==1.4?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: tensorboard triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests