Skip to content

Commit

Permalink
Update on "[dynamo] fix module buffers call"
Browse files Browse the repository at this point in the history
This PR fixes module buffers call and extract module.buffers similar to
module.parameters

cc voznesenskym penguinwu anijain2305 EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx

[ghstack-poisoned]
  • Loading branch information
wanchaol committed May 25, 2023
2 parents 1cc50f2 + 65b28ba commit 381ea73
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions torch/_dynamo/variables/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,20 +637,26 @@ def call_function(
value, mutable_local=MutableLocal(), **options
)
elif (
inspect.getmodule(self.value) is itertools
self.value is itertools.product
and not kwargs
and all(arg.has_unpack_var_sequence(tx) for arg in args)
):
seqs = [arg.unpack_var_sequence(tx) for arg in args]
items = []
if self.value is itertools.product:
for item in itertools.product(*seqs):
items.append(variables.TupleVariable(list(item), **options))
elif self.value is itertools.chain:
for item in itertools.chain(*seqs):
items.append(item)
else:
unimplemented(f"call_function {self.value} with {args}")
for item in itertools.product(*seqs):
items.append(variables.TupleVariable(list(item), **options))
return variables.ListIteratorVariable(
items, mutable_local=MutableLocal(), **options
)
elif (
self.value is itertools.chain
and not kwargs
and all(arg.has_unpack_var_sequence(tx) for arg in args)
):
seqs = [arg.unpack_var_sequence(tx) for arg in args]
items = []
for item in itertools.chain(*seqs):
items.append(item)
return variables.ListIteratorVariable(
items, mutable_local=MutableLocal(), **options
)
Expand Down

0 comments on commit 381ea73

Please sign in to comment.