In [11]:
%matplotlib inline

In [12]:
import math
import random
import skimage.io as io
import matplotlib.pyplot as plt
import keras
import pickle
import numpy as np
import cv2
import os
from os import listdir
from os.path import isfile, join
import sys
import dlib
# import skvideo.io
import json
from keras.preprocessing import image
from keras.applications.vgg16 import preprocess_input
import glob
import re
from collections import defaultdict

import nltk
from nltk.corpus import cmudict
import tensorflow as tf

In [13]:
os.environ['CUDA_VISIBLE_DEVICES']='2'

# Prepare Data

## Phonemes

In [14]:
phoneme_list = [] 
phoneme_dict = {}

with open("/n/fs/scratch/jiaqis/cmudict-master/cmudict.phones", 'r') as fp:
    i = 0
    line = fp.readline()
    while line:
        phoneme = line.split()[0].strip()
        phoneme_property = line.split()[1].strip()
        phoneme_list.append((phoneme, phoneme_property))
        phoneme_dict[phoneme] = i
        line = fp.readline()
        i=i+1
        
phoneme_dict['START'] = 39
phoneme_dict["END"] = 40
phoneme_dict["BLANK"] = 41
print(phoneme_list, phoneme_dict)

([('AA', 'vowel'), ('AE', 'vowel'), ('AH', 'vowel'), ('AO', 'vowel'), ('AW', 'vowel'), ('AY', 'vowel'), ('B', 'stop'), ('CH', 'affricate'), ('D', 'stop'), ('DH', 'fricative'), ('EH', 'vowel'), ('ER', 'vowel'), ('EY', 'vowel'), ('F', 'fricative'), ('G', 'stop'), ('HH', 'aspirate'), ('IH', 'vowel'), ('IY', 'vowel'), ('JH', 'affricate'), ('K', 'stop'), ('L', 'liquid'), ('M', 'nasal'), ('N', 'nasal'), ('NG', 'nasal'), ('OW', 'vowel'), ('OY', 'vowel'), ('P', 'stop'), ('R', 'liquid'), ('S', 'fricative'), ('SH', 'fricative'), ('T', 'stop'), ('TH', 'fricative'), ('UH', 'vowel'), ('UW', 'vowel'), ('V', 'fricative'), ('W', 'semivowel'), ('Y', 'semivowel'), ('Z', 'fricative'), ('ZH', 'fricative')], {'IY': 17, 'START': 39, 'W': 35, 'DH': 9, 'Y': 36, 'HH': 15, 'CH': 7, 'JH': 18, 'ZH': 38, 'END': 40, 'EH': 10, 'NG': 23, 'TH': 31, 'BLANK': 41, 'AA': 0, 'B': 6, 'AE': 1, 'D': 8, 'G': 14, 'F': 13, 'AH': 2, 'K': 19, 'M': 21, 'L': 20, 'AO': 3, 'N': 22, 'IH': 16, 'S': 28, 'R': 27, 'EY': 12, 'T': 30, 'AW': 

In [16]:
pron_dict = cmudict.dict()

In [17]:
def clean_pron(pron):
    """Remove stress from pronunciations."""
    return re.sub(r"\d", "", pron)

def make_triphones(pron):
    """Output triphones from a word's pronunciation."""
    if len(pron) < 3:
        return []
    # Junk on end is to make word boundaries work
    return ([((pron[idx - 2], pron[idx - 1]), pron[idx])
             for idx in range(2, len(pron))] + [(('#', '#'), pron[0])] +
            [((pron[-2], pron[-1]), '#')])
                                                
def triphone_probs(prons):
    """Calculate triphone probabilities for pronunciations."""
    context_counts = defaultdict(lambda: defaultdict(int))
    for pron in prons:
        for (context, phoneme) in make_triphones(pron):
            context_counts[context][phoneme] += 1
            
    for (context, outcomes) in context_counts.items():
        total_outcomes = sum(outcomes.values())
        for outcome, count in outcomes.items():
            context_counts[context][outcome] = float(count) / total_outcomes
        
    return context_counts

## Video Volume and Facial Features

In [18]:
DATA_DIR = "/n/fs/scratch/jiaqis/LRS3-TED/"
SAVE_DIR = "/n/fs/scratch/jiaqis/LRS3-TED-Extracted/"

In [19]:
def get_dataset_list(dataDir, setName):
    # Images, facial/mouth features, text-> phonetic
    data_list = []
    for urlDir in glob.glob(os.path.join(dataDir, setName, "*/")):
        url = urlDir.split('/')[-2]
        for idFilename in glob.glob(os.path.join(urlDir, '*.txt')):
            index = idFilename.split('/')[-1].split('.')[0]
            filepath = os.path.join(dataDir, setName, url, index)
            
            text = open(filepath+".txt", 'r').readline()
            words = text[5:].lower().strip().split()
            flag = False
            for word in words:
                if word not in pron_dict:
                    flag=True
                    break
            if flag:
                continue
            imgfiles = sorted(glob.glob(filepath + "_*.jpg"))
            if len(imgfiles) > 100:
                continue
            
            ID = idFilename.split('/')[-1].split('.')[0]
            data_list.append((url, ID))
    return data_list

In [None]:
# test_ID_list = get_dataset_list(SAVE_DIR, "test")

In [None]:
# trainval_ID_list = get_dataset_list(SAVE_DIR, "trainval")

In [None]:
# json.dump(test_ID_list, open('test_ID_list.json', "w"))

In [None]:
# json.dump(trainval_ID_list, open('trainval_ID_list.json', "w"))

In [20]:
test_ID_list = json.load(open('test_ID_list.json', "r"))
trainval_ID_list = json.load(open('trainval_ID_list.json', "r"))

In [21]:
print(len(test_ID_list), len(trainval_ID_list))

(730, 3360)


# Data Loader

In [22]:
FPS = 25
FRAME_ROWS = 120
FRAME_COLS = 120
NFRAMES = 5 # size of input volume of frames
MARGIN = NFRAMES/2
COLORS = 1 # grayscale
CHANNELS = COLORS*NFRAMES
MAX_FRAMES_COUNT= 250 # corresponding to 10 seconds, 25Hz*10

EXAMPLE_FILEPATH = "/n/fs/scratch/jiaqis/LRS3-TED-Extracted/test/0Fi83BHQsMA/00002"

In [23]:
video_tensor_size = (100, 120, 120, 3) 
keypoint_img_size = (224, 224)
keypoint_size=20
label_seq_size=100
n_classes=39
num_tokens=n_classes+3

In [24]:
def prepare_data(filepath, img_size, keypoint_img_size, keypoint_size, label_seq_size):
    # images
    # frames x rows x cols x channels
    visual_cube = []
    # keypoint features
    feature_cube = []
    features = json.load(open(filepath + ".json", 'r'))
    # Target Text/phonemes
    labels = [phoneme_dict["START"]]
    text = open(filepath+".txt", 'r').readline()
    words = text[5:].lower().strip().split()
    for word in words:
        word_phonemes = pron_dict[word][0]
        word_indices = [phoneme_dict[clean_pron(phon)] for phon in word_phonemes]
        labels.extend(word_indices)
    labels.append(phoneme_dict["END"])        
    acc = 0
    for imgFilename in sorted(glob.glob(filepath + "_*_mouth.jpg")):
#         if 'mouth' in imgFilename:
#             continue
        x = image.img_to_array(
              image.load_img(imgFilename, target_size=img_size))/255.0
#         x = np.expand_dims(x, axis=0)
#         x = preprocess_input(x)
        visual_cube.append(x)
        
        mask = np.zeros((keypoint_img_size[0], keypoint_img_size[1], keypoint_size))
        framenum = str(int(imgFilename.split("_")[-2].split(".")[0]))
        f_feature = features[framenum]['mouthCoords']
        for ft_index in range(keypoint_size):
            # TODO: check range of outputs
            keypoint_x = f_feature[ft_index][0] - 1
            keypoint_y = f_feature[ft_index][1] - 1
            mask[keypoint_y, keypoint_x, ft_index] = 1.0
        feature_cube.append(mask)
        acc+=1
    return np.array(visual_cube), np.array(feature_cube), np.array(labels)

In [25]:
visual_cube, feature_cube, labels = prepare_data(EXAMPLE_FILEPATH, (120, 120, 3), (224, 224), 20, 100)

In [None]:
print(visual_cube[:, :, :, :].shape)
print(feature_cube.shape)
print(labels)

In [None]:
# ## TODO: Not working
# import cv2
# cv2.imshow( "Display window", visual_cube[0, :, :, :])

In [51]:
import numpy as np
import keras

class DataGenerator(keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, data_dir, subset, list_IDs, prons, phonemes,
                       video_tensor_size=(200, 224, 224, 3), 
                       keypoint_img_size = (224, 224),
                       keypoint_size=20, label_seq_size=90, 
                       batch_size=32,
                       n_classes=39, num_tokens=42, shuffle=True):
        'Initialization'
        self.data_dir = data_dir
        self.subset = subset
        self.video_tensor_size = video_tensor_size
        self.img_size = (video_tensor_size[1], video_tensor_size[2], video_tensor_size[3])
        self.keypoint_img_size = keypoint_img_size
        self.keypoint_size =keypoint_size
        self.label_seq_size = label_seq_size
        self.batch_size = batch_size
        self.list_IDs = list_IDs
        self.prons = prons
        self.phonemes = phonemes
        self.n_classes = n_classes
        self.num_tokens = num_tokens
        self.shuffle = shuffle
        self.on_epoch_end()
        
    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.ceil(float(len(self.list_IDs)) / float(self.batch_size)))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        list_IDs_temp = [self.list_IDs[k] for k in indexes]
        # Generate data
        X, Y = self.data_generation(list_IDs_temp)

        return X, Y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def data_generation(self, list_IDs_temp):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        v_ID = list_IDs_temp[0]

        v_url, v_index = v_ID
        filepath = os.path.join(self.data_dir, self.subset, v_url, v_index)
        v_V, v_F, v_T = prepare_data(filepath, self.img_size, self.keypoint_img_size,
                                       self.keypoint_size, self.label_seq_size)
        num_frames = v_V.shape[0]
        
        T_LEN = np.ones((1, 1))
        T_LEN[0, 0] = len(v_T)-1
        
#         decoder_T = np.zeros((1, len(v_T)-1, self.num_tokens))
#         for i in range(len(v_T)-1):
#             decoder_T[0, i, v_T[i]] = 1.0
        
        return [v_V[np.newaxis,], v_F[np.newaxis,], v_T[np.newaxis, :-1], v_T[np.newaxis,1:], T_LEN], \
                [np.zeros_like(T_LEN), v_T[np.newaxis, 1:, np.newaxis]]

In [52]:
train_generator =  DataGenerator(SAVE_DIR, 'trainval', trainval_ID_list, pron_dict, phoneme_dict,
                       video_tensor_size=video_tensor_size, 
                       keypoint_img_size=keypoint_img_size,
                       keypoint_size=keypoint_size, 
                       label_seq_size=label_seq_size, batch_size=1,
                       n_classes=n_classes, shuffle=True)

val_generator = DataGenerator(SAVE_DIR, 'test', test_ID_list, pron_dict, phoneme_dict,
                       video_tensor_size=video_tensor_size, 
                       keypoint_img_size=keypoint_img_size,
                       keypoint_size=keypoint_size, 
                       label_seq_size=label_seq_size, batch_size=1,
                       n_classes=n_classes, shuffle=True)

In [46]:
try_inputs, try_output = train_generator.data_generation(train_generator.list_IDs[:10])

In [None]:
try_V = try_inputs[0]
try_F = try_inputs[1]
try_dT = try_inputs[2]
try_T = try_inputs[3]
try_T_LEN = try_inputs[4]

In [None]:
print(try_V.shape, try_F.shape, try_dT.shape, try_T.shape, try_T_LEN.shape)

# Model

In [47]:
import keras
from keras.models import Model
from keras.layers import Input, Dense, Dropout, BatchNormalization,ZeroPadding2D, Embedding, LSTM, Bidirectional, Add, Multiply, Activation, Masking, Concatenate
from keras.layers import TimeDistributed, GlobalAveragePooling3D, Conv2D, Flatten, Permute, RepeatVector, Lambda, GlobalAveragePooling2D, MaxPooling2D
from keras.layers import Dot
import seq2seq
from seq2seq.models import AttentionSeq2Seq, Seq2Seq, SimpleSeq2Seq

In [48]:
def attention_block(inputs):
    # inputs.shape = (batch_size, time_steps, input_dim)
    input_dim = int(inputs.shape[2])
    a = Permute((2, 1))(inputs)
    a = Dense(video_tensor_size[0], activation='softmax')(a)
    a = Lambda(lambda x: keras.backend.mean(x, axis=1), name='dim_reduction')(a)
    a = RepeatVector(input_dim)(a)
    a_probs = Permute((2, 1), name='attention_vec')(a)
    output_attention_mul = Multiply(name='attention_mul')([inputs, a_probs])
    return output_attention_mul

In [53]:
def visual_conv_net(inputs, maxpool=True):
    # 224 x 224 x 64
    conv1 = TimeDistributed(Conv2D(64, kernel_size=(3,3), padding='same', 
                                   activation="relu"))(inputs)
    # 112 x 112 x 64
    if maxpool:
        down1 = TimeDistributed(MaxPooling2D(pool_size=(2, 2)))(conv1)
    else:
        down1 = TimeDistributed(Conv2D(64, kernel_size=(2,2), 
                                      strides=(2,2), 
                                      padding='same', 
                                      activation=None))(conv1)
    # 112 x 112 x 128
    conv2 = TimeDistributed(Conv2D(128, (3,3), padding='same', activation="relu"))(down1)
    # 56 x 56 x 128
    if maxpool:
        down2 = TimeDistributed(MaxPooling2D(pool_size=(2, 2)))(conv2)
    else:
        down2 = TimeDistributed(Conv2D(128, kernel_size=(2,2), 
                                      strides=(2,2), 
                                      padding='same', 
                                      activation=None))(conv2)
    # 56 x 56 x 256
    conv3 = TimeDistributed(Conv2D(256, (3,3), padding='same', activation="relu"))(down2)
    # 28 x 28 x 256
    if maxpool:
        down3 = TimeDistributed(MaxPooling2D(pool_size=(2, 2)))(conv3)
    else:
        down3 = TimeDistributed(Conv2D(256, kernel_size=(2, 2), 
                                      strides=(2,2), 
                                      padding='same', 
                                      activation=None))(conv3)
    # 28 x 28 x 256
    conv4 = TimeDistributed(Conv2D(256, (3,3), padding='same', activation="relu"))(down3)
    # 14 x 14 x 256
    if maxpool:
        down4 = TimeDistributed(MaxPooling2D(pool_size=(2, 2)))(conv4)
    else:
        down4 = TimeDistributed(Conv2D(256, kernel_size=(2,2), 
                                      strides=(2,2), 
                                      padding='same', 
                                      activation=None))(conv4)
    # 14 x 14 x 256
    conv5 = TimeDistributed(Conv2D(256, (3,3), padding='same', activation="relu"))(down4)
    # 7 x 7 x 256
    if maxpool:
        down5 = TimeDistributed(MaxPooling2D(pool_size=(2, 2)))(conv5)
    else:
        down5 = TimeDistributed(Conv2D(256, kernel_size=(2,2), 
                                      strides=(2,2), 
                                      padding='same', 
                                      activation=None))(conv5)
    return down5

In [54]:
# def ctc_lambda_func(args):
#     y_pred, labels, input_length, label_length = args
#     # the 2 is critical here since the first couple outputs of the RNN
#     # tend to be garbage:
#     y_pred = y_pred[:, 2:, :]
#     return keras.backend.ctc_batch_cost(labels, y_pred, input_length, label_length)

def ctc_lambda_func(args):
    base_output, labels, label_length = args 
    base_output_shape = tf.shape(base_output)
    sequence_length = tf.fill([base_output_shape[0], 1], base_output_shape[1])
    print(labels)
    print(base_output)
    print(sequence_length)
    print(label_length)
    
    return keras.backend.ctc_batch_cost(labels, base_output, sequence_length, label_length)

In [55]:
def masked_crossentropy_func(args):  
    target, output = args
    print(target)
#     target_dense = keras.utils.to_categorical(
#                         target,
#                         num_classes=n_classes+1
#                     )
#     # Compute cross entropy for each frame.
#     cross_entropy = target_dense * tf.log(output)
#     cross_entropy = -tf.reduce_sum(cross_entropy, 2)
    cross_entropy = keras.backend.sparse_categorical_crossentropy(
                        target,
                        output,
                        from_logits=False,
                        axis=-1
                    )
    print(cross_entropy)
    mask = tf.cast(target < n_classes, dtype=tf.float32)
    print(mask)
    cross_entropy *= mask
    # Average over actual sequence lengths.
    cross_entropy = tf.reduce_sum(cross_entropy, 1)
    cross_entropy /= tf.reduce_sum(mask, 1)
    return cross_entropy

In [79]:
# Define an input sequence and process it.
input_V_tensor = Input(shape=(None, 
                              video_tensor_size[1], 
                              video_tensor_size[2], 
                              video_tensor_size[3]), name="V")
input_F_tensor = Input(shape=(None, 
                              keypoint_img_size[0], 
                              keypoint_img_size[1], 
                              keypoint_size), name="F")

labels = Input(shape=(None,), name="labels")

label_length = Input(shape=(1,), name="label_length")

# 224 x 224 x 23
# input_tensor = Concatenate(axis=-1)([input_V_tensor, input_F_tensor])
input_tensor = input_V_tensor

conv_output_tensor = visual_conv_net(input_tensor)

# fc_out = TimeDistributed(GlobalAveragePooling2D())(conv_output_tensor)
fc_in = TimeDistributed(Flatten())(conv_output_tensor)
fc_out = TimeDistributed(Dense(256, activation="relu"))(fc_in)

print(fc_out)

encoder, encoder_state_h, encoder_state_c = LSTM(256, return_sequences=True, 
                                                 return_state=True)(fc_out)
encoder_last = encoder[:,-1,:]

# Set up the decoder, using `encoder_states` as initial state.
decoder_inputs = Input(shape=(None,))

decode_input_embedding = Embedding(num_tokens+1, 256, mask_zero=True)
decoder = decode_input_embedding(decoder_inputs)
decoder_lstm = LSTM(256, return_sequences=True, return_state=True)
decoder, _, _ =decoder_lstm(decoder, initial_state=[encoder_state_h, encoder_state_c])
# Equation (7) with 'dot' score from Section 3.1 in the paper. 
# Note that we reuse Softmax-activation layer instead of writing tensor calculation
attention = Dot(axes=[2, 2])([decoder, encoder])
attention = Activation('softmax')(attention)

context = Dot(axes=[2,1])([attention, encoder])
decoder_combined_context = Concatenate()([context, decoder])

# Has another weight + tanh layer as described in equation (5) of the paper
decode_dense =  TimeDistributed(Dense(256, activation="tanh"))
decoded = decode_dense(decoder_combined_context)
print(decode_dense, decoded)
decode_output_dense =  Dense(num_tokens, activation='softmax', name='output_sequence')
decoder_outputs = decode_output_dense(decoded)

# decoder_outputs = TimeDistributed(Dense(n_classes+1, activation='softmax'), name='output_sequence')(decoded)

loss = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([decoder_outputs, labels, label_length])

ce_loss = Lambda(masked_crossentropy_func, output_shape=(1,), name='masked_ce')([labels, decoder_outputs])

# Define the model that will turn
# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
model = Model([input_V_tensor, input_F_tensor, decoder_inputs, labels, label_length], [ce_loss, decoder_outputs])
print(model.summary())

Tensor("time_distributed_139/Reshape_1:0", shape=(?, ?, 256), dtype=float32)
(<keras.layers.wrappers.TimeDistributed object at 0x7f2b3fa44a10>, <tf.Tensor 'time_distributed_140/Reshape_1:0' shape=(?, ?, 256) dtype=float32>)
Tensor("labels_10:0", shape=(?, ?), dtype=float32)
Tensor("output_sequence_9/truediv:0", shape=(?, ?, 42), dtype=float32)
Tensor("ctc_6/Fill:0", shape=(?, 1), dtype=int32)
Tensor("label_length_10:0", shape=(?, 1), dtype=float32)
Tensor("labels_10:0", shape=(?, ?), dtype=float32)
Tensor("masked_ce_6/Reshape_2:0", shape=(?, ?), dtype=float32)
Tensor("masked_ce_6/Cast_1:0", shape=(?, ?), dtype=float32)
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
V (InputLayer)                  (None, None, 120, 12 0                                            
____________________________________________________________________________

In [57]:
# ##################
# # Baseline Model #
# ##################
# input_V_tensor = Input(shape=video_tensor_size, name="V")
# input_F_tensor = Input(shape=(video_tensor_size[0], 
#                               keypoint_img_size[0], 
#                               keypoint_img_size[1], 
#                               keypoint_size), name="F")

# labels = Input(shape=(label_seq_size,), name="labels")
# label_length = Input(shape=(1,), name="label_length")

# # 224 x 224 x 23
# # input_tensor = Concatenate(axis=-1)([input_V_tensor, input_F_tensor])
# input_tensor = input_V_tensor

# conv_output_tensor = visual_conv_net(input_tensor)

# # fc_out = TimeDistributed(GlobalAveragePooling2D())(conv_output_tensor)
# fc_in = TimeDistributed(Flatten())(conv_output_tensor)
# fc_out = TimeDistributed(Dense(256, activation="relu"))(fc_in)

# print(conv_output_tensor)
# print(fc_out)

# # att_seq2seq = AttentionSeq2Seq(input_dim=128, input_length=video_tensor_size[0], 
# #                          hidden_dim=128, 
# #                          output_length=label_seq_size, 
# #                          output_dim=n_classes+1,
# #                          depth=2)
# # decoded = att_seq2seq(fc_out)

# # # LSTM Encoder
# # encoder_lstm = LSTM(256, return_sequences=True, return_state=True)
# # encoder_outputs, state_h, state_c = encoder_lstm(fc_out)
# # encoder_states = [state_h, state_c]

# # # Sequence Placeholder
# # # decoder_inputs = Input(shape=(None, n_classes+2))

# # # LSTM Decoder
# # decoder_lstm = LSTM(256, return_sequences=True, return_state=True)
# # decoded, _, _ = decoder_lstm(encoder_outputs, initial_state=encoder_states)

# fc_out = Masking(mask_value=0.0)(fc_out)

# seq2seq = SimpleSeq2Seq(output_dim=n_classes+1, output_length=label_seq_size,
#             input_dim=256, input_length=video_tensor_size[0],
#             hidden_dim=256, depth=2, unroll=False,
#             stateful=False, dropout=0.3)

# decoded = seq2seq(fc_out)

# decoder_outputs = TimeDistributed(Dense(n_classes+1, activation='softmax'), name='output_sequence')(decoded)
# # decoder_outputs = Lambda(lambda x:x, name='output_sequence')(decoded)

# loss = Lambda(ctc_lambda_func, output_shape=(1,), name='ctc')([decoder_outputs, labels, label_length])

# ce_loss = Lambda(masked_crossentropy_func, output_shape=(1,), name='masked_ce')([labels, decoder_outputs])

# model = Model(inputs=[input_V_tensor, input_F_tensor, labels, label_length], outputs=[ce_loss, decoder_outputs])

# print(model.summary())

## Setup Checkpoint

In [58]:
path = "./sessions/altseq2srq-mouthtrue-ce"
checkpoints_path = os.path.join(path, 'checkpoints')
history_filename = 'history_' + path[path.rindex('/') + 1:] + '.csv'
early_stopping_patience = 10

if not os.path.exists("./sessions"):
    os.mkdir("./sessions")

if not os.path.exists(path):
    os.mkdir(path)

if not os.path.exists(checkpoints_path):
    os.mkdir(checkpoints_path)

In [80]:
model.load_weights('./sessions/altseq2srq-mouthtrue-ce/checkpoints/checkpoint.00052-1.939.hdf5')

In [None]:
model.load_weights('./sessions/lstm/checkpoints/checkpoint.00002-55.906.hdf5')

In [59]:
def loss_func(y_true, y_pred):
    return y_pred

def dummy_loss_func(y_true, y_pred):
    return tf.fill([tf.shape(y_true)[0], 1], 0.0)

In [None]:
# def ctc_loss(y_true, y_pred):
#     y_true_sparse = tf.argmax(y_true, axis=-1)
#     y_true_valid = tf.cast(y_true_sparse < n_classes, tf.float32)
#     print(y_true_valid)
#     print(y_pred)
#     label_length = tf.reduce_sum(y_true_valid, axis=-1, keepdims=True)
#     print(label_length)
#     input_length = label_seq_size * tf.ones((tf.shape(y_pred)[0], 1))
#     print(input_length)
#     loss_tensor = tf.keras.backend.ctc_batch_cost(
#                         y_true_sparse,
#                         y_pred,
#                         input_length,
#                         label_length
#                     )
#     return tf.reduce_mean(loss_tensor)

# def ctc_loss_2(y_true, y_pred):
#     sparse = tf.contrib.layers.dense_to_sparse(y_true)
#     input_length = label_seq_size * tf.ones((tf.shape(y_pred)[0], 1))
#     y_true_valid = tf.cast(y_true_sparse<n_classes, tf.float32)
#     label_length = tf.reduce_sum(y_true_valid, axis=-1, keepdims=True)
#     loss_tensor = tf.nn.ctc_loss(
#             sparse,
#             y_pred,
#             label_length,
#             preprocess_collapse_repeated=False,
#             ctc_merge_repeated=True,
#             ignore_longer_outputs_than_inputs=True,
#             time_major=False
#         )
#     return tf.reduce_mean(loss_tensor)

In [42]:
def test_edit_distance(truth, hyp):
    truth = tf.reshape(truth,(tf.shape(truth)[0], tf.shape(truth)[1]))
    truth = tf.cast(truth, dtype=tf.int64)
    truth_idx = tf.where(tf.not_equal(truth, num_tokens))
    # Use tf.shape(a_t, out_type=tf.int64) instead of a_t.get_shape() if tensor shape is dynamic
    truth_sparse = tf.SparseTensor(truth_idx, tf.gather_nd(truth, truth_idx), tf.shape(truth, out_type=tf.int64))
    
    hyp_dense = tf.argmax(hyp, axis=-1)
    hyp_idx = tf.where(tf.not_equal(hyp_dense, num_tokens))
    # Use tf.shape(a_t, out_type=tf.int64) instead of a_t.get_shape() if tensor shape is dynamic
    hyp_sparse = tf.SparseTensor(hyp_idx, tf.gather_nd(hyp_dense, hyp_idx), tf.shape(hyp_dense, out_type=tf.int64))
#     tf.contrib.layers.dense_to_sparse(
#     tensor,
#     eos_token=0,
#     outputs_collections=None,
#     scope=None
# )
    print(truth, hyp_dense)
    editDist = tf.edit_distance(hyp_sparse, truth_sparse, normalize=True)
    return editDist

In [43]:
def sparse_accuracy(y_true, y_pred):
    y_pred_sparse = tf.argmax(y_pred, axis=-1)
    return tf.reduce_mean(tf.cast(y_pred_sparse==y_true, tf.float32))

In [60]:
from keras.optimizers import Adam, SGD, RMSprop
# model.compile(optimizer=Adam(lr=0.001), 
#               loss='categorical_crossentropy', metrics=[ctc_loss])

model.compile(optimizer=RMSprop(lr=0.001, clipnorm=200), loss={'masked_ce': loss_func, 
                                                               'output_sequence':'sparse_categorical_crossentropy'},
                                          loss_weights={'masked_ce': 0.0, 'output_sequence':1.0},
                                          metrics={'output_sequence':sparse_accuracy, 'output_sequence':test_edit_distance})

(<tf.Tensor 'metrics/test_edit_distance/Cast:0' shape=(?, ?) dtype=int64>, <tf.Tensor 'metrics/test_edit_distance/ArgMax:0' shape=(?, ?) dtype=int64>)


In [61]:
def get_callbacks():
    return [
        keras.callbacks.ReduceLROnPlateau(patience=early_stopping_patience / 2,
                                              cooldown=early_stopping_patience / 4,
                                              verbose=1),
        keras.callbacks.EarlyStopping(patience=early_stopping_patience, verbose=1,
                                          monitor='val_loss'),
        keras.callbacks.ModelCheckpoint(os.path.join(checkpoints_path, 'checkpoint.{epoch:05d}-{val_loss:.3f}.hdf5')),
        keras.callbacks.CSVLogger(os.path.join(path, history_filename), append=True)
    ]

In [62]:
model.fit_generator(train_generator, 
          epochs=100, 
          verbose=1, 
          callbacks=get_callbacks(), 
          validation_data=val_generator, 
          shuffle=True, 
          initial_epoch=0, 
          steps_per_epoch=400, 
          validation_steps=100)

Epoch 1/100


  '. They will not be included '


Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100


Epoch 28/100
Epoch 29/100
Epoch 30/100

Epoch 00030: ReduceLROnPlateau reducing learning rate to 0.00010000000475.
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100

Epoch 00037: ReduceLROnPlateau reducing learning rate to 1.0000000475e-05.
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100

Epoch 00047: ReduceLROnPlateau reducing learning rate to 1.00000006569e-06.
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100


Epoch 00052: early stopping


<keras.callbacks.History at 0x7f2b212eda90>

## Inference

In [81]:
# Next: inference mode (sampling).
# Here's the drill:
# 1) encode input and retrieve initial decoder state
# 2) run one step of decoder with this initial state
# and a "start of sequence" token as target.
# Output will be the next target token
# 3) Repeat with the current target token and current states

# Define sampling models
latent_dim = 256

encoder_model = Model([input_V_tensor, input_F_tensor, decoder_inputs, labels, label_length], 
                      [encoder, encoder_state_h, encoder_state_c])

decoder_state_input_h = Input(shape=(latent_dim,))
decoder_state_input_c = Input(shape=(latent_dim,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]


embed_decoder_inputs = decode_input_embedding(decoder_inputs)
decoder_outputs, state_h, state_c = decoder_lstm(
    embed_decoder_inputs, initial_state=decoder_states_inputs)
decoder_states = [state_h, state_c]

attention = Dot(axes=[2, 2])([decoder_outputs, encoder])
attention = Activation('softmax')(attention)

context = Dot(axes=[2, 1])([attention, encoder])
decoder_combined_context = Concatenate()([context, decoder_outputs])

print(decoder_combined_context, context)
# Has another weight + tanh layer as described in equation (5) of the paper
decoded = decode_dense(decoder_combined_context) # equation (5) of the paper
decoder_outputs = decode_output_dense(decoded)

decoder_model = Model(
    [decoder_inputs] + decoder_states_inputs,
    [decoder_outputs] + decoder_states)

# Reverse-lookup token index to decode sequences back to
# something readable.
reverse_input_char_index = dict(
    (index, key) for key, index in phoneme_dict.iteritems())
reverse_target_char_index = dict(
    (index, key) for key, index in phoneme_dict.iteritems())

def decode_sequence(input_seq):
    # Encode the input as state vectors.
    states_value = encoder_model.predict(input_seq)

    # Generate empty target sequence of length 1.
    target_seq = np.zeros((1, 1, num_tokens))
    # Populate the first character of target sequence with the start character.
    target_seq[0, 0, phoneme_dict['START']] = 1.

    # Sampling loop for a batch of sequences
    # (to simplify, here we assume a batch of size 1).
    stop_condition = False
    decoded_sentence = ''
    while not stop_condition:
        output_tokens, h, c = decoder_model.predict(
            [target_seq] + states_value)

        # Sample a token
        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        sampled_char = reverse_target_char_index[sampled_token_index]
        decoded_sentence += sampled_char

        # Exit condition: either hit max length
        # or find stop character.
        if (sampled_char == 'END' or
           len(decoded_sentence) > label_seq_size):
            stop_condition = True

        # Update the target sequence (of length 1).
        target_seq = np.zeros((1, 1, num_tokens))
        target_seq[0, 0, sampled_token_index] = 1.

        # Update states
        states_value = [h, c]

    return decoded_sentence


for seq_index in range(10):
    # Take one sequence (part of the training set)
    # for trying out decoding.
    test_inputs, test_output = val_generator.data_generation(val_generator.list_IDs[seq_index:seq_index+1])
    decoded_sentence = decode_sequence(test_inputs)
    print('-')
    print('Input sentence:', [reverse_input_char_index[ph_id] for ph_id in list(test_inputs[3][0])])
    print('Decoded sentence:', decoded_sentence)

(<tf.Tensor 'concatenate_16/concat:0' shape=(?, ?, 512) dtype=float32>, <tf.Tensor 'dot_32/MatMul:0' shape=(?, ?, 256) dtype=float32>)


ValueError: Graph disconnected: cannot obtain value for tensor Tensor("V_10:0", shape=(?, ?, 120, 120, 3), dtype=float32) at layer "V". The following previous layers were accessed without issue: []

In [None]:
def generate(encoder_input):
    encoder_input = transform(input_encoding, [text.lower()], INPUT_LENGTH)
    decoder_input = np.zeros(shape=(len(encoder_input), OUTPUT_LENGTH))
    decoder_input[:,0] = START_CHAR_CODE
    for i in range(1, OUTPUT_LENGTH):
        output = model.predict([encoder_input, decoder_input]).argmax(axis=2)
        decoder_input[:,i] = output[:,i]
    return decoder_input[:,1:]

def decode(decoding, sequence):
    text = ''
    for i in sequence:
        if i == 0:
            break
        text += output_decoding[i]
    return text

def to_katakana(text):
    decoder_output = generate(text)
    return decode(output_decoding, decoder_output[0])

...
to_katakana('Banana')           # バナナ
to_katakana('Peter Parker')     # ピーター・パーカー
to_katakana('Jon Snow')         # ジョン・スノー