Skip to content

Commit

Permalink
Add append method for nn.Sequential (#71326)
Browse files Browse the repository at this point in the history
Summary:
Partially addresses #71249, and potentially supersedes #20274.

Pull Request resolved: #71326

Reviewed By: cpuhrsch

Differential Revision: D33855047

Pulled By: jbschlosser

fbshipit-source-id: a3a682e206f93b4c52bc3405e2f7b26aea6635ea
(cherry picked from commit c0b27bb)
  • Loading branch information
jaketae authored and pytorchmergebot committed Jan 31, 2022
1 parent 72c972e commit ca61292
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
11 changes: 11 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,6 +1521,17 @@ def test_Sequential_delitem(self):
del n[1::2]
self.assertEqual(n, nn.Sequential(l1, l3))

def test_Sequential_append(self):
l1 = nn.Linear(10, 20)
l2 = nn.Linear(20, 30)
l3 = nn.Linear(30, 40)
l4 = nn.Linear(40, 50)
n = nn.Sequential(l1, l2, l3)
n2 = n.append(l4)
self.assertEqual(n, nn.Sequential(l1, l2, l3, l4))
self.assertEqual(n2, nn.Sequential(l1, l2, l3, l4))
self.assertEqual(nn.Sequential(l1).append(l2).append(l4), nn.Sequential(l1, l2, l4))

def test_ModuleList(self):
modules = [nn.ReLU(), nn.Linear(5, 5)]
module_list = nn.ModuleList(modules)
Expand Down
9 changes: 9 additions & 0 deletions torch/nn/modules/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,15 @@ def forward(self, input):
input = module(input)
return input

def append(self, module: Module) -> 'Sequential':
r"""Appends a given module to the end.
Args:
module (nn.Module): module to append
"""
self.add_module(str(len(self)), module)
return self


class ModuleList(Module):
r"""Holds submodules in a list.
Expand Down

0 comments on commit ca61292

Please sign in to comment.