In [1]:
import torch
import torchvision
import matplotlib.pyplot as plt
import os
import cv2
from tqdm import tqdm
import numpy as np
import random
from collections import defaultdict
from PIL import Image
import re

In [2]:
DATA_FOLDER = 'demo_data'

In [4]:
train_data = os.listdir(DATA_FOLDER)

In [4]:
%%time

def generator(files, n_datapoints, n_frames = 1):
    '''
    Generates random non-overlapping triplets from the videos
    as given by the files variable.
    
    n_frames -- number of frames used to interpolate (default 1)
    n_datapoints -- total nubmer of data points to generate from the files
    '''
    
    # determine
    
    frames_per_video = {}
    
    for file in tqdm(files):
        cap = cv2.VideoCapture(os.path.join(DATA_FOLDER, file))
        
        n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        frames_per_video[file] = n_frames
        
    return frames_per_video
        
gen = generator(train_data, 10)
    

100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 30.73it/s]


Wall time: 139 ms


In [5]:
def get_frames_by_indices(cap, indices):
    frames = []
    
    for index in indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, index)
        _, image = cap.read()
        frames.append(image)
        
    return np.stack(frames)


In [6]:
def gen_basic():

    to_sample = defaultdict(list)
    for file in train_data:
        cap = cv2.VideoCapture(os.path.join(DATA_FOLDER, file))
        n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

        domain = [[i+j for j in range(3)] for i in range(n_frames-2)]

        while len(domain) > 0:
    #         print(len(domain))
            datapoint = random.choice(domain)
            domain = [x for x in domain if sum([x.count(y) for y in datapoint]) == 0]
            to_sample[file].append(datapoint)

        random.shuffle(to_sample[file])

        for datapoint in to_sample[file]:
            frames = get_frames_by_indices(cap, datapoint)

            yield frames
        
    
        
        
    
    
    

In [38]:
%%time
_ = next(gen_basic())

Wall time: 398 ms


In [191]:
sample = random.sample(to_sample, 5)

In [195]:
sample = sorted(sample)

In [39]:
result = np.split(_, indices_or_sections=3)

In [58]:
np.squeeze(frame).shape

(1080, 1920, 3)

In [63]:
%%time
for i, item in tqdm(enumerate(gen_basic())):
    frames = np.split(item, indices_or_sections=3)
    
    for j, frame in enumerate(frames):
        filepath_out = os.path.join('dataset', f'id_{i}_frame_{j}.png')
        frame = np.squeeze(frame)
        frame = cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR)
        im = Image.fromarray(frame)
        im.save(filepath_out)


0it [00:00, ?it/s]
1it [00:01,  1.69s/it]
2it [00:03,  1.69s/it]
3it [00:05,  1.69s/it]
4it [00:06,  1.64s/it]
5it [00:08,  1.63s/it]
6it [00:09,  1.63s/it]
7it [00:11,  1.65s/it]
8it [00:13,  1.64s/it]
9it [00:15,  1.74s/it]
10it [00:16,  1.76s/it]
11it [00:18,  1.72s/it]
12it [00:20,  1.68s/it]
13it [00:21,  1.64s/it]
14it [00:23,  1.65s/it]
15it [00:25,  1.66s/it]
16it [00:26,  1.61s/it]
17it [00:28,  1.60s/it]
18it [00:29,  1.64s/it]
19it [00:31,  1.62s/it]
20it [00:32,  1.59s/it]
21it [00:34,  1.61s/it]
22it [00:36,  1.65s/it]
23it [00:37,  1.62s/it]
24it [00:39,  1.62s/it]
25it [00:41,  1.67s/it]
26it [00:42,  1.68s/it]
27it [00:44,  1.69s/it]
28it [00:46,  1.67s/it]
29it [00:47,  1.67s/it]
30it [00:49,  1.64s/it]
31it [00:51,  1.61s/it]
32it [00:52,  1.57s/it]
33it [00:54,  1.56s/it]
34it [00:55,  1.60s/it]
35it [00:57,  1.64s/it]
36it [00:59,  1.59s/it]
37it [01:00,  1.61s/it]
38it [01:02,  1.63s/it]
39it [01:04,  1.64s/it]
40it [01:05,  1.65s/it]
41it [01:07,  1.68s/it]
42it 

Wall time: 10min 12s


In [50]:
files = os.listdir('output_data')

In [53]:
len(files)

1263

In [56]:
files = [[files[i+j] for i in range(3)] for j in range(len(files)-2)]

[['aajhbrxhzm.mp4_0_0.png',
  'aajhbrxhzm.mp4_0_1.png',
  'aajhbrxhzm.mp4_0_2.png'],
 ['aajhbrxhzm.mp4_0_1.png',
  'aajhbrxhzm.mp4_0_2.png',
  'aajhbrxhzm.mp4_10_0.png'],
 ['aajhbrxhzm.mp4_0_2.png',
  'aajhbrxhzm.mp4_10_0.png',
  'aajhbrxhzm.mp4_10_1.png'],
 ['aajhbrxhzm.mp4_10_0.png',
  'aajhbrxhzm.mp4_10_1.png',
  'aajhbrxhzm.mp4_10_2.png'],
 ['aajhbrxhzm.mp4_10_1.png',
  'aajhbrxhzm.mp4_10_2.png',
  'aajhbrxhzm.mp4_11_0.png'],
 ['aajhbrxhzm.mp4_10_2.png',
  'aajhbrxhzm.mp4_11_0.png',
  'aajhbrxhzm.mp4_11_1.png'],
 ['aajhbrxhzm.mp4_11_0.png',
  'aajhbrxhzm.mp4_11_1.png',
  'aajhbrxhzm.mp4_11_2.png'],
 ['aajhbrxhzm.mp4_11_1.png',
  'aajhbrxhzm.mp4_11_2.png',
  'aajhbrxhzm.mp4_12_0.png'],
 ['aajhbrxhzm.mp4_11_2.png',
  'aajhbrxhzm.mp4_12_0.png',
  'aajhbrxhzm.mp4_12_1.png'],
 ['aajhbrxhzm.mp4_12_0.png',
  'aajhbrxhzm.mp4_12_1.png',
  'aajhbrxhzm.mp4_12_2.png'],
 ['aajhbrxhzm.mp4_12_1.png',
  'aajhbrxhzm.mp4_12_2.png',
  'aajhbrxhzm.mp4_13_0.png'],
 ['aajhbrxhzm.mp4_12_2.png',
  'aajhbr

In [17]:
int(files[0].split('_')[-1][0])

0

In [70]:
x = [*range(123)]

In [77]:
round(3.2)

3

In [140]:
def generator(folder, batch_size = 4, k=None):
    
    raw_files = os.listdir(folder)

    # infer k from files
    if k == None:
        ks = [int(re.findall(r'.+\.mp4_\d+_(\d+)\.png',f)[0]) for f in raw_files]
        k = max(ks)-1

        assert len(raw_files) % (k+2) == 0, f'not all generated datapoints of equal length, {len(raw_files), k}'

    files = [
        [raw_files[i+j] for i in range(3)] for j in range(len(raw_files)-2)
    ]        
    
    files = files[:1]
    print(files)
    random.shuffle(files)

    n_batches = np.ceil(len(files)/batch_size).astype(int)

    for b in range(n_batches):

        batch = files[b*batch_size:(b+1)*batch_size]
        batch_result = []
#         print('batch', batch)
        for frameset in batch:

            imgs = []
#             print(frameset)
            for file in frameset:
#                 print(os.path.join(folder, file))
                img = np.array(Image.open(os.path.join(folder, file)))
                imgs.append(img)
            imgs = np.stack(imgs)
            batch_result.append(imgs)
#         yield batch_result
        yield np.stack(batch_result)

In [None]:
b.shape

In [145]:
b = generator('E:\scriptieAI\output_data')

In [146]:
%%time
b = next(b)

[['aajhbrxhzm.mp4_0_0.png', 'aajhbrxhzm.mp4_0_1.png', 'aajhbrxhzm.mp4_0_2.png']]
batch [['aajhbrxhzm.mp4_0_0.png', 'aajhbrxhzm.mp4_0_1.png', 'aajhbrxhzm.mp4_0_2.png']]
['aajhbrxhzm.mp4_0_0.png', 'aajhbrxhzm.mp4_0_1.png', 'aajhbrxhzm.mp4_0_2.png']
E:\scriptieAI\output_data\aajhbrxhzm.mp4_0_0.png
E:\scriptieAI\output_data\aajhbrxhzm.mp4_0_1.png
E:\scriptieAI\output_data\aajhbrxhzm.mp4_0_2.png
Wall time: 164 ms


In [143]:
b.shape

(1, 3, 1080, 1920, 3)

In [83]:
## generator from images
def gen_basic2(batch_size=4):
    
    for fileset in files:
        imgs = []
        for file in fileset:
            img = Image.open(os.path.join('dataset', file))
            imgs.append(img)
            
        yield np.stack(imgs)
            

gen = gen_basic2()

In [105]:
%%time
next(gen)

Wall time: 160 ms


array([[[[101, 106,  82],
         [ 95, 100,  76],
         [108, 112,  95],
         ...,
         [111, 122, 111],
         [121, 132, 121],
         [130, 141, 130]],

        [[102, 107,  83],
         [ 98, 103,  79],
         [109, 113,  96],
         ...,
         [114, 125, 114],
         [118, 129, 118],
         [125, 136, 125]],

        [[102, 112,  86],
         [ 98, 108,  82],
         [109, 118,  99],
         ...,
         [115, 126, 115],
         [117, 128, 117],
         [119, 130, 119]],

        ...,

        [[ 91,  81,  85],
         [ 92,  82,  86],
         [ 92,  82,  86],
         ...,
         [110,  99,  94],
         [105,  94,  89],
         [102,  91,  86]],

        [[ 92,  82,  86],
         [ 92,  82,  86],
         [ 93,  83,  87],
         ...,
         [115, 104,  99],
         [112, 101,  96],
         [111, 100,  95]],

        [[ 92,  82,  86],
         [ 92,  82,  86],
         [ 92,  82,  86],
         ...,
         [121, 110, 105],
        