In [1]:
import sys, os
sys.path.insert(0, os.path.abspath('..'))

In [2]:
import tensorflow as tf
tf.enable_eager_execution()
import numpy as np

## VSPNet layers

In [3]:
from modules.lipreading import VSPNet

Using TensorFlow backend.


In [4]:
from modules.metrics import vspnet_loss as loss_function

## Prepare data

In [5]:
import numpy as np
from glob import glob
import h5py
import time
from os import path
from keras.preprocessing.text import Tokenizer 
from keras.preprocessing.sequence import pad_sequences
from modules.framestream import VisemeStream
from modules.generators import GeneratorInterface
from modules.utils import parse_config, get_sample_ids, load_tokenizer, Log

In [6]:
def build_digits_tokenizer():
    tokenizer = Tokenizer(oov_token="<unk>", filters='!"#$%&()*+.,-/:;=?@[\]^_`{|}~ ')
    sp_tokens = ['<pad>','<start>', '<end>', '<unk>']
    digits = list('0123456789')
    tokenizer.fit_on_texts(sp_tokens + digits)
    return tokenizer

In [7]:
tokenizer = build_digits_tokenizer()
with open('../config/vsp-digits-tokenizer.json', 'w') as f:
    print(tokenizer.to_json(), file=f)

In [30]:
'''
    Creates batches from raw videos and transcripts on-the-fly
    VisemeStream is responsible for viseme extraction
'''
class OnlineDigitData(GeneratorInterface):

    def __init__(self, config, max_label_len=4):
        GeneratorInterface.__init__(self)
        self.config = config
        self.inputShape = (config['frame_length'], config['frame_height'], config['frame_width'], 3)
        self.seqLen = max_label_len
        '''
            videoDir:
            eg. video_list[0]: /home/sziraqui/vsp-dev/datasets/GRID/videos/s1/bbac9n.mpg
            Then videoDir: /home/sziraqui/vsp-dev/datasets/GRID/videos

            Likewise for textDir
            eg. text_list[0]: /home/sziraqui/vsp-dev/datasets/GRID/transcripts/s1/bbac9n.align
            Then textDir: /home/sziraqui/vsp-dev/datasets/GRID/transcripts
        '''
        self.videoDir = path.dirname(path.dirname(config['video_list'][0]))
        self.videoExt = path.basename(config['video_list'][0]).split('.')[-1] # eg 'mpg'
        
        self.cache = config['cache_dir']
        
        self.ids = get_sample_ids(config['video_list'])
        np.random.shuffle(self.ids)

        self.sampleIndex = 0 # current sample index in self.ids
        self.tokenizer = load_tokenizer(config['tokenizer'])
        self.vs = VisemeStream(params=config)
        

    def next_data(self):
        
        while True:
            try:
                print('Processing file:', self.ids[self.sampleIndex])
                visemes = self.load_from_cache(self.ids[self.sampleIndex])
                if visemes is None:
                    visemes = self.load_from_disk(self.ids[self.sampleIndex ])
                    self.save_to_cache(self.ids[self.sampleIndex], visemes)
                    
                word = self.load_transcript(self.ids[self.sampleIndex])
                seq = self.tokenizer.texts_to_sequences([word])[0]
                if visemes[:25].shape[0] == 25:
                    yield visemes[:25], seq
            except KeyboardInterrupt as e:
                Log.error(repr(e))
                sys.exit(1)
            except Exception as e:
                Log.error(repr(e) + ' Error processing sample ' + self.ids[self.sampleIndex])
                continue
            self.sampleIndex = (self.sampleIndex + 1)%len(self.ids)
        
    
    def next_batch(self, batchSize):
        while True:
            X = np.zeros((batchSize,) + (self.inputShape))
            sequences = [None]*batachSize
            
            i = 0 # current batch index
            for visemes, seq in self.next_data():
                X[i] = visemes
                sequences[i] = seq
                i+=1
                if i > batchSize:
                    break
            Y = pad_sequences(sequences, value=self.tokenizer.word_index['<pad>'], maxlen=self.seqLen, 
                            dtype='uint8', padding='post', truncating='post')
            X /= 255
            
            yield X, Y
            del X
            del Y

   
    def load_from_cache(self, id):
        try:
            with h5py.File(path.join(self.cache, id + '.h5')) as f:
                print('Found in cache')
                viseme = f["features"][:]
                if viseme.shape[1:] != self.inputShape[1:]:
                    print('Incorrect shape {} in cached data. Skipping'.format(viseme.shape))
                    return None
                else:
                    print('Using cached data')
                    return viseme
        except (OSError, KeyError):
            return None
    

    def load_from_disk(self, id):
        print('Building viseme')
        self.vs.set_source(path.join(self.videoDir, id + '.' + self.videoExt))
        visemes = []
        frame = self.vs.next_frame()
        while frame is not None:
            visemes.append(frame)
            frame = self.vs.next_frame()
        return np.array(visemes)

    
    def load_transcript(self, id):
        digit = id.split('_')[-1].split('.')[0]
        return ['<start>'] + [digit] + ['<end>']
    
    
    def save_to_cache(self, id, visemes):
        filename = path.join(self.cache, id + '.h5')
        os.makedirs(path.dirname(filename), exist_ok=True)
        
        with h5py.File(filename, 'w') as f:
            f.create_dataset("features", data=visemes, dtype='uint8', compression="gzip", compression_opts=4)
            return True
        return False

In [35]:
config = parse_config('../config/config-train-vspnet-digits.json')
tokenizer = build_digits_tokenizer()
gen = OnlineDigitData(config)

In [36]:
dataset = tf.data.Dataset.from_generator(gen.next_data, (tf.float64, tf.int64))

## Make model

In [37]:
BUFFER_SIZE = 2
BATCH_SIZE = config['batch_size']
N_BATCH = BUFFER_SIZE//BATCH_SIZE
embedding_dim = 32
units = 256
vocab_size = len(tokenizer.word_index)
max_length_targ= gen.seqLen-2

In [38]:
from modules.layers import Encoder, Decoder

In [39]:
encoder = Encoder(units, BATCH_SIZE)
decoder = Decoder(vocab_size, embedding_dim, units, BATCH_SIZE)
optimizer = tf.train.AdamOptimizer()

In [40]:
checkpoint_dir = '../weights'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 encoder=encoder,
                                 decoder=decoder)

In [41]:
# adding this in a separate cell because if you run the training cell 
# many times, the loss_plot array will be reset
loss_plot = []
train_accuracy_results = []

## Main training loop

In [42]:
EPOCHS = config['epochs']
dataset = tf.data.Dataset.batch(dataset, batch_size=BATCH_SIZE)
for epoch in range(EPOCHS):
    start = time.time()
    hidden = encoder.initialize_hidden_state()
    total_loss = 0
    epoch_accuracy = tf.contrib.eager.metrics.Accuracy()#accuracy #Change here for without Nightly
    
    for (batch, (img_tensor, target)) in enumerate(dataset):
        loss = 0
        
        pred_list=[]#accuracy
        with tf.GradientTape() as tape:
            enc_output, enc_hidden = encoder(img_tensor, hidden)
            
            dec_hidden = enc_hidden
            
            dec_input = tf.expand_dims([tokenizer.word_index['<start>']] * BATCH_SIZE, 1)       
            
            # Teacher forcing - feeding the target as the next input
            for t in range(1, target.shape[1]):
                # passing enc_output to the decoder
                predictions, dec_hidden, _ = decoder(dec_input, dec_hidden, enc_output)
                pred_list.append(tf.argmax(predictions, axis=1, output_type=tf.int32))#accuracy
                
                loss += loss_function(target[:, t], predictions)
                
                # using teacher forcing
                dec_input = tf.expand_dims(target[:, t], 1)
        
        total_loss += (loss / int(target.shape[1]))
        
        variables = encoder.variables + decoder.variables
        
        gradients = tape.gradient(loss, variables) 
        
        optimizer.apply_gradients(zip(gradients, variables))#, tf.train.get_or_create_global_step()

        epoch_accuracy(np.asarray(pred_list).T, target[:,1:]) #accuracy
        
        if batch % 100 == 0:
            print ('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1, 
                                                          batch, 
                                                          loss.numpy() / int(target.shape[1])))
    # storing the epoch end loss value to plot later
    loss_plot.append(total_loss / N_BATCH)
    train_accuracy_results.append(epoch_accuracy.result())
    
    # saving (checkpoint) the model every 2 epochs
    #if (epoch + 1) % 5 == 0:
      #checkpoint.save(file_prefix = checkpoint_prefix)
    
    print ('Epoch {} Loss {:.6f}, Accuracy: {:.3%}'.format(epoch + 1, 
                                                           total_loss/N_BATCH,
                                                           epoch_accuracy.result()))#accuracy
    print ('Time taken for 1 epoch {} sec\n'.format(time.time() - start))

Processing file: VSP_DIGITS/Shriya_3
Found in cache
Using cached data
Processing file: VSP_DIGITS/Shriya_0
Found in cache
Using cached data
Processing file: VSP_DIGITS/tarang_7
Found in cache
Incorrect shape (0,) in cached data. Skipping
Building viseme


InternalError: Could not find valid device for node.
Node: {{node MaxPool3D}} = MaxPool3D[T=DT_DOUBLE, data_format="NDHWC", ksize=[1, 1, 2, 2, 1], padding="VALID", strides=[1, 1, 2, 2, 1]](dummy_input)
All kernels registered for op MaxPool3D :
  device='XLA_CPU'; T in [DT_FLOAT, DT_HALF]
  device='CPU'; T in [DT_FLOAT]
  device='XLA_CPU_JIT'; T in [DT_FLOAT, DT_HALF]
 [Op:MaxPool3D]

In [34]:
batchData = tf.data.Dataset.batch(dataset, batch_size=BATCH_SIZE)
count = 0
for (batch, (img_tensor, target)) in enumerate(batchData):
    print('batch', batch)
    print('img_tensor', img_tensor.shape)
    print('target', target)
    break

Processing file: VSP_DIGITS/Shriya_5
Found in cache
Using cached data
Processing file: VSP_DIGITS/akarshan_9
Found in cache
Using cached data
Processing file: VSP_DIGITS/sarfaraz_9
Found in cache
Incorrect shape (0,) in cached data. Skipping
Building viseme
batch 0
img_tensor (2, 25, 50, 100, 3)
target tf.Tensor(
[[ 3 11  4]
 [ 3 15  4]], shape=(2, 3), dtype=int64)
