-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Closed
Labels
dynamo-triage-jan2025module: dynamooncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
tuple expression parsing error, the expected result is result0, torch.compile results are result1, look at the code in detail
import sys
import torch
from torch import nn
import numpy as np
class Bottleneck(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x+1
class TempOpModel(nn.Module):
def __init__(self):
super(TempOpModel,self).__init__()
self.bottlenecks = nn.ModuleList([Bottleneck() for _ in range(3)])
def forward(self,x):
y = list([x,x])
y.extend(m(y[-1]) for m in self.bottlenecks) # x为[0,0] y=[[0,0],[0,0],[1,1],[2,2],[3,3]
# out = y[-1]
# for m in self.bottlenecks:
# out = m(out)
# y.append(out)
# y.extend([m(y[-1]) for m in self.bottlenecks]) # x为[0,0] y=[[0,0],[0,0],[1,1],[1,1],[1,1]
# input = y[-1]
# for m in self.bottlenecks:
# out = m(input)
# y.append(out)
return y
def test_pytorch():
torch._dynamo.reset()
net = TempOpModel()
net.eval()
with torch.no_grad():
net_compile = torch.compile(net)
indata0 = torch.zeros((2))
result0 = net(indata0)
result1 = net_compile(indata0)
print(" net \n",result0,"\n net compile:\n",result1)
if __name__=='__main__':
test_pytorch()
Versions
torch: version 2.4.0+cpu
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
numpy: Version 1.24.1
python: version 3.8.10
### Tasks
cc @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames
Metadata
Metadata
Assignees
Labels
dynamo-triage-jan2025module: dynamooncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module