In [10]:
import pathlib

import pytorchvideo.data

from pytorchvideo.transforms import (
    ApplyTransformToKey,
    Normalize,
    RandomShortSideScale,
    RemoveKey,
    ShortSideScale,
    UniformTemporalSubsample,
)

from torchvision.transforms import (
    Compose,
    Lambda,
    RandomCrop,
    RandomHorizontalFlip,
    Resize,
)

In [2]:
UCF101_ROOT_PATH = "/data/Users/xyq/developer/yuquan/data/UCF101_subset"

data_root_path = pathlib.Path(UCF101_ROOT_PATH)

In [3]:
video_count_train = len(list(data_root_path.glob("train/*/*.avi")))
video_count_val = len(list(data_root_path.glob("val/*/*.avi")))
video_count_test = len(list(data_root_path.glob("test/*/*.avi")))
video_total = video_count_train + video_count_val + video_count_test
video_total

405

In [4]:
all_video_file_paths = (
    list(data_root_path.glob("train/*/*.avi"))
    + list(data_root_path.glob("val/*/*.avi"))
    + list(data_root_path.glob("test/*/*.avi"))
)
all_video_file_paths[:3]

[PosixPath('/data/Users/xyq/developer/yuquan/data/UCF101_subset/train/ApplyEyeMakeup/v_ApplyEyeMakeup_g02_c03.avi'),
 PosixPath('/data/Users/xyq/developer/yuquan/data/UCF101_subset/train/ApplyEyeMakeup/v_ApplyEyeMakeup_g07_c04.avi'),
 PosixPath('/data/Users/xyq/developer/yuquan/data/UCF101_subset/train/ApplyEyeMakeup/v_ApplyEyeMakeup_g19_c02.avi')]

In [5]:
# derive the set of labels in the dataset
class_labels = sorted({str(file_path).split("/")[-2] for file_path in all_video_file_paths})
class_labels

['ApplyEyeMakeup',
 'ApplyLipstick',
 'Archery',
 'BabyCrawling',
 'BalanceBeam',
 'BandMarching',
 'BaseballPitch',
 'Basketball',
 'BasketballDunk',
 'BenchPress']

In [6]:
label2id = {label: i for i, label in enumerate(class_labels)}
id2label = dict(enumerate(class_labels))
label2id, id2label

({'ApplyEyeMakeup': 0,
  'ApplyLipstick': 1,
  'Archery': 2,
  'BabyCrawling': 3,
  'BalanceBeam': 4,
  'BandMarching': 5,
  'BaseballPitch': 6,
  'Basketball': 7,
  'BasketballDunk': 8,
  'BenchPress': 9},
 {0: 'ApplyEyeMakeup',
  1: 'ApplyLipstick',
  2: 'Archery',
  3: 'BabyCrawling',
  4: 'BalanceBeam',
  5: 'BandMarching',
  6: 'BaseballPitch',
  7: 'Basketball',
  8: 'BasketballDunk',
  9: 'BenchPress'})

In [12]:
from transformers import VideoMAEImageProcessor, VideoMAEForVideoClassification


image_processor = VideoMAEImageProcessor.from_pretrained(
    "/data/Users/xyq/developer/yuquan/model_repo/videomae-base"
)
image_processor

VideoMAEImageProcessor {
  "_valid_processor_keys": [
    "videos",
    "do_resize",
    "size",
    "resample",
    "do_center_crop",
    "crop_size",
    "do_rescale",
    "rescale_factor",
    "do_normalize",
    "image_mean",
    "image_std",
    "return_tensors",
    "data_format",
    "input_data_format"
  ],
  "crop_size": {
    "height": 224,
    "width": 224
  },
  "do_center_crop": true,
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.485,
    0.456,
    0.406
  ],
  "image_processor_type": "VideoMAEImageProcessor",
  "image_std": [
    0.229,
    0.224,
    0.225
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "shortest_edge": 224
  }
}

In [13]:
model = VideoMAEForVideoClassification.from_pretrained(
    "/data/Users/xyq/developer/yuquan/model_repo/videomae-base",
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes=True,  # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
)

Some weights of VideoMAEForVideoClassification were not initialized from the model checkpoint at /data/Users/xyq/developer/yuquan/model_repo/videomae-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [14]:
mean = image_processor.image_mean
std = image_processor.image_std
if "shortest_edge" in image_processor.size:
    height = width = image_processor.size["shortest_edge"]
else:
    height = image_processor.size["height"]
    width = image_processor.size["width"]
resize_to = (height, width)

num_frames_to_sample = model.config.num_frames
sample_rate = 4
fps = 30
clip_duration = num_frames_to_sample * sample_rate / fps

In [18]:
num_frames_to_sample

16

In [15]:
import os

train_transform = Compose(
    [
        ApplyTransformToKey(
            key="video",
            transform=Compose(
                [
                    UniformTemporalSubsample(num_frames_to_sample),
                    Lambda(lambda x: x / 255.0),
                    Normalize(mean, std),
                    RandomShortSideScale(min_size=256, max_size=320),
                    RandomCrop(resize_to),
                    RandomHorizontalFlip(p=0.5),
                ]
            ),
        ),
    ]
)

train_dataset = pytorchvideo.data.Ucf101(
    data_path=os.path.join(data_root_path, "train"),
    clip_sampler=pytorchvideo.data.make_clip_sampler("random", clip_duration),
    decode_audio=False,
    transform=train_transform,
)

In [16]:
val_transform = Compose(
    [
        ApplyTransformToKey(
            key="video",
            transform=Compose(
                [
                    UniformTemporalSubsample(num_frames_to_sample),
                    Lambda(lambda x: x / 255.0),
                    Normalize(mean, std),
                    Resize(resize_to),
                ]
            ),
        ),
    ]
)

val_dataset = pytorchvideo.data.Ucf101(
    data_path=os.path.join(data_root_path, "val"),
    clip_sampler=pytorchvideo.data.make_clip_sampler("uniform", clip_duration),
    decode_audio=False,
    transform=val_transform,
)

test_dataset = pytorchvideo.data.Ucf101(
    data_path=os.path.join(data_root_path, "test"),
    clip_sampler=pytorchvideo.data.make_clip_sampler("uniform", clip_duration),
    decode_audio=False,
    transform=val_transform,
)

# {
#                     'video': <video_tensor>,
#                     'label': <index_label>,
#                     'video_label': <index_label>
#                     'video_index': <video_index>,
#                     'clip_index': <clip_index>,
#                     'aug_index': <aug_index>,
#                 }

In [17]:
print(train_dataset.num_videos, val_dataset.num_videos, test_dataset.num_videos)

300 30 75


In [None]:
def collate_fn(examples):
    # permute to (num_frames, num_channels, height, width)
    pixel_values = torch.stack(
        [example["video"].permute(1, 0, 2, 3) for example in examples]
    )
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}