torchvisionのdatasetを使ってUFC101を読み込み，pytorchvideoのx3dモデルをスクラッチで学習してみる．

UFC101はあらかじめダウンロードして展開済みであるとする．

In [1]:
#
# torchvisionをimportした後ではエラーが発生する（ImportError: cannot import name ***）
# https://github.com/pytorch/hub/issues/46
# 対応策：import torch直後に（import torchvisionをしない状態で）torch.hub.loadして，キャッシュに残しておく
# こうすると，以降はキャッシュ（~/.cache/torch/hub/checkpoints/）が使われるのでエラーは発生しない


# https://pytorch.org/hub/facebookresearch_pytorchvideo_x3d/
import torch
model = torch.hub.load('facebookresearch/pytorchvideo', 'x3d_xs', pretrained=True)
model = torch.hub.load('facebookresearch/pytorchvideo', 'x3d_s', pretrained=True)
model = torch.hub.load('facebookresearch/pytorchvideo', 'x3d_m', pretrained=True)


Using cache found in /home/tamaki/.cache/torch/hub/facebookresearch_pytorchvideo_master
Using cache found in /home/tamaki/.cache/torch/hub/facebookresearch_pytorchvideo_master
Using cache found in /home/tamaki/.cache/torch/hub/facebookresearch_pytorchvideo_master


In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from torch.utils.data import DistributedSampler, RandomSampler

from torchvision.models import resnet18
from torchvision import transforms
# from torchvision.datasets import UCF101

from pytorchvideo.models import x3d
from pytorchvideo.data import Ucf101, RandomClipSampler, UniformClipSampler


from pytorchvideo.transforms import (
    ApplyTransformToKey,
    Normalize,
    RandomShortSideScale,
    RemoveKey,
    ShortSideScale,
    UniformTemporalSubsample,
)
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Lambda,
    RandomCrop,
    RandomHorizontalFlip,
)


import torchinfo

from tqdm.notebook import tqdm
import itertools
import os
import pickle

argparseを真似たパラメータ設定．
- rootで指定したディレクトリには，101クラスのサブディレクトリがあること
- annotation_pathにはtrainlist0{1,2,3}.txtなどがあること

In [3]:
class Args:
    def __init__(self):
        self.metadata_path = '/mnt/HDD10TB/dataset/UFC101/'
        self.root = '/mnt/HDD10TB/dataset/UFC101/video/'
        self.annotation_path = '/mnt/HDD10TB/dataset/UFC101/ucfTrainTestlist/'
        self.frames_per_clip = 16
        self.step_between_clips = 16
        self.model = 'X3D'
        self.batch_size = 8
        self.num_workers = 24

        self.video_num_subsampled = 16  # 16枚抜き出す

args = Args()

In [4]:
train_transform = Compose([
    ApplyTransformToKey(
        key="video",
        transform=Compose([
                UniformTemporalSubsample(args.video_num_subsampled),
                transforms.Lambda(lambda x: x / 255.),
                Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
                # transforms.Lambda(lambda x: [
                #     x, 
                #     print(type(x)),
                #     print(x.dtype),
                #     print(x.max()),
                #     print(x.min()),
                #     print(x.mean()),
                #     ]),
                # transforms.Lambda(lambda x: x[0]),
                RandomShortSideScale(min_size=256, max_size=320,),
                RandomCrop(224),
                RandomHorizontalFlip(),
        ]),
    ),
    ApplyTransformToKey(
        key="label",
        transform=transforms.Lambda(lambda x: x - 1),
    ),
    RemoveKey("audio"),
])

val_transform = Compose([
    ApplyTransformToKey(
        key="video",
        transform=Compose([
                UniformTemporalSubsample(args.video_num_subsampled),
                transforms.Lambda(lambda x: x / 255.),
                Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
                # RandomShortSideScale(min_size=256, max_size=320,),
                ShortSideScale(256),
                # RandomCrop(224),
                CenterCrop(224),
                # RandomHorizontalFlip(),
        ]),
    ),
    RemoveKey("audio"),
])



In [28]:
train_set = Ucf101(
    data_path='/mnt/HDD10TB/dataset/UFC101/ucfTrainTestlist/trainlist01.txt',  # ラベルが1から101になっているので，transformで1を引いている
    video_path_prefix='/mnt/HDD10TB/dataset/UFC101/video',
    # clip_sampler=UniformClipSampler(clip_duration=16/25),  # 25FPSを想定して16枚
    clip_sampler=RandomClipSampler(clip_duration=16/25),  # 25FPSを想定して16枚
    video_sampler=RandomSampler,
    decode_audio=False,
    transform=train_transform,
    )
val_set = Ucf101(
    data_path='/mnt/HDD10TB/dataset/UFC101/ucfTrainTestlist/testlist01.txt',
    video_path_prefix='/mnt/HDD10TB/dataset/UFC101/video',
    clip_sampler=RandomClipSampler(clip_duration=16/25),  # 25FPSを想定して16枚
    video_sampler=RandomSampler,
    decode_audio=False,
    transform=val_transform,
    )

In [29]:
train_set.num_videos

9537

In [30]:
val_set.num_videos

3783

In [31]:
# https://github.com/facebookresearch/pytorchvideo/blob/ef2d3a96bb939b12aa0f21fb467d2175b0f05e9f/tutorials/video_classification_example/train.py#L343
class LimitDataset(torch.utils.data.Dataset):
    """
    To ensure a constant number of samples are retrieved from the dataset we use this
    LimitDataset wrapper. This is necessary because several of the underlying videos
    may be corrupted while fetching or decoding, however, we always want the same
    number of steps per epoch.
    """

    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
        self.dataset_iter = itertools.chain.from_iterable(
            itertools.repeat(iter(dataset), 2)
        )

    def __getitem__(self, index):
        return next(self.dataset_iter)

    def __len__(self):
        return self.dataset.num_videos

In [32]:
train_loader = DataLoader(LimitDataset(train_set),
                            batch_size=args.batch_size,
                            drop_last=True,
                            num_workers=args.num_workers)
val_loader = DataLoader(LimitDataset(val_set),
                            batch_size=args.batch_size,
                            drop_last=True,
                            num_workers=args.num_workers)


In [33]:
train_loader, train_set

(<torch.utils.data.dataloader.DataLoader at 0x7f6ae82d4a30>,
 <pytorchvideo.data.labeled_video_dataset.LabeledVideoDataset at 0x7f6b3877b610>)

In [34]:
len(train_loader), train_set.num_videos, train_set.num_videos / args.batch_size

(1192, 9537, 1192.125)

In [35]:
batch = next(iter(train_set))
print(batch.keys())
print(batch['video'].shape)
print(batch['video'][0, 0, :5, :5])

dict_keys(['video', 'video_name', 'video_index', 'clip_index', 'aug_index', 'label'])
torch.Size([3, 16, 224, 224])
tensor([[-2., -2., -2., -2., -2.],
        [-2., -2., -2., -2., -2.],
        [-2., -2., -2., -2., -2.],
        [-2., -2., -2., -2., -2.],
        [-2., -2., -2., -2., -2.]])


In [38]:
for i, batch in enumerate(train_loader):
    print(batch['label'].cpu().numpy())
    if i > 100:
        break

[91 73  1 50 35 61 93 20]
[52 95 12 19 17 83 41 71]
[10 44  7 72 31 33 87  0]
[78 90 87 73 77 96 33 89]
[57 28 51 72 23 15 65 94]
[30 32 83  9 74 87 34 79]
[85 38 55 86 56 84  2 92]
[ 3 52 56 39 31 80 58 32]
[69  8 76 15 32 61  3 52]
[39 68 94 96 94 88  0 58]
[82 68 95 51 22 97 67 74]
[51  6  3 20 87 16 51 59]
[68 77 48 72 14 95 98 81]
[ 3 32 50 98 16  9 87 34]
[11 43 40 97 98 40 15 48]
[18 31 45 88 97 69 29 83]
[ 0 39 77 36 77 96 26 31]
[ 55  94  94  30 100  80  28  14]
[69 18 63 83  6 82 45 98]
[16 69 83 42 93 23 90 37]
[ 2 70 41 17 47  2  2 60]
[ 28 100  58   6  43  82  61  89]
[85 64 88 58 67 22  8 10]
[41 60 68 16 86 31 55  6]
[62 48 82 17 66 94 85 77]
[41  2 90 38 50 49  0 53]
[34 41 96 64 12 36 11 40]
[83 73 93  7 98 46 37 93]
[72 98 12 52  6 63 87 22]
[48 60 77 99 25 48 26 43]
[48 30 50  3 81 85 55 63]
[90 77 59 89 46 49  5 34]
[87 98 33 60 24 88 88 43]
[86 13 30 77 15 39 42 17]
[98  9  5 40 19 68 93 77]
[11 30 20 86 36  4 91 22]
[79 92 26 48 73 19 93 34]
[23 14 59  4 27 83 28 

transformの定義．
- UCF101を読み込むとuint8なので，255で割ってfloatにする．
- torchvisionのUCF101データセットは(T, H, W, C)の形式．しかしpytorchvideoのx3dの入力形式は(B, C, T, H, W)らしいので，それに合わせる．
- X3D-Mを想定して，画像を224x224にリサイズする．transform.Resize()はまだ試していないが，この形式ができるかどうか不明（torchvisionのtransformは画像しか扱わないのでムリだと思う）

データセットはimage, audio, labelの三組を返すが，UCF101には音声がない動画もあり，そのまま使うとdataloaderがバッチにできないというエラーが出てしまう（audioの次元数がサンプルによって異なるため）．そこでcollateでaudioを取り除く．

メタデータの準備．UCF101の全動画をスキャンして，FPSなどの情報を取得するらしい．かなり時間がかかる．
それを保存して再利用（毎回計算し直すと時間の無駄）．
コードを見たところ，foldやtrainには無関係で，fpcとsbcにだけ依存するらしいので，それをファイル名にして保存する．

UCF101には3つのスプリットがあるので，foldでそれを指定（多分）

データローダーの作成．collateをここで指定．

In [14]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = "cpu"

pytorchvideoのx3dモデルを作成．
webマニュアルにはないが，コードをみると，クリップ長とサイズが指定できる．
X3Dは数種類あるが，ここではX3D-Mに合わせた数字を指定（コードのコメントに書いてある）

In [15]:
# # X3D-M
# # https://github.com/facebookresearch/pytorchvideo/blob/master/pytorchvideo/models/x3d.py#L601
# model = x3d.create_x3d(
#     input_clip_length=16,
#     input_crop_size=224,
#     depth_factor=2.2,
#     model_num_class=101
# ).to(device)

num_classes = 101

model = torch.hub.load('facebookresearch/pytorchvideo', 'x3d_m', pretrained=True)

for param in model.parameters():
    param.requires_grad = False

model.blocks[5].proj = nn.Linear(model.blocks[5].proj.in_features, num_classes)
model = model.to(device)



Using cache found in /home/tamaki/.cache/torch/hub/facebookresearch_pytorchvideo_master


ランダムなデータを流し込んで出力されるかを確認する

In [16]:
data = torch.randn(2, 3, 16, 224, 224).to(device)

In [17]:
model(data)

tensor([[0.0102, 0.0096, 0.0089, 0.0099, 0.0111, 0.0104, 0.0098, 0.0106, 0.0101,
         0.0103, 0.0099, 0.0091, 0.0106, 0.0081, 0.0100, 0.0108, 0.0094, 0.0089,
         0.0109, 0.0097, 0.0096, 0.0108, 0.0096, 0.0099, 0.0095, 0.0100, 0.0096,
         0.0099, 0.0108, 0.0101, 0.0099, 0.0098, 0.0088, 0.0096, 0.0104, 0.0101,
         0.0101, 0.0092, 0.0091, 0.0106, 0.0107, 0.0099, 0.0106, 0.0111, 0.0093,
         0.0095, 0.0103, 0.0092, 0.0093, 0.0096, 0.0098, 0.0108, 0.0101, 0.0109,
         0.0109, 0.0100, 0.0099, 0.0106, 0.0109, 0.0101, 0.0097, 0.0090, 0.0094,
         0.0103, 0.0096, 0.0104, 0.0099, 0.0101, 0.0104, 0.0102, 0.0107, 0.0087,
         0.0101, 0.0103, 0.0088, 0.0110, 0.0096, 0.0095, 0.0109, 0.0085, 0.0091,
         0.0091, 0.0103, 0.0102, 0.0097, 0.0088, 0.0097, 0.0107, 0.0094, 0.0090,
         0.0096, 0.0101, 0.0103, 0.0103, 0.0095, 0.0102, 0.0096, 0.0103, 0.0090,
         0.0096, 0.0092],
        [0.0099, 0.0105, 0.0091, 0.0109, 0.0106, 0.0097, 0.0098, 0.0098, 0.0103,
  

summaryで中身を確認

In [18]:
torchinfo.summary(
    model,
    (4, 3, 16, 224, 224),
    depth=4,
    col_names=["input_size",
               "output_size"],
    row_settings=("var_names",)
)

Layer (type (var_name))                                      Input Shape               Output Shape
Net                                                          --                        --
├─ModuleList (blocks)                                        --                        --
│    └─ResNetBasicStem (0)                                   [4, 3, 16, 224, 224]      [4, 24, 16, 112, 112]
│    │    └─Conv2plus1d (conv)                               [4, 3, 16, 224, 224]      [4, 24, 16, 112, 112]
│    │    │    └─Conv3d (conv_t)                             [4, 3, 16, 224, 224]      [4, 24, 16, 112, 112]
│    │    │    └─Conv3d (conv_xy)                            [4, 24, 16, 112, 112]     [4, 24, 16, 112, 112]
│    │    └─BatchNorm3d (norm)                               [4, 24, 16, 112, 112]     [4, 24, 16, 112, 112]
│    │    └─ReLU (activation)                                [4, 24, 16, 112, 112]     [4, 24, 16, 112, 112]
│    └─ResStage (1)                                          [4, 2

便利関数を定義

In [19]:
class AverageMeter(object):
    """
    Computes and stores the average and current value
    Imported from https://github.com/pytorch/examples/blob/master/imagenet/main.py#L247-L262
    https://github.com/machine-perception-robotics-group/attention_branch_network/blob/ced1d97303792ac6d56442571d71bb0572b3efd8/utils/misc.py#L59
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        if type(val) == torch.Tensor:
            val = val.item()
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def top1(outputs, targets):
    batch_size = outputs.size(0)
    _, predicted = outputs.max(1)
    return predicted.eq(targets).sum().item() / batch_size

In [20]:
# from numpy.random import randn

# train_loss = AverageMeter()
# for i in list(range(100)):
#     train_loss.update(randn())
#     print(train_loss.avg, 
#           train_loss.count, 
#           train_loss.sum,
#           train_loss.val)

torchvisionのvideo.pyで，ワーニングが多数出るのでそれを抑制．

In [21]:
# import warnings
# warnings.filterwarnings("ignore", category=UserWarning,
#                                    module='torchvision')

In [22]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()

In [23]:
num_epochs = 5

model = model.to(device)

with tqdm(range(num_epochs)) as pbar_epoch:
    for epoch in pbar_epoch:
        pbar_epoch.set_description("[Epoch %d]" % (epoch))


        with tqdm(enumerate(train_loader),
                  total=len(train_loader),
                  leave=True) as pbar_loss:

            train_loss = AverageMeter()
            train_acc = AverageMeter()
            model.train()

            for batch_idx, batch in pbar_loss:
                pbar_loss.set_description("[train]")

                inputs, targets = batch['video'].to(device), batch['label'].to(device)
                bs = inputs.size(0)  # current batch size, may vary at the end of the epoch

                print(targets.cpu().numpy())

                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()
                train_loss.update(loss, bs)
                train_acc.update(top1(outputs, targets), bs)

                pbar_loss.set_postfix_str(
                    ' | loss={:6.04f} , top1={:6.04f}'
                    ' | loss={:6.04f} , top1={:6.04f}'
                    ''.format(
                    train_loss.avg, train_acc.avg,
                    train_loss.val, train_acc.val,
                ))



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1192.0), HTML(value='')))

[18 18 18 18 18 18 18 18]
[18 18 18 18 18 18 18 18]
[18 18 18 18 18 18 18 18]
[18 18 18 18 18 18 18 18]
[18 18 18 18 18 18 18 18]
[18 18 18 18 18 18 18 18]
[18 18 18 18 18 18 18 18]
[18 18 18 18 18 18 18 18]
[18 18 18 18 18 18 18 18]
[18 18 18 18 18 18 18 18]
[18 18 18 18 18 18 18 18]
[18 18 18 18 18 18 18 18]
[18 18 18 18 18 18 18 18]
[18 18 18 18 18 18 18 18]
[18 18 18 18 18 18 18 18]
[18 18 18 18 18 18 18 18]
[18 18 18 18 18 18 18 18]
[18 18 18 18 18 18 18 18]
[18 18 18 18 18 18 18 18]
[18 18 18 18 18 18 18 18]
[18 18 18 18 18 18 18 18]
[18 18 18 18 18 18 18 18]
[18 18 18 18 18 18 18 18]
[18 18 18 18 18 18 18 18]
[18 18 18 18 18 64 64 64]
[18 18 18 18 18 36 36 36]
[18 18 18 18 18 28 28 28]
[18 18 18 18 18 36 36 36]
[18 18 18 18 18  6  6  6]
[18 18 18 18 18  3  3  3]
[18 18 18 18 18 67 67 67]
[18 18 18 18 18  6  6  6]
[18 18 18 18 18 73 73 73]
[18 18 18 18 18 95 95 95]
[18 18 18 18 18 36 36 36]
[18 18 18 18 18  1  1  1]
[18 18 18 18 18 81 81 81]
[18 18 18 18 18 30 30 30]
[18 18 18 18

KeyboardInterrupt: 