In [33]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import numpy as np
import pandas as pd
from PIL import Image, ImageDraw
import random
import warnings
import os
warnings.filterwarnings("ignore")


In [44]:
class UCF101(datasets.UCF101):
    def __init__(self, temporal_transform, spatial_transform, norm_method, detection_file_path,
                 *args, **kwargs):
        super().__init__(*args, **kwargs)
        # detection data with YOLO v4
        # https://github.com/AlexeyAB/darknet
        self.detection_data = pd.read_csv(detection_file_path,
                                          sep=' ', 
                                          names=["object_class", "x_center", "y_center", "width", "height"])
        
        self.kwargs = kwargs
        self.args = args
        
        self.temporal_transform = temporal_transform
        self.spatial_transform = spatial_transform
        self.norm_method = norm_method
        
    def fetch_video_data(self, video_idx):
        video_path, label = self.samples[self.indices[video_idx]]
        video, audio, info = torchvision.io.read_video(video_path)
        clip_pts = list(range(len(video)))

        return video, label, clip_pts, video_path
    
    def fetch_clip_data(self, clip_idx):
        clip, audio, info, video_idx, clip_pts = self.video_clips.get_clip(
            clip_idx)  # clip_pts-'frame index'
        video_path, label = self.samples[self.indices[video_idx]]

        return clip, label, clip_pts, video_path
    
    def apply_spatial_transform(self, video, randomize=True, normalize=True):
        if randomize:
            self.spatial_transform.randomize_parameters(video)
        video = self.spatial_transform(video)
        if normalize:
            video = self.norm_method(video)

        return video

    def __len__(self):
        if self.train:
            return len(self.indices)
        else:
            return self.video_clips.num_clips()

    def __getitem__(self, idx):
        if self.train:
            video, label, clip_pts, video_path = self.fetch_video_data(idx)
        else:
            video, label, clip_pts, video_path = self.fetch_clip_data(idx)
        
        # spatial random crop
        if self.temporal_transform is not None:
            clip_pts = self.temporal_transform(clip_pts)
        # convert List -> Tensor
        clip_pts = torch.as_tensor(clip_pts)
        
        # retrieve a subVideoClip
        video = video[clip_pts]
        
        video = Resize3D(size=(144,144),
                         interpolation=Image.BILINEAR)(video)
        
        query = os.path.splitext(video_path[len(self.root.rstrip('/'))+1:])[0]
        query = [ os.path.join(query, "thumb{:04d}.txt".format(i)) for i in clip_pts ]
        
        detection_res = self.detection_data.loc[query]
        import ipdb
        ipdb.set_trace()
        
        if self.spatial_transform is not None:
            video = self.apply_spatial_transform(
                        video, randomize=True, normalize=True)
        
    
        return video, label

In [45]:
from utils.transforms import (
    Compose, RandomCrop3D, Resize3D, CenterCrop3D, RandomHorizontalFlip3D, ToTensor3D, Normalize3D,
    TemporalRandomCrop, LoopPadding)

In [46]:
spatial_transform = {
        "train": Compose(
            [
                RandomCrop3D(transform2D=transforms.RandomCrop(
                    size=(112, 112))
                ),
                RandomHorizontalFlip3D(),
                ToTensor3D()
            ]
        ),
        "test": Compose(
            [
                CenterCrop3D((112, 112)),
                ToTensor3D()
            ]
        )
}

temporal_transform = {
    "train": TemporalRandomCrop(size=16),
    "test": LoopPadding(16)
}

norm_method = Normalize3D(
        mean=[
    0.43216,
    0.394666,
    0.37645
  ],
        std=[
    0.22803,
    0.22145,
    0.216989
  ]
    )

In [47]:
ds = UCF101(root='/data/torch_data/UCF-101/video', 
            annotation_path='/data/torch_data/UCF-101/ucfTrainTestlist', 
            detection_file_path='/data/torch_data/UCF-101/detection_yolov4.txt',
            frames_per_clip=16, num_workers=16, train=True,
            temporal_transform=temporal_transform["train"],
            spatial_transform=spatial_transform["train"],
            norm_method=norm_method)

HBox(children=(IntProgress(value=0, max=833), HTML(value='')))




In [48]:
video = ds[0][0]

> [0;32m<ipython-input-44-54654d27f6f3>[0m(72)[0;36m__getitem__[0;34m()[0m
[0;32m     71 [0;31m[0;34m[0m[0m
[0m[0;32m---> 72 [0;31m        [0;32mif[0m [0mself[0m[0;34m.[0m[0mspatial_transform[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     73 [0;31m            video = self.apply_spatial_transform(
[0m
ipdb> detection_res
                                                    object_class  x_center  \
ApplyEyeMakeup/v_ApplyEyeMakeup_g08_c01/thumb00...           0.0    0.5031   
ApplyEyeMakeup/v_ApplyEyeMakeup_g08_c01/thumb00...           0.0    0.5037   
ApplyEyeMakeup/v_ApplyEyeMakeup_g08_c01/thumb00...           NaN       NaN   
ApplyEyeMakeup/v_ApplyEyeMakeup_g08_c01/thumb00...           0.0    0.4989   
ApplyEyeMakeup/v_ApplyEyeMakeup_g08_c01/thumb00...          71.0    0.5664   
ApplyEyeMakeup/v_ApplyEyeMakeup_g08_c01/thumb00...           0.0    0.4994   
ApplyEyeMakeup/v_ApplyEyeMakeup_g08_c01/thumb00...    

BdbQuit: 

In [21]:
ds.detection_data

Unnamed: 0,object_class,x_center,y_center,width,height
ApplyEyeMakeup/v_ApplyEyeMakeup_g01_c01/thumb0001.txt,62,0.1171,0.2951,0.2391,0.5967
ApplyEyeMakeup/v_ApplyEyeMakeup_g01_c01/thumb0001.txt,0,0.6659,0.5017,0.6467,1.0919
ApplyEyeMakeup/v_ApplyEyeMakeup_g01_c01/thumb0002.txt,62,0.1163,0.2936,0.2351,0.5979
ApplyEyeMakeup/v_ApplyEyeMakeup_g01_c01/thumb0002.txt,0,0.6668,0.5039,0.6828,1.1075
ApplyEyeMakeup/v_ApplyEyeMakeup_g01_c01/thumb0003.txt,62,0.1158,0.2963,0.2343,0.6037
...,...,...,...,...,...
YoYo/v_YoYo_g25_c05/thumb0192.txt,56,0.9642,0.6388,0.0700,0.3877
YoYo/v_YoYo_g25_c05/thumb0193.txt,56,0.9628,0.6228,0.0726,0.3638
YoYo/v_YoYo_g25_c05/thumb0195.txt,56,0.9625,0.6223,0.0734,0.3658
YoYo/v_YoYo_g25_c05/thumb0196.txt,56,0.9626,0.6228,0.0732,0.3659


# Load UCF & HMDB

In [None]:
ucf_ds = datasets.UCF101(root='/data/torch_data/UCF-101/video', 
                         annotation_path='/data/torch_data/UCF-101/ucfTrainTestlist', 
                         frames_per_clip=16, num_workers=16)

In [None]:
len(ucf_ds) / 80

In [None]:
13320 / 16

In [None]:
len(ucf_ds)

In [None]:
torch.__version__