Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KeyError shape,stack,cos on pennylane quantum circuit #93624

Closed
msaroufim opened this issue Aug 13, 2022 · 2 comments
Closed

KeyError shape,stack,cos on pennylane quantum circuit #93624

msaroufim opened this issue Aug 13, 2022 · 2 comments
Assignees
Labels
dynamo-must-fix These bugs affect TorchDynamo reliability. module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@msaroufim
Copy link
Member

msaroufim commented Aug 13, 2022

@anijain2305 would you rather I create an "interesting model" tracker or should I keep creating unique issues for each kind of model

Repro

python -m pip install pennylane

import torch
import pennylane as qml
import torchdynamo

dev = qml.device('default.qubit', wires=2)

@qml.qnode(dev, interface='torch')
def circuit4(phi, theta):
    qml.RX(phi[0], wires=0)
    qml.RZ(phi[1], wires=1)
    qml.CNOT(wires=[0, 1])
    qml.RX(theta, wires=0)
    return qml.expval(qml.PauliZ(0))

def cost(phi, theta):
    return torch.abs(circuit4(phi, theta) - 0.5)**2

phi = torch.tensor([0.011, 0.012], requires_grad=True)
theta = torch.tensor(0.05, requires_grad=True)

opt = torch.optim.Adam([phi, theta], lr = 0.1)

steps = 200

def closure():
    opt.zero_grad()
    loss = cost(phi, theta)
    loss.backward()
    return loss

with torchdynamo.optimize("eager"):
  for i in range(steps):
        opt.step(closure)

Logs

https://gist.github.com/msaroufim/ce9ec004536e762fb5c94eb3ab2670f1

cc @ezyang @wconstab @bdhirsh @anijain2305 @zou3519 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng @soumith @ngimel

@anijain2305 anijain2305 self-assigned this Aug 16, 2022
@malfet malfet transferred this issue from pytorch/torchdynamo Feb 1, 2023
@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 7, 2023
@ydwu4
Copy link
Contributor

ydwu4 commented Nov 22, 2023

Not sure if the error is caused by pacakge upgrade in pennylane. But here is the repro:

import torch
import pennylane as qml

dev = qml.device('default.qubit', wires=2)

@qml.qnode(dev, interface='torch')
def circuit4(phi, theta):
    qml.RX(phi[0], wires=0)
    qml.RZ(phi[1], wires=1)
    qml.CNOT(wires=[0, 1])
    qml.RX(theta, wires=0)
    return qml.expval(qml.PauliZ(0))

def cost(phi, theta):
    return torch.abs(circuit4(phi, theta) - 0.5)**2

phi = torch.tensor([0.011, 0.012], requires_grad=True)
theta = torch.tensor(0.05, requires_grad=True)

opt = torch.optim.Adam([phi, theta], lr = 0.1)

steps = 200

def closure():
    opt.zero_grad()
    loss = cost(phi, theta)
    loss.backward()
    return loss

def f():
  for i in range(steps):
        opt.step(closure)

torch.compile(f, backend="eager")()

Error log:

Traceback (most recent call last):
  File "/home/yidi/local/pytorch/repro6.py", line 34, in <module>
    torch.compile(f, backend="eager")()
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 488, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/repro6.py", line 32, in f
    opt.step(closure)
  File "/home/yidi/local/pytorch/torch/optim/optimizer.py", line 76, in _use_grad
    ret = func(self, *args, **kwargs)
  File "/home/yidi/local/pytorch/torch/optim/adam.py", line 146, in step
    loss = closure()
  File "/home/yidi/local/pytorch/repro6.py", line 26, in closure
    loss = cost(phi, theta)
  File "/home/yidi/local/pytorch/repro6.py", line 15, in cost
    return torch.abs(circuit4(phi, theta) - 0.5)**2
  File "/home/yidi/local/miniconda3/envs/pytorch-3.10/lib/python3.10/site-packages/pennylane/qnode.py", line 970, in __call__
    self.construct(args, kwargs)
  File "/home/yidi/local/miniconda3/envs/pytorch-3.10/lib/python3.10/site-packages/pennylane/qnode.py", line 856, in construct
    self._qfunc_output = self.func(*args, **kwargs)
  File "/home/yidi/local/pytorch/repro6.py", line 8, in circuit4
    qml.RX(phi[0], wires=0)
  File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 654, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 721, in _convert_frame
    result = inner_convert(frame, cache_entry, hooks, frame_state)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 383, in _convert_frame_assert
    compiled_product = _compile(
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 664, in _compile
    raise InternalTorchDynamoError(str(e)).with_traceback(
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 645, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 562, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
    transformations(instructions, code_options)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 151, in _fn
    return fn(*args, **kwargs)
  File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 527, in transform
    tracer.run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2123, in run
    super().run()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 470, in wrapper
    return inner_fn(self, inst)
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 1213, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 652, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builtin.py", line 479, in call_function
    args = [v.realize() for v in args]
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builtin.py", line 479, in <listcomp>
    args = [v.realize() for v in args]
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/lazy.py", line 56, in realize
    self._cache.realize(self.parents_tracker)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/lazy.py", line 22, in realize
    self.vt = VariableBuilder(tx, self.source)(self.value)
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 239, in __call__
    vt = self._wrap(value).clone(**self.options())
  File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 693, in _wrap
    elif trace_rules.lookup(value) is not None:
  File "/home/yidi/local/pytorch/torch/_dynamo/trace_rules.py", line 180, in lookup
    if not hashable(obj):
  File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 420, in hashable
    hash(x)
  File "/home/yidi/local/miniconda3/envs/pytorch-3.10/lib/python3.10/site-packages/pennylane/operation.py", line 726, in __hash__
    return self.hash
  File "/home/yidi/local/miniconda3/envs/pytorch-3.10/lib/python3.10/site-packages/pennylane/operation.py", line 715, in hash
    str(self.name),
  File "/home/yidi/local/miniconda3/envs/pytorch-3.10/lib/python3.10/site-packages/pennylane/operation.py", line 901, in name
    return self._name
torch._dynamo.exc.InternalTorchDynamoError: 'RX' object has no attribute '_name'

from user code:
   File "/home/yidi/local/miniconda3/envs/pytorch-3.10/lib/python3.10/site-packages/pennylane/ops/qubit/parametric_ops_single_qubit.py", line 75, in __init__
    super().__init__(phi, wires=wires, id=id)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

@ezyang
Copy link
Contributor

ezyang commented Nov 22, 2023

The problem is we're hashing a user defined object, which can result in arbitrary code execution. We should not do this

  File "/home/yidi/local/pytorch/torch/_dynamo/trace_rules.py", line 180, in lookup
    if not hashable(obj):

@anijain2305 anijain2305 removed their assignment Jan 31, 2024
@anijain2305 anijain2305 added the dynamo-must-fix These bugs affect TorchDynamo reliability. label Jan 31, 2024
@williamwen42 williamwen42 self-assigned this May 10, 2024
williamwen42 added a commit that referenced this issue May 10, 2024
ghstack-source-id: b2ed907afd741b74ecb414794bc789be19af5d35
Pull Request resolved: #125945
williamwen42 added a commit that referenced this issue May 13, 2024
ghstack-source-id: 371f501ba5bb73203abc94d4859f44c46b784405
Pull Request resolved: #125945
williamwen42 added a commit that referenced this issue May 13, 2024
ghstack-source-id: 8ad813745794adeb66b40e958cd91649c590f476
Pull Request resolved: #125945
williamwen42 added a commit that referenced this issue May 14, 2024
Fixes #93624 but also requires jcmgray/autoray#20 to be fixed.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
williamwen42 added a commit that referenced this issue May 14, 2024
Fixes #93624 but also requires jcmgray/autoray#20 to be fixed.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang

[ghstack-poisoned]
williamwen42 added a commit that referenced this issue May 14, 2024
ghstack-source-id: 7c79ff2750e04e7a791a7cbb65eb2a541e0b9cdd
Pull Request resolved: #125945
williamwen42 added a commit that referenced this issue May 15, 2024
ghstack-source-id: d31b5c57619890cf24b36ddc1306d7a93b310386
Pull Request resolved: #125945
ZelboK pushed a commit to ZelboK/pytorch that referenced this issue May 19, 2024
Fixes pytorch#93624 but also requires jcmgray/autoray#20 to be fixed.

Pull Request resolved: pytorch#125945
Approved by: https://github.com/jansel
ghstack dependencies: pytorch#125882, pytorch#125943
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
dynamo-must-fix These bugs affect TorchDynamo reliability. module: dynamo oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants