Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Video SwinTransformer #6521

Merged
merged 36 commits into from
Nov 17, 2022
Merged

Add Video SwinTransformer #6521

merged 36 commits into from
Nov 17, 2022

Conversation

oke-aditya
Copy link
Contributor

@oke-aditya oke-aditya commented Aug 30, 2022

Closes #6499

I think ready for a initial round of review. Will love to understand a bit more and then continue.

I will mostly to wrap up this before the end of month. Little bit more work is pending 😄

@oke-aditya
Copy link
Contributor Author

I got the hang of SwinTransformer2d. Soon I will take this up.

@oke-aditya
Copy link
Contributor Author

This was mega learning and refresher for me as I haven't at all worked on models / attention.
(My ML skills probably are waning as I don't much work on it these days)

OK so after a long wait, I have completely read both the papers swin image and swin video. I'm currently working to align the image and video codebases and would then write down the model configurations. I have a few comments though

Copy link
Contributor Author

@oke-aditya oke-aditya left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think couple of things should be added.

  1. Logging.
  2. torch.fx.wrap()

I would love to understand what torch.fx.wrap() does. 😄

return pad_size[0], pad_size[1], pad_size[2]


def _compute_attention_mask_3d(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is no such function in 2d case. This is just in 3d

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just FYI, there is an equivalent in these lines: https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py#L193-L207

But yeah, we dont wrap them as separate function in 2d implementation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can refactor the 2d implementation a bit, the 3d looks much better here 😄

torchvision/models/video/swin_transformer.py Show resolved Hide resolved
torchvision/models/video/swin_transformer.py Show resolved Hide resolved
torchvision/models/video/swin_transformer.py Show resolved Hide resolved
shift_size[i] = 0

window_vol = window_size[0] * window_size[1] * window_size[2]
relative_position_bias = self.relative_position_bias_table[
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And this too can be rewritten

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, for V2 we should separate this into separate function

torchvision/models/video/swin_transformer.py Show resolved Hide resolved
torchvision/models/video/swin_transformer.py Show resolved Hide resolved
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-5)

if patch_embed is None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somehow I am not convinced with this style of writing both in 2d and 3d, why can't we give defaults in the function definition? Is there any reason?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I follow the 2d version style for this, I guess for the previous one it kinda make sense if we use partial(...), but I agree that we might as well put this as default directly in the function definition for patch_embed. Can't really think the advantage using this style now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope we can refactor the implementation, but it might be BC breaking. (minor bc break though)


# Modified from:
# https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/mmaction/models/backbones/swin_transformer.py
class PatchEmbed3d(nn.Module):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious what was equivalent to this in 2d case and where is the codeblock for same?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually implemented directly in the main class in the 2d version:
https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py#L548-L557

In the 3d version we need to separate this into separate class because when Omnivore use this swin3d as encoder, it need to use different way for creating patch embedding: https://github.com/facebookresearch/multimodal/blob/main/torchmultimodal/models/omnivore.py#L144

Copy link
Contributor Author

@oke-aditya oke-aditya Sep 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wish we can separate in the 2d version too. It will make the paper implementation bit more clear as the block diagram illustrated.

torchvision/models/video/swin_transformer.py Show resolved Hide resolved
@oke-aditya oke-aditya marked this pull request as ready for review September 18, 2022 19:41
Copy link
Contributor

@YosuaMichael YosuaMichael left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @oke-aditya , nice work! For this first review, I address some of your question and concern. I will do further review tomorrow :)

return pad_size[0], pad_size[1], pad_size[2]


def _compute_attention_mask_3d(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just FYI, there is an equivalent in these lines: https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py#L193-L207

But yeah, we dont wrap them as separate function in 2d implementation

torchvision/models/video/swin_transformer.py Show resolved Hide resolved
torchvision/models/video/swin_transformer.py Show resolved Hide resolved
torchvision/models/video/swin_transformer.py Show resolved Hide resolved
size_dhw = (d, h, w)
window_size, shift_size = self.window_size.copy(), self.shift_size.copy()
# Handle case where window_size is larger than the input tensor
for i in range(3):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the only way to remove the loop is to do comparison one by one (3 if statement).

Btw, take note that on this specific logic, it is actually difference between video swin transformer and the 2d.
On our 2d implementation, we dont change the window size but only set the shift to be 0.

On the original 3d implementation, we actually resize the window_size like how it is implemented here.

shift_size[i] = 0

window_vol = window_size[0] * window_size[1] * window_size[2]
relative_position_bias = self.relative_position_bias_table[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, for V2 we should separate this into separate function


# Modified from:
# https://github.com/SwinTransformer/Video-Swin-Transformer/blob/master/mmaction/models/backbones/swin_transformer.py
class PatchEmbed3d(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This actually implemented directly in the main class in the 2d version:
https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py#L548-L557

In the 3d version we need to separate this into separate class because when Omnivore use this swin3d as encoder, it need to use different way for creating patch embedding: https://github.com/facebookresearch/multimodal/blob/main/torchmultimodal/models/omnivore.py#L144

if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should use nn.Identity to make mypy happy?

if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-5)

if patch_embed is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I follow the 2d version style for this, I guess for the previous one it kinda make sense if we use partial(...), but I agree that we might as well put this as default directly in the function definition for patch_embed. Can't really think the advantage using this style now.

@oke-aditya
Copy link
Contributor Author

Ok so this seems almost done. I guess we are ready to port weights. Can you verify the configurations @YosuaMichael ?

@oke-aditya oke-aditya changed the title [WIP] Add Video SwinTransformer Add Video SwinTransformer Sep 24, 2022
@oke-aditya
Copy link
Contributor Author

Just a quick update.
We plan to continue working on this after the release. This will give us sufficient time to port all the weights!

cc @datumbox

@oke-aditya
Copy link
Contributor Author

@datumbox
Quick update, @YosuaMichael and I have resumed work on this. @YosuaMichael might kick off a run to check for initial set of weights. Hopefully we are in luck.

@YosuaMichael
Copy link
Contributor

Thanks @oke-aditya for the ported weight. I modify the code from this branch a bit by adding the weight and transforms to mimic the original paper. Then I ran the weight to do testing on Kinetics-400 validation data using the following command:

python -u ~/script/run_with_submitit.py \
    --timeout 3000 --ngpus 8 --nodes 1 \
    --data-path="/datasets01/kinetics/070618/400/" \
    --batch-size=16 --test-only \
    --clip-len 32 --frame-rate 15 --clips-per-video 12 \
    --cache-dataset \
    --model {model_name} --weights="{weight_name}"

Note: I use clips-per-video==12 to simulate the paper that use 4 spatial x 3 temporal clips since as of now we dont have the spatial clips capability. Thanks @datumbox suggestion for this.

Where the pair of the model_name and weight_name with their corresponding accuracy are as follow:

  1. model_name=swin3d_t, weight_name=Swin3D_T_Weights.KINETICS400_V1
 * Video Acc@1 77.715 Video Acc@5 93.519
  1. model_name=swin3d_s, weight_name=Swin3D_S_Weights.KINETICS400_V1
 * Video Acc@1 79.521 Video Acc@5 94.158
  1. model_name=swin3d_b, weight_name=Swin3D_B_Weights.KINETICS400_V1
 * Video Acc@1 79.427 Video Acc@5 94.386
  1. model_name=swin3d_b, weight_name=Swin3D_B_Weights.IN22K_KINETICS400_V1
 * Video Acc@1 81.643 Video Acc@5 95.574

@datumbox
Copy link
Contributor

datumbox commented Nov 3, 2022

@YosuaMichael Looks good. We should add the info about the input parameters passed to the ref script on the docs similar to what we do for MViT:

"_docs": (
"The weights were ported from the paper. The accuracies are estimated on video-level "
"with parameters `frame_rate=7.5`, `clips_per_video=5`, and `clip_len=16`"
),

@oke-aditya
Copy link
Contributor Author

Nice, I will wrap this over the weekend!

@oke-aditya
Copy link
Contributor Author

oke-aditya commented Nov 5, 2022

I kind of disagree to the name IN22K_KINETICS400_V1 Since for Swag models we put the names for other pretrained dataset as last.
E.g
Default one

 IMAGENET1K_V2 = Weights(

Swag one

IMAGENET1K_SWAG_E2E_V1 = Weights(

I would suggest for KINETICS we do

KINETICS400_V1 

and

KINETICS400_IMAGENET22K_V1 = Weights(

Edit:

Also I suggest for consistency the default weights for swin_b should be on IMAGENET_1K rather than IMAGENET_22K

@oke-aditya
Copy link
Contributor Author

Btw also how do I get these 3 fields?

 "min_size": (224, 224),
 "min_temporal_size": 16,
"num_params": 36610672,

@datumbox
Copy link
Contributor

datumbox commented Nov 7, 2022

@oke-aditya Thanks for the update.

KINETICS400_IMAGENET22K_V1

Sounds good to me.

For min_size its the minimum spatial size that can be sent to your architecture without throwing error. Sometimes maxpooling or conv layers impose a minimum size below which the network returns an error. I usually try a couple of values until I find the one that breaks the architecture and then set the min_size to the next available permitted value.

For min_temporal_size it's exactly the same but for the temporal T dimension.

Finally for num_params, just do sum(p.numel() for p in model.parameters()) to estimate them.

@YosuaMichael
Copy link
Contributor

Looks good @oke-aditya , I agree with the naming KINETICS_IMAGENET22K_V1!
For min_size, temporal_size, and num_params it is as @datumbox explained.
I think swin transformer 3d dont really have min_size and min_temporal_size since we do padding in PatchEmbed3d (this means the min_size=(1,1) and min_temporal_size=1), however please give it a try first.

@oke-aditya
Copy link
Contributor Author

So I checked the min_size and temporal size. It seems even (1,1) would work fine. And 1 for temporal size is fine too

>>> from torchvision.models.video import swin3d_b
>>> v_tensor = torch.randn(1, 3, 1, 1, 1) # b c t h w
>>> v_tensor = v_tensor.cuda()
>>> model = swin3d_b()
>>> out = model(v_tensor)

@oke-aditya
Copy link
Contributor Author

@YosuaMichael @datumbox I have added the FLOPs and weights size by running the script in #6936

I think we should be good to merge.

@YosuaMichael
Copy link
Contributor

@oke-aditya
Copy link
Contributor Author

This should be good to go. CI failure is unrelated

Copy link
Contributor

@YosuaMichael YosuaMichael left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the hard work @oke-aditya, it looks good now!

@YosuaMichael YosuaMichael merged commit b1054cb into pytorch:main Nov 17, 2022
@YosuaMichael
Copy link
Contributor

The test failures are unrelated so we merged.

@oke-aditya oke-aditya deleted the add_videoswin branch November 17, 2022 09:56
facebook-github-bot pushed a commit that referenced this pull request Nov 19, 2022
Summary:
* Just start adding mere copy paste

* Replace d with t and D with T

* Align swin transformer video to image a bit

* Rename d -> t

* align with 2d impl

* align with 2d impl

* Add helpful comments and config for 3d

* add docs

* Add docs

* Add configurations

* Add docs

* Fix bugs

* Fix wrong edit

* Fix wrong edit

* Fix bugs

* Fix bugs

* Fix as per fx suggestions

* Update torchvision/models/video/swin_transformer.py

* Fix as per fx suggestions

* Fix expect files and code

* Update the expect files

* Modify video swin

* Add min size and min temporal size, num params

* Add flops and size

* Fix types

* Fix url recipe

Reviewed By: YosuaMichael

Differential Revision: D41376277

fbshipit-source-id: 00ec3c40b12dff7d6404af7c327e6fc209fc6618

Co-authored-by: Yosua Michael Maranatha <yosuamichael@fb.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Port SwinTransformer3d from torchmultimodal
4 participants