Skip to content

Commit

Permalink
[dynamo] fix module buffers call (#102251)
Browse files Browse the repository at this point in the history
This PR fixes module buffers call and extract module.buffers similar to
module.parameters

Pull Request resolved: #102251
Approved by: https://github.com/wconstab
  • Loading branch information
wanchaol authored and pytorchmergebot committed May 25, 2023
1 parent d40f4f1 commit c1db235
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
11 changes: 11 additions & 0 deletions test/dynamo/test_modules.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Owner(s): ["module: dynamo"]

import collections
import itertools
import traceback
import types
import unittest
Expand Down Expand Up @@ -1386,6 +1387,16 @@ def forward(self, x):
# Check parameteres and buffers
for p1, p2 in zip(mod.parameters(), opt_mod.parameters()):
self.assertTrue(id(p1) == id(p2))
for b1, b2 in zip(mod.buffers(), opt_mod.buffers()):
self.assertTrue(id(b1) == id(b2))

def get_parameter_dtype(mod: torch.nn.Module):
parameters_and_buffers = itertools.chain(mod.parameters(), mod.buffers())
return next(parameters_and_buffers).dtype

opt_mod = torch._dynamo.optimize("eager")(get_parameter_dtype)
out_dtype = opt_mod(mod)
self.assertEqual(out_dtype, torch.float32)

def test_recursion(self):
mod = MockModule()
Expand Down
2 changes: 2 additions & 0 deletions torch/_dynamo/variables/nn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,8 @@ def gen_source(source, name):
return wrap_values(module.named_modules())
elif name == "parameters":
return wrap_values(module.named_parameters(**get_kwargs("recurse")))
elif name == "buffers":
return wrap_values(module.named_buffers(**get_kwargs("recurse")))
elif name == "keys":
assert not (args or kwargs)
result = []
Expand Down

0 comments on commit c1db235

Please sign in to comment.