When a model in TorchBenchmark raises NotImplementedError
in the get_module()
function, it will skip all the unit tests!
This happens because test.py will accept and skip the exception, NotImplementedError. These are the reasons why every test is skipped:
- check_example
- check_device
- check_example and check_device directly calls
model.get_module()
, which may raise NotImplementedError
.
- train
- eval
- train and eval unit tests call
set_train()
and set_eval()
respectively before running the method
set_train/eval()
will call _set_mode()
, where (model, _) = self.get_module()
is called and may raise NotImplementedError
.
For 1 and 2, it may make sense to skip them since their functionality requires get_module()
. However for 3 and 4, we may choose:
- Overwrite
_set_mode()
to something that works (or pass).
- Skip
_set_mode()
and directly run train and eval.
- Require every model to implement get_module(). (This would also fix 1 and 2).