-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Video SwinTransformer #6521
Conversation
I got the hang of SwinTransformer2d. Soon I will take this up. |
This was mega learning and refresher for me as I haven't at all worked on models / attention. OK so after a long wait, I have completely read both the papers swin image and swin video. I'm currently working to align the image and video codebases and would then write down the model configurations. I have a few comments though |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think couple of things should be added.
- Logging.
- torch.fx.wrap()
I would love to understand what torch.fx.wrap() does. 😄
return pad_size[0], pad_size[1], pad_size[2] | ||
|
||
|
||
def _compute_attention_mask_3d( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there is no such function in 2d case. This is just in 3d
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just FYI, there is an equivalent in these lines: https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py#L193-L207
But yeah, we dont wrap them as separate function in 2d implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can refactor the 2d implementation a bit, the 3d looks much better here 😄
shift_size[i] = 0 | ||
|
||
window_vol = window_size[0] * window_size[1] * window_size[2] | ||
relative_position_bias = self.relative_position_bias_table[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And this too can be rewritten
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, for V2 we should separate this into separate function
if norm_layer is None: | ||
norm_layer = partial(nn.LayerNorm, eps=1e-5) | ||
|
||
if patch_embed is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Somehow I am not convinced with this style of writing both in 2d and 3d, why can't we give defaults in the function definition? Is there any reason?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I follow the 2d version style for this, I guess for the previous one it kinda make sense if we use partial(...)
, but I agree that we might as well put this as default directly in the function definition for patch_embed
. Can't really think the advantage using this style now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hope we can refactor the implementation, but it might be BC breaking. (minor bc break though)
|
||
# Modified from: | ||
# https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/mmaction/models/backbones/swin_transformer.py | ||
class PatchEmbed3d(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious what was equivalent to this in 2d case and where is the codeblock for same?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This actually implemented directly in the main class in the 2d version:
https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py#L548-L557
In the 3d version we need to separate this into separate class because when Omnivore
use this swin3d as encoder, it need to use different way for creating patch embedding: https://github.com/facebookresearch/multimodal/blob/main/torchmultimodal/models/omnivore.py#L144
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wish we can separate in the 2d version too. It will make the paper implementation bit more clear as the block diagram illustrated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @oke-aditya , nice work! For this first review, I address some of your question and concern. I will do further review tomorrow :)
return pad_size[0], pad_size[1], pad_size[2] | ||
|
||
|
||
def _compute_attention_mask_3d( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just FYI, there is an equivalent in these lines: https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py#L193-L207
But yeah, we dont wrap them as separate function in 2d implementation
size_dhw = (d, h, w) | ||
window_size, shift_size = self.window_size.copy(), self.shift_size.copy() | ||
# Handle case where window_size is larger than the input tensor | ||
for i in range(3): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the only way to remove the loop is to do comparison one by one (3 if statement).
Btw, take note that on this specific logic, it is actually difference between video swin transformer and the 2d.
On our 2d implementation, we dont change the window size but only set the shift to be 0.
On the original 3d implementation, we actually resize the window_size like how it is implemented here.
shift_size[i] = 0 | ||
|
||
window_vol = window_size[0] * window_size[1] * window_size[2] | ||
relative_position_bias = self.relative_position_bias_table[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, for V2 we should separate this into separate function
|
||
# Modified from: | ||
# https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/mmaction/models/backbones/swin_transformer.py | ||
class PatchEmbed3d(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This actually implemented directly in the main class in the 2d version:
https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py#L548-L557
In the 3d version we need to separate this into separate class because when Omnivore
use this swin3d as encoder, it need to use different way for creating patch embedding: https://github.com/facebookresearch/multimodal/blob/main/torchmultimodal/models/omnivore.py#L144
if norm_layer is not None: | ||
self.norm = norm_layer(embed_dim) | ||
else: | ||
self.norm = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we should use nn.Identity
to make mypy happy?
if norm_layer is None: | ||
norm_layer = partial(nn.LayerNorm, eps=1e-5) | ||
|
||
if patch_embed is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I follow the 2d version style for this, I guess for the previous one it kinda make sense if we use partial(...)
, but I agree that we might as well put this as default directly in the function definition for patch_embed
. Can't really think the advantage using this style now.
Ok so this seems almost done. I guess we are ready to port weights. Can you verify the configurations @YosuaMichael ? |
Just a quick update. cc @datumbox |
@datumbox |
Thanks @oke-aditya for the ported weight. I modify the code from this branch a bit by adding the weight and transforms to mimic the original paper. Then I ran the weight to do testing on Kinetics-400 validation data using the following command:
Note: I use clips-per-video==12 to simulate the paper that use 4 spatial x 3 temporal clips since as of now we dont have the spatial clips capability. Thanks @datumbox suggestion for this. Where the pair of the model_name and weight_name with their corresponding accuracy are as follow:
|
@YosuaMichael Looks good. We should add the info about the input parameters passed to the ref script on the docs similar to what we do for MViT: vision/torchvision/models/video/mvit.py Lines 616 to 619 in cffb7f7
|
Nice, I will wrap this over the weekend! |
I kind of disagree to the name
Swag one
I would suggest for KINETICS we do
and
Edit: Also I suggest for consistency the default weights for |
Btw also how do I get these 3 fields?
|
@oke-aditya Thanks for the update.
Sounds good to me. For For Finally for |
Looks good @oke-aditya , I agree with the naming |
So I checked the min_size and temporal size. It seems even (1,1) would work fine. And 1 for temporal size is fine too
|
@YosuaMichael @datumbox I have added the FLOPs and weights size by running the script in #6936 I think we should be good to merge. |
@oke-aditya there seems lint error about the typing https://app.circleci.com/pipelines/github/pytorch/vision/21954/workflows/1179cdc7-f7db-4700-87c7-e45ef42af46c/jobs/1777451 |
This should be good to go. CI failure is unrelated |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the hard work @oke-aditya, it looks good now!
The test failures are unrelated so we merged. |
Summary: * Just start adding mere copy paste * Replace d with t and D with T * Align swin transformer video to image a bit * Rename d -> t * align with 2d impl * align with 2d impl * Add helpful comments and config for 3d * add docs * Add docs * Add configurations * Add docs * Fix bugs * Fix wrong edit * Fix wrong edit * Fix bugs * Fix bugs * Fix as per fx suggestions * Update torchvision/models/video/swin_transformer.py * Fix as per fx suggestions * Fix expect files and code * Update the expect files * Modify video swin * Add min size and min temporal size, num params * Add flops and size * Fix types * Fix url recipe Reviewed By: YosuaMichael Differential Revision: D41376277 fbshipit-source-id: 00ec3c40b12dff7d6404af7c327e6fc209fc6618 Co-authored-by: Yosua Michael Maranatha <yosuamichael@fb.com>
Closes #6499
I think ready for a initial round of review. Will love to understand a bit more and then continue.
I will mostly to wrap up this before the end of month. Little bit more work is pending 😄