diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index d890962696eee..af8fd2dfae285 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -1006,9 +1006,9 @@ def trace_module(mod, Arguments: mod (torch.nn.Module): A ``torch.nn.Module`` containing methods whose names are - specified in ``example_inputs``. The given methods will be compiled + specified in ``inputs``. The given methods will be compiled as a part of a single `ScriptModule`. - example_inputs (dict): A dict containing sample inputs indexed by method names in ``mod``. + inputs (dict): A dict containing sample inputs indexed by method names in ``mod``. The inputs will be passed to methods whose names correspond to inputs' keys while tracing. ``{ 'forward' : example_forward_input, 'method2': example_method2_input}``