-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Closed
Labels
oncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queue
Description
🐛 Bug
To Reproduce
Steps to reproduce the behavior:
- Go to doc and check for code example
- Copy/paste into fresh install venv woth pytorch 1.2
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Conv2d(1, 1, 3)
def forward(self, x):
return self.conv(x)
def weighted_kernel_sum(self, weight):
return weight * self.conv.weight
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)
n = Net()
# the following two calls are equivalent
module = torch.jit.trace_module(n, example_forward_input)Error message
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-2-bd56084c5306> in <module>
14 n = Net()
15 # the following two calls are equivalent
---> 16 module = torch.jit.trace_module(n, example_forward_input)
17 module = torch.jit.trace_module(n.forward, example_forward_input)
~/miniconda3/envs/ai/lib/python3.7/site-packages/torch/jit/__init__.py in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, _force_outplace, _module_class, _compilation_unit)
894
895 if not isinstance(inputs, dict):
--> 896 raise AttributeError("expected a dictionary of (method_name, input) pairs")
897
898 module = make_module(mod, _module_class, _compilation_unit)
AttributeError: expected a dictionary of (method_name, input) pairsExpected behavior
Doc example to work :)
Environment
PyTorch version: 1.2.0
Is debug build: No
CUDA used to build PyTorch: 10.0.130
OS: Ubuntu 18.04.3 LTS
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
CMake version: version 3.10.2
Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.0.130
GPU models and configuration:
GPU 0: GeForce GTX 1080 Ti
GPU 1: GeForce GTX 1080 Ti
Nvidia driver version: 430.14
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.0
Versions of relevant libraries:
[pip] numpy==1.16.4
[pip] torch==1.2.0
[pip] torchsummary==1.5.1
[pip] torchvision==0.4.0a0+6b959ee
[conda] Could not collect
driazati
Metadata
Metadata
Assignees
Labels
oncall: jitAdd this issue/PR to JIT oncall triage queueAdd this issue/PR to JIT oncall triage queue