Skip to content

Commit

Permalink
Add a test to make sure all modules in the codebase are importable (#…
Browse files Browse the repository at this point in the history
…110598)

As per title, running import on any of these files lead to a crash.
I'm very curious how the code in them is used!
Pull Request resolved: #110598
Approved by: https://github.com/janeyx99, https://github.com/malfet
  • Loading branch information
albanD authored and pytorchmergebot committed Oct 8, 2023
1 parent 230a124 commit 1824ea3
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 5 deletions.
144 changes: 140 additions & 4 deletions test/test_public_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import json
import os
import unittest
from importlib import import_module


class TestPublicBindings(TestCase):
Expand Down Expand Up @@ -221,6 +222,142 @@ def test_no_new_bindings(self):
msg = f"torch._C had bindings that are not present in the allowlist:\n{difference}"
self.assertTrue(torch_C_bindings.issubset(torch_C_allowlist_superset), msg)

@staticmethod
def _is_mod_public(modname):
split_strs = modname.split('.')
for elem in split_strs:
if elem.startswith("_"):
return False
return True

def test_modules_can_be_imported(self):
failures = []
for _, modname, _ in pkgutil.walk_packages(path=torch.__path__, prefix=torch.__name__ + '.'):
try:
# TODO: fix "torch/utils/model_dump/__main__.py"
# which calls sys.exit() when we try to import it
if "__main__" in modname:
continue
import_module(modname)
except Exception as e:
# Some current failures are not ImportError
failures.append((modname, type(e)))

# It is ok to add new entries here but please be careful that these modules
# do not get imported by public code.
private_allowlist = {
"torch._inductor.codegen.cuda.cuda_kernel",
"torch.onnx._internal.fx._pass",
"torch.onnx._internal.fx.analysis",
"torch.onnx._internal.fx.diagnostics",
"torch.onnx._internal.fx.fx_onnx_interpreter",
"torch.onnx._internal.fx.fx_symbolic_graph_extractor",
"torch.onnx._internal.fx.onnxfunction_dispatcher",
"torch.onnx._internal.fx.op_validation",
"torch.onnx._internal.fx.passes",
"torch.onnx._internal.fx.type_utils",
"torch.testing._internal.common_distributed",
"torch.testing._internal.common_fsdp",
"torch.testing._internal.dist_utils",
"torch.testing._internal.distributed._shard.sharded_tensor",
"torch.testing._internal.distributed._shard.test_common",
"torch.testing._internal.distributed._tensor.common_dtensor",
"torch.testing._internal.distributed.ddp_under_dist_autograd_test",
"torch.testing._internal.distributed.distributed_test",
"torch.testing._internal.distributed.distributed_utils",
"torch.testing._internal.distributed.fake_pg",
"torch.testing._internal.distributed.multi_threaded_pg",
"torch.testing._internal.distributed.nn.api.remote_module_test",
"torch.testing._internal.distributed.pipe_with_ddp_test",
"torch.testing._internal.distributed.rpc.dist_autograd_test",
"torch.testing._internal.distributed.rpc.dist_optimizer_test",
"torch.testing._internal.distributed.rpc.examples.parameter_server_test",
"torch.testing._internal.distributed.rpc.examples.reinforcement_learning_rpc_test",
"torch.testing._internal.distributed.rpc.faulty_agent_rpc_test",
"torch.testing._internal.distributed.rpc.faulty_rpc_agent_test_fixture",
"torch.testing._internal.distributed.rpc.jit.dist_autograd_test",
"torch.testing._internal.distributed.rpc.jit.rpc_test",
"torch.testing._internal.distributed.rpc.jit.rpc_test_faulty",
"torch.testing._internal.distributed.rpc.rpc_agent_test_fixture",
"torch.testing._internal.distributed.rpc.rpc_test",
"torch.testing._internal.distributed.rpc.tensorpipe_rpc_agent_test_fixture",
"torch.testing._internal.distributed.rpc_utils",
"torch.utils.tensorboard._caffe2_graph",
"torch._inductor.codegen.cuda.cuda_template",
"torch._inductor.codegen.cuda.gemm_template",
"torch._inductor.triton_helpers",
"torch.ao.pruning._experimental.data_sparsifier.lightning.callbacks.data_sparsity",
"torch.backends._coreml.preprocess",
"torch.contrib._tensorboard_vis",
"torch.distributed._composable",
"torch.distributed._functional_collectives",
"torch.distributed._functional_collectives_impl",
"torch.distributed._shard",
"torch.distributed._sharded_tensor",
"torch.distributed._sharding_spec",
"torch.distributed._spmd.api",
"torch.distributed._spmd.batch_dim_utils",
"torch.distributed._spmd.comm_tensor",
"torch.distributed._spmd.data_parallel",
"torch.distributed._spmd.distribute",
"torch.distributed._spmd.experimental_ops",
"torch.distributed._spmd.parallel_mode",
"torch.distributed._tensor",
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper",
"torch.distributed.algorithms._optimizer_overlap",
"torch.distributed.rpc._testing.faulty_agent_backend_registry",
"torch.distributed.rpc._utils",
}

# No new entries should be added to this list.
# All public modules should be importable on all platforms.
public_allowlist = {
"torch.distributed.algorithms.ddp_comm_hooks",
"torch.distributed.algorithms.model_averaging.averagers",
"torch.distributed.algorithms.model_averaging.hierarchical_model_averager",
"torch.distributed.algorithms.model_averaging.utils",
"torch.distributed.checkpoint",
"torch.distributed.constants",
"torch.distributed.distributed_c10d",
"torch.distributed.elastic.agent.server",
"torch.distributed.elastic.rendezvous",
"torch.distributed.fsdp",
"torch.distributed.launch",
"torch.distributed.launcher",
"torch.distributed.nn",
"torch.distributed.nn.api.remote_module",
"torch.distributed.optim",
"torch.distributed.optim.optimizer",
"torch.distributed.pipeline.sync",
"torch.distributed.rendezvous",
"torch.distributed.rpc.api",
"torch.distributed.rpc.backend_registry",
"torch.distributed.rpc.constants",
"torch.distributed.rpc.internal",
"torch.distributed.rpc.options",
"torch.distributed.rpc.rref_proxy",
"torch.distributed.elastic.rendezvous.etcd_rendezvous",
"torch.distributed.elastic.rendezvous.etcd_rendezvous_backend",
"torch.distributed.elastic.rendezvous.etcd_store",
"torch.distributed.rpc.server_process_global_profiler",
"torch.distributed.run",
"torch.distributed.tensor.parallel",
"torch.distributed.utils",
}

errors = []
for mod, excep_type in failures:
if mod in public_allowlist:
# TODO: Ensure this is the right error type
continue

if mod in private_allowlist:
continue

errors.append(f"{mod} failed to import with error {excep_type}")

self.assertEqual("", "\n".join(errors))

# AttributeError: module 'torch.distributed' has no attribute '_shard'
@unittest.skipIf(IS_WINDOWS or IS_JETSON, "Distributed Attribute Error")
def test_correct_module_names(self):
Expand All @@ -247,7 +384,6 @@ def test_correct_module_names(self):
allow_dict[allow_dict["being_migrated"][modname]] = allow_dict[modname]

def test_module(modname):
split_strs = modname.split('.')
try:
if "__main__" in modname:
return
Expand All @@ -256,9 +392,9 @@ def test_module(modname):
# It is ok to ignore here as we have a test above that ensures
# this should never happen
return
for elem in split_strs:
if elem.startswith("_"):
return

if not self._is_mod_public(modname):
return

# verifies that each public API has the correct module name and naming semantics
def check_one_element(elem, modname, mod, *, is_public, is_all):
Expand Down
2 changes: 1 addition & 1 deletion torch/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2295,7 +2295,7 @@ class TestCase(expecttest.TestCase):
# Always use difflib to print diffs on multi line equality.
# Undocumented feature in unittest
_diffThreshold = sys.maxsize
maxDiff = sys.maxsize
maxDiff = None

# checker to early terminate test suite if unrecoverable failure occurs.
def _should_stop_test_suite(self):
Expand Down

0 comments on commit 1824ea3

Please sign in to comment.