# Welcome to torchvision's new video API

Here, we're going to examine the capabilities of the new video API, together with the examples on how to build datasets and more. 

### Table of contents
1. Introduction: building a new video object and examining the properties
2. Building a sample `read_video` function
3. Building an example dataset (can be applied to e.g. kinetics400)

## 1. Introduction: building a new video object and examining the properties

First we select a video to test the object out. For the sake of argument we're using one from Kinetics400 dataset. To create it, we need to define the path and the stream we want to use. See inline comments for description.  

In [1]:
import torch, torchvision
"""
chosen video statistics:
WUzgd7C1pWA.mp4
  - source: kinetics-400
  - video: H-264 - MPEG-4 AVC (part 10) (avc1)
    - fps: 29.97
  - audio: MPEG AAC audio (mp4a)
    - sample rate: 48K Hz
"""
video_path = "../../test/assets/videos/WUzgd7C1pWA.mp4"

"""
streams are defined in a similar fashion as torch devices. We encode them as strings in a form
of `stream_type:stream_id` where stream_type is a string and stream_id a long int. 

The constructor accepts passing a stream_type only, in which case the stream is auto-discovered.
"""
stream = "video"



video = torch.classes.torchvision.Video(video_path, stream)

First, let's get the metadata for our particular video:

In [3]:
video.get_metadata()

{'video': {'duration': [10.9109], 'fps': [29.97002997002997]},
 'audio': {'duration': [10.9], 'framerate': [48000.0]}}

Here we can see that video has two streams - a video and an audio stream. 

Let's read all the frames from the video stream.

In [7]:
# first we select the video stream 
video.set_current_stream("video:0")

frames = []  # we are going to save the frames here.
frame, pts = video.next()
# note that next will return emptyframe at the end of the video stream
while frame.numel() != 0:
    frames.append(frame)
    frame, pts = video.next()
    
print("Total number of frames: ", len(frames))
approx_nf = video.get_metadata()['video']['duration'][0] * video.get_metadata()['video']['fps'][0]
print("We can expect approx: ", approx_nf)
print("Tensor size: ", frames[0].size())

Total number of frames:  327
We can expect approx:  327.0
Tensor size:  torch.Size([3, 256, 340])


Note that selecting zero video stream is equivalent to selecting video stream automatically. I.e. `video:0` and `video` will end up with same results in this case. 

Let's try this for audio

In [10]:
video.set_current_stream("audio")

frames = []  # we are going to save the frames here.
frame, pts = video.next()
# note that next will return emptyframe at the end of the video stream
while frame.numel() != 0:
    frames.append(frame)
    frame, pts = video.next()
    
print("Total number of frames: ", len(frames))
approx_nf = video.get_metadata()['video']['duration'][0] * video.get_metadata()['video']['fps'][0]
print("We can expect approx: ", approx_nf)
print("Tensor size: ", frames[0].size())

Total number of frames:  511
We can expect approx:  327.0
Tensor size:  torch.Size([1024, 1])


But what if we only want to read certain time segment of the video?

That can be done easily using the combination of our seek function, and the fact that each call to next returns the presentation timestamp of the returned frame in seconds.

For example, if we wanted to read video from second to fifth second:

In [12]:
video.set_current_stream("video")

frames = []  # we are going to save the frames here.

# we seek into a second second of the video 
# the following call to next returns the first following frame
video.seek(2)  
frame, pts = video.next()
# note that we add exit condition
while pts < 5 and frame.numel() != 0:
    frames.append(frame)
    frame, pts = video.next()
    
print("Total number of frames: ", len(frames))
approx_nf = (5-2) * video.get_metadata()['video']['fps'][0]
print("We can expect approx: ", approx_nf)
print("Tensor size: ", frames[0].size())

Total number of frames:  90
We can expect approx:  89.91008991008991
Tensor size:  torch.Size([3, 256, 340])


## 2. Building a sample `read_video` function

We can utilize the methods above to build the read video function that follows the same API to the existing `read_video` function 

In [42]:
def example_read_video(video_object, start=0, end=None, read_video=True, read_audio=True):

    if end is None:
        end = float("inf")
    if end < start:
        raise ValueError(
            "end time should be larger than start time, got "
            "start time={} and end time={}".format(s, e)
        )
    
    video_frames = torch.empty(0)
    video_pts = []
    if read_video:
        video_object.set_current_stream("video")
        video_object.seek(start)
        frames = []
        t, pts = video_object.next()
        while t.numel() > 0 and (pts >= start and pts <= end):
            frames.append(t)
            video_pts.append(pts)
            t, pts = video_object.next()
        if len(frames) > 0:
            video_frames = torch.stack(frames, 0)

    audio_frames = torch.empty(0)
    audio_pts = []
    if read_audio:
        video_object.set_current_stream("audio")
        video_object.seek(start)
        frames = []
        t, pts = video_object.next()
        while t.numel() > 0 and (pts >= start and pts <= end):
            frames.append(t)
            audio_pts.append(pts)
            t, pts = video_object.next()
        if len(frames) > 0:
            audio_frames = torch.cat(frames, 1)

    return video_frames, audio_frames, (video_pts, audio_pts), video_object.get_metadata()

In [43]:
vf, af, info, meta = example_read_video(video)
# total number of frames should be 327
vf.size()

torch.Size([327, 3, 256, 340])

In [44]:
# you can also get the sequence of audio frames as well
af.size()

torch.Size([1024, 511])

## 3. Building an example randomly sampled dataset (can be applied to training dataest of kinetics400)

Cool, so now we can use the same principle to make the sample dataset. We suggest trying out iterable dataset for this purpose. 

Here, we are going to build

a. an example dataset that reads randomly selected 10 frames of video

In [48]:
# first, housekeeping and utilities
import os
import random

import torch
from torchvision.datasets.folder import make_dataset
from torchvision import transforms as t

def _find_classes(dir):
    classes = [d.name for d in os.scandir(dir) if d.is_dir()]
    classes.sort()
    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx

def get_samples(root, extensions=(".mp4", ".avi")):
    _, class_to_idx = _find_classes(root)
    return make_dataset(root, class_to_idx, extensions=extensions)

We are going to define the dataset and some basic arguments. We asume the structure of the FolderDataset, and add the following parameters:
    
1. frame transform: with this API, we can chose to apply transforms on every frame of the video
2. videotransform: equally, we can also apply transform to a 4D tensor
3. length of the clip: do we want a single or multiple frames?

Note that we actually add `epoch size` as using `IterableDataset` class allows us to naturally oversample clips or images from each video if needed. 

In [52]:
class RandomDataset(torch.utils.data.IterableDataset):
    def __init__(self, root, epoch_size=None, frame_transform=None, video_transform=None, clip_len=16):
        super(RandomDataset).__init__()
        
        self.samples = get_samples(root)
         
        # allow for temporal jittering
        if epoch_size is None:
            epoch_size = len(self.samples)
        self.epoch_size = epoch_size
        
        self.clip_len = clip_len  # length of a clip in frames
        self.frame_transform = frame_transform  # transform for every frame individually
        self.video_transform = video_transform # transform on a video sequence

    def __iter__(self):
        for i in range(self.epoch_size):
            # get random sample
            path, target = random.choice(self.samples)
            # get video object
            vid = torch.classes.torchvision.Video(path, "video")
            metadata = vid.get_metadata()
            video_frames = [] # video frame buffer 
            # seek and return frames
            
            max_seek = metadata["video"]['duration'][0] - (self.clip_len / metadata["video"]['fps'][0])
            start = random.uniform(0., max_seek)
            vid.seek(start)
            while len(video_frames) < self.clip_len:
                frame, current_pts = vid.next()
                video_frames.append(self.frame_transform(frame))
            # stack it into a tensor
            video = torch.stack(video_frames, 0)
            if self.video_transform:
                video = self.video_transform(video)
            output = {
                'path': path,
                'video': video,
                'target': target,
                'start': start,
                'end': current_pts}
            yield output

Given a path of videos in a folder structure, i.e:
```
dataset:
    -class 1:
        file 0
        file 1
        ...
    - class 2:
        file 0
        file 1
        ...
    - ...
```
We can generate a dataloader and test the dataset. 
            

In [56]:
from torchvision import transforms as t
transforms = [t.ToPILImage(), t.Resize((112, 112), interpolation=2), t.ToTensor()]
frame_transform = t.Compose(transforms)

ds = RandomDataset("/home/bjuncek/work/video_reader_benchmark/dataset_files", epoch_size=None, frame_transform=frame_transform)

In [57]:
from torch.utils.data import DataLoader
loader = DataLoader(ds, batch_size=12)
d = {"video":[], 'start':[], 'end':[]}
for b in loader:
    for i in range(len(b['path'])):
        d['video'].append(b['path'][i])
        d['start'].append(b['start'][i].item())
        d['end'].append(b['end'][i].item())