In [63]:
import os
from os import listdir, path
import numpy as np
import pickle, argparse
from glob import glob
import cv2
import tensorflow as tf

os.environ["CUDA_VISIBLE_DEVICES"]="1"
print(tf.__version__)

2.1.0


In [64]:
from easydict import EasyDict

args = EasyDict(data_root='LipGAN_dataset_local',
               batch_size=200,
               lr=1e-3,
               img_size=96,
               logdir='logs_ipynb',
               all_images='filenames.pkl')
print(args)

{'data_root': 'LipGAN_dataset_local', 'batch_size': 200, 'lr': 0.001, 'img_size': 96, 'logdir': 'logs_ipynb', 'all_images': 'filenames.pkl'}


### Dataset

In [78]:
import itertools 

half_window_size = 4
mel_step_size = 27
    
def frame_id(fname):
    return int(os.path.basename(fname).split('.')[0])

def choose_ip_frame(frames, gt_frame):
    selected_frames = [f for f in frames if np.abs(frame_id(gt_frame) - frame_id(f)) >= 6]
    if len(selected_frames) == 0:
        selected_frames = frames
        
    return np.random.choice(selected_frames)

def get_audio_segment(center_frame, spec):
    center_frame_id = frame_id(center_frame)
    start_frame_id = center_frame_id - half_window_size

    start_idx = int((81./25.) * start_frame_id) # 25 is fps of LRS2
    end_idx = start_idx + mel_step_size

    return spec[:, start_idx : end_idx] if end_idx <= spec.shape[1] else None

def bgr2rgb(x):
    temp = x[:, :, 0].copy()
    x[:, :, 0] = x[:, :, 2]
    x[:, :, 2] = temp
    
    return x

if not path.exists(args.logdir):
    os.mkdir(args.logdir)

if path.exists(path.join(args.logdir, args.all_images)):
    all_images = pickle.load(open(path.join(args.logdir, args.all_images), 'rb'))
else:
    all_images = glob(path.join("{}/train/*/*/*.jpg".format(args.data_root)))
    pickle.dump(all_images, open(path.join(args.logdir, args.all_images), 'wb'), protocol=pickle.HIGHEST_PROTOCOL)

print ("Will be training on {} images".format(len(all_images)))

np.random.shuffle(all_images)
batches = all_images

def fetch_data():
    while(True):
        index = np.random.randint(0, len(batches))

        '''Get a frame'''
        img_name = batches[index]
        gt_fname = os.path.basename(img_name)
        dir_name = img_name.replace(gt_fname, '')
        frames = glob(dir_name + '/*.jpg')

        if len(frames) < 12:
            continue

        '''Get a melspectrogram'''
        mel_fname = dir_name + "./mels.npz"
        mel = np.load(mel_fname)['spec']
        mel = get_audio_segment(gt_fname, mel)

        if mel is None or mel.shape[1] != mel_step_size:
            continue

        if sum(np.isnan(mel.flatten())) > 0:
            continue    

        '''Ground Truth & IP '''
        img_gt = cv2.imread(img_name)        
        img_gt = bgr2rgb(img_gt)
        img_gt = img_gt / 255.0
        img_gt = cv2.resize(img_gt, (args.img_size, args.img_size))

        img_gt_masked = img_gt.copy()
        img_gt_masked[args.img_size//2:] = 0 

        ip_fname = choose_ip_frame(frames, gt_fname)
        img_ip = cv2.imread(ip_fname)
        img_ip = bgr2rgb(img_ip)
        img_ip = img_ip / 255.0
        img_ip = cv2.resize(img_ip, (args.img_size, args.img_size))

        break

    return (img_gt, img_gt_masked, img_ip, mel)
 
def gen(): 
    for i in itertools.count(1):
        img_gt_list = []
        img_gt_masked_list = []
        img_ip_list = []
        mel_list = []
        for _ in range(args.batch_size):
            img_gt, img_gt_masked, img_ip, mel = fetch_data()
            img_gt_list.append(img_gt)
            img_gt_masked_list.append(img_gt_masked)
            img_ip_list.append(img_ip)
            mel_list.append(mel)
            
        img_gt_list = np.stack(img_gt_list)
        img_gt_masked_list = np.stack(img_gt_masked_list)
        img_ip_list = np.stack(img_ip_list)
        mel_list = np.stack(mel_list)
        
        yield(img_gt_list, img_gt_masked_list, img_ip_list, mel_list)
        
dataset = tf.data.Dataset.from_generator(gen, 
     (tf.float32, tf.float32, tf.float32, tf.float32))

Will be training on 1678240 images


In [81]:
for step, data in enumerate(dataset):
    print(data[0].shape)
    break

(200, 96, 96, 3)
