In [1]:
# Helper libraries
import matplotlib
from matplotlib import gridspec
from matplotlib import pyplot as plt
import numpy as np
from PIL import Image
import cv2 as cv
import tqdm
import IPython
from sklearn.metrics import confusion_matrix
from tabulate import tabulate
import os

import glob
import pandas as pd
import random
from colour.plotting import *

from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
from matplotlib import colors
import data_utils as utils
from skimage import io

In [2]:
del_frames = [60, 120, 180, 240, 300]
def clean_frames(genre, remove_beg_end=True):
    if remove_beg_end:
        vids = glob.glob('./{}/*_60.jpg'.format(genre))
        vids = [f[:-6] for f in vids]
    else:
        vids = glob.glob('./{}/*_360.jpg'.format(genre))
        vids = [f[:-7] for f in vids]

    for v in tqdm.tqdm(vids):
        max_fnum = len(glob.glob(v+'*.jpg'))*60
        if remove_beg_end:
            for d in del_frames:
                os.remove(v+str(d)+'.jpg')
                os.remove(v+str(max_fnum - d)+'.jpg')
        remaining_frames = glob.glob(v+'*.jpg')
        for f in remaining_frames:
            rgb = io.imread(f)
            if rgb.mean() < 20:  # get rid of the black frames that occasionally occurs in trailers
                os.remove(f)
                continue
            try:
                rgb = utils.crop_blacks(rgb)
                if np.any(rgb.shape < 200):
                    os.remove(f)
                else:
                    plt.imsave(f, rgb)
            except:
                continue
def sample_frames(base_path, num_frames=20):
    frames = glob.glob(base_path+'*.jpg')
    if len(frames) < 20:
        return None
    sampled_frames = random.sample(frames, num_frames)
    minh = 9999
    minw = 9999
    for f in sampled_frames:
        rgb = io.imread(f)
        if rgb.shape[1] < minw:
            minw = rgb.shape[1]
        if rgb.shape[0] < minh:
            minh = rgb.shape[0]
    for f in sampled_frames:
        rgb = io.imread(f)
        h, w, _ = rgb.shape
        dh = h - minh
        dw = w - minw
        rgb_crop = rgb[round(dh/2):h-round(dh/2 - .5), round(dw/2):w-round(dw/2 - .5)]
        plt.imsave(f, rgb_crop)
    return sampled_frames
        
def view_frames(base_path, num_frames=20):
    frames = glob.glob(base_path+'*.jpg')

    sampled_frames = random.sample(frames, num_frames)
    fig = plt.figure(figsize=(30,24))

    for i in range(4):
        for j in range(5):
            a=fig.add_subplot(4,5,i * 5 + j + 1)
            image = io.imread(sampled_frames[i*5 + j])
            a.imshow(image)
            plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    return sampled_frames


In [37]:
clean_frames('Comedy', remove_beg_end=True)

100%|██████████| 323/323 [05:41<00:00,  1.06s/it]


In [38]:
genre = 'Comedy'
class_num = '8'
vids = glob.glob('./{}/*_360.jpg'.format(genre))
vids = [f[:-7] for f in vids]

# Create 85-15 train-val split

random.shuffle(vids)
split = int(len(vids)*.85)
train_vids = vids[:split]
val_vids = vids[split:]
print(val_vids)

['./Comedy/Ut_DsYTMOrQ_', './Comedy/loTIzXAS7v4_', './Comedy/FZkIlAEbHi4_', './Comedy/Sq5CIH0duMk_', './Comedy/mdBm8vBpvAc_', './Comedy/ntxIBzJ0tU8_', './Comedy/AK6EbfdnTHg_', './Comedy/GuyNP-XyFHs_', './Comedy/YH9ju47g4x0_', './Comedy/IyaFEBI_L24_', './Comedy/uO12W35DpsQ_', './Comedy/LKFuXETZUsI_', './Comedy/1S8RTKL5cW8_', './Comedy/4RI0QvaGoiI_', './Comedy/VWyH_twcMl0_', './Comedy/3zuRkQSx5gs_', './Comedy/bKeW-MGu-qQ_', './Comedy/oDcaZ3StTfI_', './Comedy/Wfql_DoHRKc_', './Comedy/r73tMMDKjas_', './Comedy/u7__TG7swg0_', './Comedy/2A63Ly0Pvpk_', './Comedy/FAfR8omt-CY_', './Comedy/9xXBjtAGIzE_', './Comedy/8UiHFHF-Nqk_', './Comedy/KM1OouzGxPM_', './Comedy/crrYiYreaLk_', './Comedy/eyKOgnaf0BU_', './Comedy/dt__kig8PVU_', './Comedy/lcjN7zkgELM_', './Comedy/XPG0MqIcby8_', './Comedy/9vN6DHB6bJc_', './Comedy/iqyxxMFOFa0_', './Comedy/or_J4mMOH1w_', './Comedy/EI_3ywJLQio_', './Comedy/SApIKVq1iJQ_', './Comedy/Pl9JS8-gnWQ_', './Comedy/l_ngIv5kFNY_', './Comedy/RFuZhw5b1KA_', './Comedy/FNppLrmdyug_',

In [40]:
text_file = 'train.txt'

with open(text_file, 'a+') as file:
    for vid in tqdm.tqdm(train_vids):
        sampled_frames = sample_frames(vid)
        if sampled_frames is not None:
            file.write(','.join(sampled_frames)+','+class_num+'\n')

100%|██████████| 246/246 [05:11<00:00,  1.27s/it]


In [None]:
rgb.shape

In [None]:
plt.imshow(rgb)

In [None]:
h, w,_ = rgb.shape
dh = 20
dw = 27
rgb_crop = rgb[round(dh/2):h - round(dh/2 - .5), round(dw/2):w - round(dw/2 - .5)]
plt.imshow(rgb_crop)

In [None]:
rgb_crop.shape