# Custome Video Dataset for PyTorch

Notes about usage:

This code is to create a custom video dataset to train deeplearning models 
using PyTorch on consecutive video frames extracted from a video. This code expects the extracted video frames in separate folders. For example, video1's frames will be in a folder named 'video1'. You can use OpenCV to extract video frames and save inside folders.

In [7]:
from __future__ import print_function
import glob
from itertools import chain
import os
import cv2
import random
import zipfile
import os.path as osp
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from imgaug import augmenters as iaa
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
from torch.utils.data.distributed import DistributedSampler
import sys

In [1]:
# path to directories, expects the following structure

# real_videos > video1 > video_frames,
#                video 2 > video_frames,
#                 ...
train_dir_real = 'path_to_data/Video_Dataset/real/'
train_dir_fake = 'path_to_data/Video_Dataset/fake/'

# gets paths to directories as list
train_list_real = glob.glob(os.path.join(train_dir_real,'*'))
train_list_fake = glob.glob(os.path.join(train_dir_fake,'*'))

# creates full training list
train_list = []
train_list.extend(train_list_real)
train_list.extend(train_list_fake)
random.shuffle(train_list)

In [2]:
print(f"Train Data Real: {len(train_list_real)}")
print(f"Train Data Fake: {len(train_list_fake)}")
print(f"Train Data: {len(train_list)}")

In [3]:
# get labels from path
labels = [path.split('/')[-2].split('.')[0] for path in train_list]

# create train and validation sets
train_list, valid_list = train_test_split(train_list, 
                                          test_size=0.2,
                                          stratify=labels,
                                          random_state=173)

print(f"Train Data: {len(train_list)}")
print(f"Validation Data: {len(valid_list)}")

In [4]:
# apply transformations
train_transforms = transforms.Compose(
    [
        transforms.Resize((299, 299)),
        transforms.ToTensor()    ]
)

val_transforms = transforms.Compose(
    [
        transforms.Resize((299, 299)),
        transforms.ToTensor(),
    ]
)

test_transforms = transforms.Compose(
    [
        transforms.Resize((299, 299)),
        transforms.ToTensor(),
    ]
)


In [5]:
# custom dataset class, expects 6 frames from each video
class DeepFakeSet(Dataset):
    def __init__(self, file_list, transform=None):

        self.file_list = file_list
        self.transform = transform
    
    def __len__(self):
        self.filelength = len(self.file_list)
        return self.filelength

    def __getitem__(self, idx):
        
        # get all sorted frames from a single video
        frames_in_video = sorted(glob.glob(self.file_list[idx] +'/*.png'))
        
        # empty tensor
        concated_images_per_video = torch.zeros(1, 3, 299, 299)
        
        # makes sure we have 6 frames, if we have less than 6 frames in any given video, we will copy 
        # consecutive frames to reach the total tally of 6 frames, which our model expects
        frame_check = 0
        for i in range(len(frames_in_video)):
            
            if len(frames_in_video) == 4 and frame_check != 3:
                for j in range(3):
                    frame_check +=1
                    img_path = frames_in_video[i]
                    img = Image.open(img_path)
                    img_transformed = self.transform(img)
                    img_transformed = img_transformed.unsqueeze(0)
                    concated_images_per_video = torch.cat((concated_images_per_video, img_transformed), dim=0)
            
            elif len(frames_in_video) == 5 and frame_check != 2:
                for j in range(2):    
                    frame_check+=1
                    img_path = frames_in_video[i]
                    img = Image.open(img_path)
                    img_transformed = self.transform(img)
                    img_transformed = img_transformed.unsqueeze(0)
                    concated_images_per_video = torch.cat((concated_images_per_video, img_transformed), dim=0)
            else:
                img_path = frames_in_video[i]
                img = Image.open(img_path)
                img_transformed = self.transform(img)
                img_transformed = img_transformed.unsqueeze(0)
                concated_images_per_video = torch.cat((concated_images_per_video, img_transformed), dim=0)
        
        label = img_path.split("/")[-3]
        label = 1 if label == "real" else 0
        
        # take concatenated frame tensor, leave the zero tansor behind (we created zero tensor befor the for loop)
        return concated_images_per_video[1:], label


In [None]:
train_data = DeepFakeSet(train_list, transform=train_transforms)
valid_data = DeepFakeSet(valid_list, transform=val_transforms)
batch_size = 24
train_loader = DataLoader(dataset = train_data, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(dataset = valid_data, batch_size=batch_size, shuffle=True)