Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support of MViTv2 video variants #6373

Merged
merged 21 commits into from
Aug 10, 2022
Merged

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Aug 5, 2022

This is the continuation of the work from #6198 to finalize the API of MViT class. The PR extends TorchVision's existing MViT architecture to support v2 variants.

The mvit_v2_s variant introduced is canonical and the weights are ported from the paper. This is based on the work of @lyttonhao, @haooooooqi and @feichtenhofer on SlowFast.

Verification process

Comparing outputs

To confirm that the implementation is compatible with the original from SlowFast we create a weight converter, load the same weights for both implementations and compare them against the same input:

import collections
import tempfile
import torch
from torchvision.models.video import mvit_v2_s
from slowfast.config.defaults import assert_and_infer_cfg, get_cfg
from slowfast.models.video_model_builder import MViT
from slowfast.utils.parser import load_config


def mvit_v2_s_slowfast():
    config = """
    DATA:
      NUM_FRAMES: 16
      SAMPLING_RATE: 4
      TRAIN_CROP_SIZE: 224
      TEST_CROP_SIZE: 224
      INPUT_CHANNEL_NUM: [3]
    MVIT:
      ZERO_DECAY_POS_CLS: False
      USE_ABS_POS: False
      REL_POS_SPATIAL: True
      REL_POS_TEMPORAL: True
      DEPTH: 16
      NUM_HEADS: 1
      EMBED_DIM: 96
      PATCH_KERNEL: (3, 7, 7)
      PATCH_STRIDE: (2, 4, 4)
      PATCH_PADDING: (1, 3, 3)
      MLP_RATIO: 4.0
      QKV_BIAS: True
      DROPPATH_RATE: 0.2
      NORM: "layernorm"
      MODE: "conv"
      CLS_EMBED_ON: True
      DIM_MUL: [[1, 2.0], [3, 2.0], [14, 2.0]]
      HEAD_MUL: [[1, 2.0], [3, 2.0], [14, 2.0]]
      POOL_KVQ_KERNEL: [3, 3, 3]
      POOL_KV_STRIDE_ADAPTIVE: [1, 8, 8]
      POOL_Q_STRIDE: [[0, 1, 1, 1], [1, 1, 2, 2], [2, 1, 1, 1], [3, 1, 2, 2], [4, 1, 1, 1], [5, 1, 1, 1], [6, 1, 1, 1], [7, 1, 1, 1], [8, 1, 1, 1], [9, 1, 1, 1], [10, 1, 1, 1], [11, 1, 1, 1], [12, 1, 1, 1], [13, 1, 1, 1], [14, 1, 2, 2], [15, 1, 1, 1]]
      DROPOUT_RATE: 0.0
      DIM_MUL_IN_ATT: True
      RESIDUAL_POOLING: True
    """
    temp = tempfile.NamedTemporaryFile(mode="w+", delete=False)
    try:
        temp.write(config)
    finally:
        temp.close()

    cfg = get_cfg()
    cfg.merge_from_file(temp.name)
    cfg = assert_and_infer_cfg(cfg)

    cfg.NUM_GPUS = 0
    model = MViT(cfg)
    model.head.act = torch.nn.Identity()
    return model


def slowfast_to_tv_weights(state_dict):
    d = dict(state_dict)

    # remapping keys
    mapping = collections.OrderedDict(
        [
            ("patch_embed.project.weight", "conv_proj.weight"),
            ("patch_embed.project.bias", "conv_proj.bias"),
            ("cls_token", "pos_encoding.class_token"),
            ("pos_embed_spatial", "pos_encoding.spatial_pos"),
            ("pos_embed_temporal", "pos_encoding.temporal_pos"),
            ("pos_embed_class", "pos_encoding.class_pos"),
            ("attn.proj.weight", "attn.project.0.weight"),
            ("attn.proj.bias", "attn.project.0.bias"),
            ("attn.pool_q.weight", "attn.pool_q.pool.weight"),
            ("attn.norm_q.weight", "attn.pool_q.norm_act.0.weight"),
            ("attn.norm_q.bias", "attn.pool_q.norm_act.0.bias"),
            ("attn.pool_k.weight", "attn.pool_k.pool.weight"),
            ("attn.norm_k.weight", "attn.pool_k.norm_act.0.weight"),
            ("attn.norm_k.bias", "attn.pool_k.norm_act.0.bias"),
            ("attn.pool_v.weight", "attn.pool_v.pool.weight"),
            ("attn.norm_v.weight", "attn.pool_v.norm_act.0.weight"),
            ("attn.norm_v.bias", "attn.pool_v.norm_act.0.bias"),
            ("mlp.fc1.weight", "mlp.0.weight"),
            ("mlp.fc1.bias", "mlp.0.bias"),
            ("mlp.fc2.weight", "mlp.3.weight"),
            ("mlp.fc2.bias", "mlp.3.bias"),
            ("head.projection.weight", "head.1.weight"),
            ("head.projection.bias", "head.1.bias"),
            ("patch_embed.proj.weight", "conv_proj.weight"),
            ("patch_embed.proj.bias", "conv_proj.bias"),
            ("norm.weight", "norm.weight"),
            ("norm.bias", "norm.bias"),
            ("proj.weight", "project.weight"),
            ("proj.bias", "project.bias"),
        ]
    )
    for k in list(d.keys()):
        for pattern, replacement in mapping.items():
            if pattern in k:
                new_key = k.replace(pattern, replacement)
                d[new_key] = d.pop(k)
                break

    # matching dimensions
    d["pos_encoding.class_token"] = d["pos_encoding.class_token"][0, 0, :]
    if "pos_encoding.spatial_pos" in d:
        d["pos_encoding.spatial_pos"] = d["pos_encoding.spatial_pos"][0, :]
        d["pos_encoding.temporal_pos"] = d["pos_encoding.temporal_pos"][0, :]
        d["pos_encoding.class_pos"] = d["pos_encoding.class_pos"][0, 0, :]

    return d


def compare_models(sf_model_fn, tv_model_fn, input_shape):
    print(tv_model_fn.__name__)
    x = torch.randn(input_shape)

    sf_m = sf_model_fn().eval()
    exp_result = sf_m([x]).sum()

    d = sf_m.state_dict()
    d = slowfast_to_tv_weights(d)

    tv_m = tv_model_fn().eval()
    tv_m.load_state_dict(d)
    result = tv_m(x).sum()

    torch.testing.assert_close(result, exp_result)
    print("OK")


compare_models(mvit_v2_s_slowfast, mvit_v2_s, (1, 3, 16, 224, 224))

Benchmarks

To ensure that we don't introduce any speed regression we test the speed as follows:

import time


def benchmark(model_fn, input_shape, device, put_in_list, n=5, warmup=0.1):
    torch.manual_seed(42)
    m = model_fn().to(device).eval()
    x = torch.randn(input_shape).to(device)
    if put_in_list:
        x = [x]

    s = []
    for i in range(n):
        start = time.time()
        m(x)
        t = time.time() - start
        if i > n * warmup:
            s.append(t)

    print(model_fn.__name__, torch.tensor(s).median())


device = "cuda"
batch_size = 4
n = 100

print(f"device={device}, batch_size={batch_size}, n={n}")
for name, fn, put_in_list in [("TorchVision", mvit_v2_s, False), ("SlowFast", mvit_v2_s_slowfast, True)]:
    print(name)
    benchmark(fn, (batch_size, 3, 16, 224, 224), device, put_in_list, n=n)

This was tested on an A100 and as we see below the implementation is 5% faster than the original:

device=cuda, batch_size=4, n=100
TorchVision
mvit_v2_s tensor(0.0492)
SlowFast
mvit_v2_s_slowfast tensor(0.0520)

Accuracy

To verify the accuracy of the model we run the following:

torchrun --nproc_per_node=8 train.py --data-path="/datasets/clean_kinetics_400/" \ 
--batch-size=16 --test-only --cache-dataset \
--clip-len 16 --frame-rate 8 --clips-per-video 5 \
--model mvit_v2_s --weights="MViT_V2_S_Weights.DEFAULT"
 * Clip Acc@1 72.914 Clip Acc@5 89.507
 * Video Acc@1 80.757 Video Acc@5 94.665

Note that the reporting Acc@1 is a bit lower than the one of the paper but this is due to the version of the dataset that we use to assess the model (some corrupted videos are removed). To ensure that the accuracy of TorchVision's implementation is not lagging, we are testing the same data and weights using Slowfast reference scripts:

INFO:slowfast.utils.logging:json_stats: {"split": "test_final", "top1_acc": "80.79", "top5_acc": "94.66"}

As we can see the accuracies are practically the same, with minor differences caused by differences on the VideoClip sampling mechanism.

@datumbox

This comment was marked as outdated.

@datumbox datumbox marked this pull request as ready for review August 9, 2022 19:30
@datumbox datumbox changed the title [WIP] Add support for MViTv2 Add support for MViTv2 Aug 9, 2022
@datumbox datumbox requested a review from jdsgomes August 9, 2022 19:32
torchvision/models/video/mvit.py Show resolved Hide resolved
torchvision/models/video/mvit.py Outdated Show resolved Hide resolved
@datumbox datumbox changed the title Add support for MViTv2 Add support of MViTv2 video variants Aug 10, 2022
Copy link
Contributor

@jdsgomes jdsgomes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

@datumbox datumbox merged commit 7e8186e into pytorch:main Aug 10, 2022
Batteries Included - Phase 3 automation moved this from In progress to Done Aug 10, 2022
@datumbox datumbox deleted the models/mvitv2 branch August 10, 2022 12:57
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Aug 11, 2022
Addresses some [breakages](https://github.com/pytorch/pytorch/runs/7782559841?check_suite_focus=true) from #82560

Context: The tests are breaking because a new architecture was added in TorchVision (see pytorch/vision#6373) that requires a different input size. This PR addresses it by using the right size for the `mvit_v2_s` architecture.
Pull Request resolved: #83242
Approved by: https://github.com/ezyang
facebook-github-bot pushed a commit that referenced this pull request Aug 23, 2022
Summary:
* Extending to support MViTv2

* Fix docs, mypy and linter

* Refactor the relative positional code.

* Code refactoring.

* Rename vars.

* Update docs.

* Replace assert with exception.

* Updat docs.

* Minor refactoring.

* Remove the square input limitation.

* Moving methods around.

* Modify the shortcut in the attention layer.

* Add ported weights.

* Introduce a `residual_cls` config on the attention layer.

* Make the patch_embed kernel/padding/stride configurable.

* Apply changes from code-review.

* Remove stale todo.

Reviewed By: datumbox

Differential Revision: D38824226

fbshipit-source-id: 2950997bb37e431d76a0480b5b938b15b1d5eeaf
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Development

Successfully merging this pull request may close these issues.

None yet

4 participants