Skip to content

Commit

Permalink
add a unit test for large node error (#48938)
Browse files Browse the repository at this point in the history
Summary:
add a unit test to test the situation where a node is too large to fit into any device

Pull Request resolved: #48938

Reviewed By: zhangguanheng66

Differential Revision: D25402967

Pulled By: scottxu0730

fbshipit-source-id: a2e2a3dc70d139fa678865ef03e67fa57eff4a1d
  • Loading branch information
scottxu0730 authored and facebook-github-bot committed Dec 8, 2020
1 parent 5960581 commit 6000481
Showing 1 changed file with 31 additions and 1 deletion.
32 changes: 31 additions & 1 deletion test/test_fx_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,37 @@ def forward(self, a, b):
catch_runtime_error = True
assert catch_runtime_error

def test_large_node_error(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 4)

def forward(self, a):
linear = self.linear(a)
add = linear + a
return add

m = TestModule()
traced = symbolic_trace(m)
a = torch.rand(4)
graph_manipulation.get_size_of_all_nodes(traced, [a])
partitioner = Partitioner()
devices = [
Device("dev_0", 40, 0),
Device("dev_1", 40, 0),
Device("dev_2", 40, 0),
Device("dev_3", 40, 0),
Device("dev_4", 40, 0)
]
partitioner_config = PartitionerConfig(devices, PartitionMode.size_based)
catch_runtime_error = False
try:
ret = partitioner.partition_graph(traced, m, partitioner_config)
except RuntimeError:
catch_runtime_error = True
assert catch_runtime_error

def test_partition_node_manipulation(self):
class TestModule(torch.nn.Module):
def forward(self, a, b):
Expand All @@ -187,7 +218,6 @@ def forward(self, a, b):
partition.remove_node(selected_node)
assert(partition.used_mem_bytes == 80)


def test_size_based_partition(self):
class TestModule(torch.nn.Module):
def __init__(self):
Expand Down

0 comments on commit 6000481

Please sign in to comment.