-
Notifications
You must be signed in to change notification settings - Fork 24.6k
Initial ModuleInfo implementation #61935
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 3c4ecc8 (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
|
||
# === Compare outputs to a reference if one is specified. === | ||
# TODO: Handle precision | ||
reference_fn = module_input.reference_fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting that the reference_fn is on the module input
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ModuleInput
is possibly a bad name- it started out as a module analogue for SampleInput
. We want the ability to have a reference function per (constructor args, forward args), so that was added there. I'm good with changing the name or whatever makes it less weird :p
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Names can be sorted later, this seems fine
MODULE_CLASS_NAMES[module_cls] = f'{namespace_name}.{module_name}' | ||
|
||
|
||
class modules(_TestParametrizer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although "common_device_type.py" is an increasingly silly name, we might want this to go next to the ops decorator for clarity. Might as well add a comment, even if it's only one line:
PROTOTYPE @modules decorator that instantiates tests for each module
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, in an attempt to not increase the silliness of common_device_type.py
further, I do prefer the @modules
decorator being here. This is also the most obvious place I'd think to find it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure
module_inputs = [ | ||
ModuleInput(constructor_input=FunctionInput(10, 8), | ||
forward_input=FunctionInput(make_input((4, 10))), | ||
reference_fn=lambda m, p, i: torch.mm(i, p[0].t()) + p[1].view(1, -1).expand(4, 8)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. In the future I wonder if the reference can just be a functional form of the module so you don't need to write a reference for each input?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ability to write a reference per input is actually a feature. For a given (constructor args, forward args), the reference function may be as simple or complex as needed to verify the output. I think it'd be non-ideal to require that the module is essentially reimplemented within the reference, and if we just call the existing functional form of the module, that doesn't really test much. I'd also like to use the references from module_tests
/ new_module_tests
/ criterion_tests
as much as possible, which is where these Linear ones came from.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, pros/cons, and probably not too hard to change
This is an awesome start; just need to fix lint and add the test file to run_tests.py Line 85 in db11619
|
FYI the moduleinfo test errors appear real but the polygamma test failures are in the base. |
43bc84c
to
3c4ecc8
Compare
@jbschlosser has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! An exciting start to ModuleInfos
@jbschlosser merged this pull request in a0309f8. |
Summary: Follow up to #61935 This PR adds inplace checks to `test_modules`. This version checks the constructor for `inplace` and performs the check automatically. Pull Request resolved: #63739 Reviewed By: saketh-are Differential Revision: D30737774 Pulled By: jbschlosser fbshipit-source-id: 8813534511e9296c8424d1ca878412726ddd4043
Summary: Follow up to #61935 This PR: 1. Adds test for non-contiguous tensors 2. Fixes bug in `NLLLoss` that was catch by the test. The reason this was not catch in `common_nn` is because `CriterionTest` overrides `test_cuda` but does not call `test_nonconfig`. cc albanD mruberry jbschlosser walterddr Pull Request resolved: #64954 Reviewed By: zou3519 Differential Revision: D31174149 Pulled By: jbschlosser fbshipit-source-id: a16073e59b40ccc01c82ede016b63a8db2e810f5
Summary: Follow up to #61935 This PR adds device to device transfer test into `ModuleInfo`. cc albanD mruberry jbschlosser walterddr Pull Request resolved: #65488 Reviewed By: mruberry Differential Revision: D32063662 Pulled By: jbschlosser fbshipit-source-id: 0868235a0ae7e5b6a3e4057c23fe70753c0946d2
Summary: Follow up to #61935 This PR: 1. Adds test for non-contiguous tensors 2. Fixes bug in `NLLLoss` that was catch by the test. The reason this was not catch in `common_nn` is because `CriterionTest` overrides `test_cuda` but does not call `test_nonconfig`. cc albanD mruberry jbschlosser walterddr Pull Request resolved: #64954 Reviewed By: zou3519 Differential Revision: D31174149 Pulled By: jbschlosser fbshipit-source-id: a16073e59b40ccc01c82ede016b63a8db2e810f5 (cherry picked from commit 0d3bf97) Signed-off-by: Eli Uriegas <eliuriegas@fb.com>
) * TST Adds test for non-contiguous tensors (#64954) Summary: Follow up to #61935 This PR: 1. Adds test for non-contiguous tensors 2. Fixes bug in `NLLLoss` that was catch by the test. The reason this was not catch in `common_nn` is because `CriterionTest` overrides `test_cuda` but does not call `test_nonconfig`. cc albanD mruberry jbschlosser walterddr Pull Request resolved: #64954 Reviewed By: zou3519 Differential Revision: D31174149 Pulled By: jbschlosser fbshipit-source-id: a16073e59b40ccc01c82ede016b63a8db2e810f5 (cherry picked from commit 0d3bf97) Signed-off-by: Eli Uriegas <eliuriegas@fb.com> * Cherry-pick changes from #64444 Namely, `make_weight` partial into `module_inputs_torch_nn_NLLLoss` Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> Co-authored-by: Nikita Shulga <nshulga@fb.com>
This PR contains the initial version of
ModuleInfo
for use in testing modules. The design philosophy taken here is to start small and simple and build out / refactor as needed when more test coverage orModuleInfo
entries are added. As such, it's not intended for general usage yet. The PR contains the following:torch/testing/_internal/common_modules.py
ModuleInfo
definition - metadata for each module to use in testingmodule_db
- the actualModuleInfo
database; currently contains entries for two modulesModuleInput
- analogous toSampleInput
from OpInfo; containsFunctionInput
s for both constructor and forward pass inputsModuleInput
because they are likely correlatedFunctionInput
- just contains args and kwargs to pass to a function (is there a nicer way to do this?)@modules
decorator - analogous to@ops
; specifies a set of modules to run a test overMODULE_NAMESPACES
- list of all namespaces containing modulesMODULE_CLASSES
- list of all module class objectsMODULE_CLASS_NAMES
- dict from module class object to nice name (e.g. torch.nn.Linear -> "nn.Linear")test/test_modules.py
test_forward
, which instantiates a module, runs its forward pass, and compares it to a reference, if one is defined