In [1]:
import cv2
from nvidia.dali import pipeline_def
import nvidia.dali.fn as fn
from nvidia.dali.plugin.pytorch import LastBatchPolicy
from nvidia.dali.plugin.pytorch import DALIGenericIterator
import nvidia.dali.types as types
import torch
import numpy as np
from tqdm import tqdm
from typing import List, Dict, Optional, Union, Literal, Tuple

# see https://pytorch.org/vision/stable/models.html
_IMAGENET_MEAN = [0.485, 0.456, 0.406]
_IMAGENET_STD = [0.229, 0.224, 0.225]

In [2]:
@pipeline_def
def video_pipe(
    filenames: Union[List[str], str],
    resize_dims: Optional[List[int]] = None,
    random_shuffle: bool = False,
    seed: int = 123456,
    sequence_length: int = 16,
    pad_sequences: bool = True,
    initial_fill: int = 16,
    normalization_mean: List[float] = _IMAGENET_MEAN,
    normalization_std: List[float] = _IMAGENET_STD,
    device: str = "gpu",
    name: str = "reader",
    step: int = 1,
    pad_last_batch: bool = False,
    # arguments consumed by decorator:
    # batch_size,
    # num_threads,
    # device_id
) -> tuple:
    """Generic video reader pipeline that loads videos, resizes, augments, and normalizes.

    Args:
        filenames: list of absolute paths of video files to feed through
            pipeline
        resize_dims: [height, width] to resize raw frames
        random_shuffle: True to grab random batches of frames from videos;
            False to sequential read
        seed: random seed when `random_shuffle` is True
        sequence_length: number of frames to load per sequence
        pad_sequences: allows creation of incomplete sequences if there is an
            insufficient number of frames at the very end of the video
        initial_fill: size of the buffer that is used for random shuffling
        normalization_mean: mean values in (0, 1) to subtract from each channel
        normalization_std: standard deviation values to subtract from each
            channel
        device: "cpu" | "gpu"
        name: pipeline name, used to string together DataNode elements
        step: number of frames to advance on each read
        pad_last_batch

    Returns:
        pipeline object to be fed to DALIGenericIterator

    """
    video = fn.readers.video(
        device=device,
        filenames=filenames,
        random_shuffle=random_shuffle,
        seed=seed,
        sequence_length=sequence_length,
        step=step,
        pad_sequences=pad_sequences,
        initial_fill=initial_fill,
        normalized=False,
        name=name,
        dtype=types.DALIDataType.FLOAT,
        pad_last_batch=pad_last_batch,  # Important for context loaders
        file_list_include_preceding_frame=True,  # to get rid of dali warnings
    )
    if resize_dims:
        video = fn.resize(video, size=resize_dims)
    # video pixel range is [0, 255]; transform it to [0, 1].
    # happens naturally in the torchvision transform to tensor.
    video = video / 255.0
    # permute dimensions and normalize to imagenet statistics
    transform = fn.crop_mirror_normalize(
        video,
        output_layout="FCHW",
        mean=normalization_mean,
        std=normalization_std,
    )
    return transform


def count_frames(video_list: Union[List[str], str]) -> int:
    """Simple function to count the number of frames in a video or a list of videos."""
    if isinstance(video_list, str):
        video_list = [video_list]
    num_frames = 0
    for video_file in video_list:
        cap = cv2.VideoCapture(video_file)
        num_frames += int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        cap.release()

    return num_frames


class LitDaliWrapper(DALIGenericIterator):
    """wrapper around a DALI pipeline to get batches for ptl."""

    def __init__(
        self,
        *args,
        num_iters: int = 1,
        **kwargs
    ) -> None:
        """Wrapper around DALIGenericIterator to get batches for pl.

        Args:
            num_iters: number of enumerations of dataloader (should be computed outside for now;
                should be fixed by lightning/dali teams)

        """
        self.num_iters = num_iters
        self.batch_sampler = 1  # hack to get around DALI-ptl issue
        # call parent
        super().__init__(*args, **kwargs)

    def __len__(self) -> int:
        return self.num_iters

    def __next__(self):
        return super().__next__()

provide a path to some video file for testing -- here is one test video from this repo.

In [3]:
video_file = "../toy_datasets/toymouseRunningData/unlabeled_videos/test_vid.mp4"

## Temporal Context loader
Here are we define a pipeline that reads sequences of 5 frames at a time.
In practice we have a model that takes 5 frames as input and outputs a pose for the 3rd frame, but we omit the model now.

In [4]:

# args for context loader
pipe_args = {
    "filenames": video_file,
    "resize_dims": [64, 64],
    "sequence_length": 5,
    "step": 1,
    "batch_size": 16,
    "num_threads": 4,
    "device_id": 0,
    "random_shuffle": False,
    "device": "gpu",
    "name": "reader",
    "pad_sequences": True,
    "pad_last_batch": True,
}

pipe = video_pipe(**pipe_args)

# set up parameters for pytorch iterator
frame_count = count_frames(video_file)
# taken from https://github.com/danbider/lightning-pose/blob/b66fe34719ec89631f74c6c911a5e1a013bc7e34/lightning_pose/data/dali.py#L237

# this is for the base prediction, no context
# num_iters = int(np.ceil(frame_count / pipe_args["sequence_length"]))

# context loader, different way of calculating num_iters
num_iters = int(np.ceil(frame_count / (pipe_args["batch_size"])))

iterator_args = {
    "num_iters": num_iters,
    "output_map": ["frames"],
    "last_batch_policy": LastBatchPolicy.FILL,
    "auto_reset": False,
    "reader_name": "reader"
}


In [5]:
num_iters

63

In [6]:
# build iterator
iterator = LitDaliWrapper(pipe, **iterator_args)


In [7]:
# iterate over data
from time import time
start_time = time()
for batch in tqdm(iterator):
    shape =batch[0]["frames"].shape # we don't print now to avoid cluttering the notebook
end_time = time()
print(f"Time to iterate over {num_iters} batches: {end_time - start_time}")

100%|██████████| 63/63 [00:30<00:00,  2.06it/s]

Time to iterate over 63 batches: 30.634568452835083





In [8]:
total_time = end_time - start_time
time_per_iter = total_time / num_iters

In [9]:
time_per_iter

0.4862629913148426

## Base prediction
Now define the args for base prediction. Just iterate over sequences until we reach the end of the video. In practice we have a standard resnet waiting for these frames and evaluated in the standard way.

In [34]:
# video_file = "/home/jovyan/datastores/mirror-mouse/videos_new/180607_004.mp4"

# frame_count = count_frames(video_file)


In [35]:
# args for base predict loader
pipe_args = {
    "filenames": video_file,
    "resize_dims": [64, 64],
    "sequence_length": 64,
    "step": 64,
    "batch_size": 1,
    "num_threads": 4, # was 4
    "device_id": 0,
    "random_shuffle": False,
    "device": "gpu", 
    "name": "reader",
    "pad_sequences": True,
}

num_iters = int(np.ceil(frame_count / pipe_args["sequence_length"]))

# https://github.com/danbider/lightning-pose/blob/0d9c26c42cbddbd16a8f01937d714d221474225d/lightning_pose/data/dali.py#L386
iterator_args = {
    "num_iters": num_iters,
    "output_map": ["frames"],
    "last_batch_policy": LastBatchPolicy.FILL,
    "last_batch_padded": False,
    "auto_reset": False,
    "reader_name": "reader"
}

In [36]:
num_iters

469

In [37]:
pipe = video_pipe(**pipe_args)

In [38]:
# build iterator
iterator = LitDaliWrapper(pipe, **iterator_args)


In [39]:
# iterate over data
from time import time
start_time = time()
for batch in tqdm(iterator):
    shape =batch[0]["frames"].shape # we don't print now to avoid cluttering the notebook
end_time = time()
print(f"Time to iterate over {num_iters} batches: {end_time - start_time}")

100%|██████████| 469/469 [00:23<00:00, 19.94it/s]

Time to iterate over 469 batches: 23.523786544799805





In [46]:
total_time = end_time - start_time
time_per_iter = total_time / num_iters
print(f"Time per iteration: {time_per_iter}")

Time per iteration: 0.050110924726864424


In [21]:
# 994 // num_iters

In [20]:
# import numpy as np

In [6]:
# # test saving to .npy file, batch by batch
# batch_dims = (64, 2000, 16, 16)
# num_batches_to_test = 7
# test_filename = "/home/jovyan/test.npy"

# # create fake data and save on the fly to a single .npy file
# with open('test.npy', 'wb') as f:
#     for i in range(num_batches_to_test):
#         fake_data = np.ones(batch_dims, dtype=np.float16) * i
#         np.save(f, fake_data)

In [19]:
# # load the data back in
# with open('test.npy', 'rb') as f:
#     for i in range(num_batches_to_test):
#         loaded_data = np.load(f)
#         print(loaded_data.shape)
#         print(np.unique(loaded_data))

In [8]:
# loaded_data.shape

(64, 2000, 16, 16)