Skip to content

Commit

Permalink
Add support of MViTv2 video variants (#6373)
Browse files Browse the repository at this point in the history
* Extending to support MViTv2

* Fix docs, mypy and linter

* Refactor the relative positional code.

* Code refactoring.

* Rename vars.

* Update docs.

* Replace assert with exception.

* Updat docs.

* Minor refactoring.

* Remove the square input limitation.

* Moving methods around.

* Modify the shortcut in the attention layer.

* Add ported weights.

* Introduce a `residual_cls` config on the attention layer.

* Make the patch_embed kernel/padding/stride configurable.

* Apply changes from code-review.

* Remove stale todo.
  • Loading branch information
datumbox committed Aug 10, 2022
1 parent 6908129 commit 7e8186e
Show file tree
Hide file tree
Showing 4 changed files with 369 additions and 34 deletions.
3 changes: 2 additions & 1 deletion docs/source/models/video_mvit.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ The MViT model is based on the
Model builders
--------------

The following model builders can be used to instantiate a MViT model, with or
The following model builders can be used to instantiate a MViT v1 or v2 model, with or
without pre-trained weights. All the model builders internally rely on the
``torchvision.models.video.MViT`` base class. Please refer to the `source
code
Expand All @@ -24,3 +24,4 @@ more details about this class.
:template: function.rst

mvit_v1_b
mvit_v2_s
Binary file added test/expect/ModelTester.test_mvit_v2_s_expect.pkl
Binary file not shown.
3 changes: 3 additions & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,9 @@ def _check_input_backprop(model, inputs):
"mvit_v1_b": {
"input_shape": (1, 3, 16, 224, 224),
},
"mvit_v2_s": {
"input_shape": (1, 3, 16, 224, 224),
},
}
# speeding up slow models:
slow_models = [
Expand Down
Loading

0 comments on commit 7e8186e

Please sign in to comment.