Skip to content

Commit

Permalink
fix flop counter for trivial module
Browse files Browse the repository at this point in the history
Summary: otherwise it fails.

Reviewed By: haooooooqi

Differential Revision: D29981975

fbshipit-source-id: dab1c8e404f4dfa7ec8eb373077009f435259dbf
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed Jul 29, 2021
1 parent f768f7a commit 9958c1d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
2 changes: 1 addition & 1 deletion fvcore/nn/print_model_statistics.py
Expand Up @@ -155,7 +155,7 @@ def _remove_zero_statistics(
with submodules removed if they have zero for all statistics.
"""
out_stats: Dict[str, Dict[str, int]] = {}
_force_keep: Set[str] = force_keep if force_keep else set()
_force_keep: Set[str] = force_keep if force_keep else set() | {""}

def keep_stat(name: str) -> None:
prefix = name + ("." if name else "")
Expand Down
9 changes: 9 additions & 0 deletions tests/test_print_model_statistics.py
Expand Up @@ -535,3 +535,12 @@ def test_flop_count_str(self) -> None:
# " )\n"
# " )\n"
# ")"

def test_flop_count_empty(self) -> None:
model = nn.ReLU()
inputs = (torch.randn((1, 10)),)
table = flop_count_table(FlopCountAnalysis(model, inputs))
self.assertGreater(len(table), 0)

out = flop_count_str(FlopCountAnalysis(model, inputs))
self.assertGreater(len(out), 0)

0 comments on commit 9958c1d

Please sign in to comment.