Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JIT] Improve error messaging for using a tensor attribute in ScriptModule #16284

Closed
sidazhang opened this issue Jan 23, 2019 · 1 comment
Closed
Assignees
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue

Comments

@sidazhang
Copy link

sidazhang commented Jan 23, 2019

馃悰 Bug

ScriptModule is not able to have a tensor as an attribute

To Reproduce

Not Constant

import torch
class ConstantTensor(torch.jit.ScriptModule):
    def __init__(self):
        self.cx = torch.ones(100, 100, dtype=torch.float, device='cuda')

    @torch.jit.script_method
    def forward(self, x):
        return x + self.cx


c = ConstantTensor()
print c.graph

attribute 'cx' of type 'Tensor' is not usable in a script method (did you forget to add it __constants__?):

Constant

import torch
class ConstantTensor(torch.jit.ScriptModule):
    __constants__ = ['cx']
    def __init__(self):
        self.cx = torch.ones(100, 100, dtype=torch.float, device='cuda')

    @torch.jit.script_method
    def forward(self, x):
        return x + self.cx


c = ConstantTensor()
print c.graph
TypeError: 'Tensor' object for attribute 'cx' is not a valid constant.
Valid constants are:
  1. a nn.ModuleList
  2. a value of type {bool, float, int, str, NoneType, function, device, layout, dtype}
  3. a list or tuple of (2)


UPDATE:

Seems like the right way to use a tensor attribute is to declare it to be a buffer

self.register_buffer('cx', torch.ones(100, 100, dtype=torch.float, device='cuda'))

Please improve the error message when a user tries to use a tensor as an attribute

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Jan 23, 2019
@sidazhang sidazhang changed the title ScriptModule is unable to have a tensor as attribute (constant or not) [JIT] ScriptModule is unable to have a tensor as attribute (constant or not) Jan 23, 2019
@sidazhang
Copy link
Author

It seems like that the correct thing to do is

self.register_buffer('cx', torch.ones(100, 100, dtype=torch.float, device='cuda'))

In this case, could you improve the error messaging so it is explicit?

@sidazhang sidazhang changed the title [JIT] ScriptModule is unable to have a tensor as attribute (constant or not) [JIT] Improve error messaging for using a tensor attribute in ScriptModule Jan 24, 2019
@driazati driazati self-assigned this Jan 24, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants