In [2]:
import glob
import json
import os
import argparse
import pytorch_lightning as pl
import numpy as np
from pathlib import Path
from typing import Collection, Dict, Optional, Tuple, Union
from torch.utils.data import ConcatDataset, DataLoader
from torchvision import transforms
from biometrics.src.utils.util_data import BaseDataset, split_dataset
import cv2

BATCH_SIZE = 12  # 128
NUM_WORKERS = 0
IMG_SIZE = 256
TRAIN_FRAC = 0.8

In [3]:
def crop_center_square(frame):
    y, x = frame.shape[0:2]
    min_dim = min(y, x)
    start_x = (x // 2) - (min_dim // 2)  # // floor division
    start_y = (y // 2) - (min_dim // 2)
    return frame[start_y: start_y + min_dim, start_x: start_x + min_dim]


def load_video(path, max_frames=0, resize=(IMG_SIZE, IMG_SIZE)):
    cap = cv2.VideoCapture(path)
    frames = []
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            print("video processed")
            break
        frame = crop_center_square(frame)
        frame = cv2.resize(frame, resize)
        frame = frame[:, :, [2, 1, 0]]
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        frames.append(frame)
        if len(frames) == max_frames:
            break
    cap.release()
    cv2.destroyAllWindows()
    return np.array(frames)

def load_x16_frames_from_video(filename):
    frames = load_video(filename)    
    step = len(frames)//16
    frames = frames[::step]
    if len(frames) == 19:
        frames = frames[:-2]
    if len(frames) == 16:
        print(":9999")        
        frames = list(frames)
        frames.append(frames[-1])
        frames = np.array(frames)        
        print("new size", len(frames))
    if len(frames) == 18:
        frames = frames[:-1]
    return frames

labels = [("1",0),("2",0),("3",1),("4",1),("5",2),("6",2),("7",3),("8",3),("HR_1",0),("HR_2",1),("HR_3",2),("HR_4",3)]
labels = { e[0]: e[1] for e in labels}

In [None]:
def load_dataset(path):  # path_train: storage/datasets/CASIA_faceAntisp/train_release/  , path_test = storage/datasets/CASIA_faceAntisp/test_release/
    NUM_FOLDERS = 1
    dataset_x = []
    dataset_y = []
    try:
        files = glob.glob(
            '/home/testing/temporal/CASIA/**/*.avi', recursive=True)
        for f in files:
            frames = load_x16_frames_from_video(f)            
            dataset_x = [*dataset_x, *frames]  # verificar
            dataset_x = np.array(dataset_x)            
            fn = f.split("/")
            for i in range(0, 17):
                dataset_y.append(labels[fn[-1][:-4]])
    finally:
        dataset_x = dataset_x.reshape(-1, IMG_SIZE,
                                      IMG_SIZE).swapaxes(1, 2)
    return dataset_x, dataset_y


def get_data_loaders():
    mnist_transforms = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    )
    # We add FileLock here because multiple workers will want to
    # download data, and this may cause overwrites since
    # DataLoader is not threadsafe.
    with FileLock(os.path.expanduser(f"{data_dir}/data.lock")):
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(
                f"{data_dir}", train=True, download=True, transform=mnist_transforms
            ),
            batch_size=64,
            shuffle=True,
        )
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST(
                f"{data_dir}", train=False, download=True, transform=mnist_transforms
            ),
            batch_size=64,
            shuffle=True,
        )
    return train_loader, test_loader