In [60]:
import os
import functools
import operator
import os
import cv2
import time

import joblib
import numpy as np
from keras.layers import Input, LSTM, Dense
from keras.models import Model, load_model
# import extract_features

main_path = os.path.normpath(os.getcwd() + os.sep + os.pardir)

In [61]:
class VideoDescriptionRealTime(object):
    ''' Initialize the parameters for the model '''
    def __init__(self):
        self.latent_dim = 512
        self.num_encoder_tokens = 4096
        self.num_decoder_tokens = 1500
        self.time_steps_encoder = 80
        self.time_steps_decoder = None
        self.preload = True
        self.preload_data_path = 'preload_data'
        self.max_probability = -1
        self.search_type = 'greedy'

        # processed data
        self.encoder_input_data = []
        self.decoder_input_data = []
        self.decoder_target_data = []
        self.tokenizer = None

        # models
        self.encoder_model = None
        self.decoder_model = None
        self.inf_encoder_model = None
        self.inf_decoder_model = None
        self.save_model_path = 'model_final'
        self.test_path = 'testing_data'
        
    def load_inference_models(self):
        with open(main_path + '\\shreya_model\\tokenizer1501', 'rb') as file:
            self.tokenizer = joblib.load(file)

        # inference encoder model
        self.inf_encoder_model = load_model(main_path + '\\shreya_model\\encoder_model.h5')

        # inference decoder model
        decoder_inputs = Input(shape=(None, self.num_decoder_tokens))
       
        decoder_dense = Dense(self.num_decoder_tokens, activation='softmax')
   
        decoder_lstm = LSTM(self.latent_dim, return_sequences=True, return_state=True)

        decoder_state_input_h = Input(shape=(self.latent_dim,))
        decoder_state_input_c = Input(shape=(self.latent_dim,))
        decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
        decoder_outputs, state_h, state_c = decoder_lstm(decoder_inputs, initial_state=decoder_states_inputs)
        decoder_states = [state_h, state_c]
       
        decoder_outputs = decoder_dense(decoder_outputs)
        self.inf_decoder_model = Model(
            [decoder_inputs] + decoder_states_inputs,
            [decoder_outputs] + decoder_states)
        self.inf_decoder_model.load_weights(main_path + '\\shreya_model\\decoder_model_weights.h5')
        
    
    def greedy_search(self, f):
        """
        :param f: the loaded numpy array after creating videos to frames and extracting features
        :return: the final sentence which has been predicted greedily
        """
        inv_map = self.index_to_word()
        states_value = self.inf_encoder_model.predict(f.reshape(-1, 80, 4096))
        target_seq = np.zeros((1, 1, 1500))
        final_sentence = ''
        target_seq[0, 0, self.tokenizer.word_index['start']] = 1
        for i in range(15):
            output_tokens, h, c = self.inf_decoder_model.predict([target_seq] + states_value)
            states_value = [h, c]
            output_tokens = output_tokens.reshape(self.num_decoder_tokens)
            y_hat = np.argmax(output_tokens)
            if y_hat == 0:
                continue
            if inv_map[y_hat] is None:
                break
            if inv_map[y_hat] == 'end':
                break
            else:
                final_sentence = final_sentence + inv_map[y_hat] + ' '
                target_seq = np.zeros((1, 1, 1500))
                target_seq[0, 0, y_hat] = 1
        return final_sentence
    
    
    def decode_sequence2bs(self, input_seq):
        states_value = self.inf_encoder_model.predict(input_seq)
        target_seq = np.zeros((1, 1, self.num_decoder_tokens))
        target_seq[0, 0, self.tokenizer.word_index['start']] = 1
        self.beam_search(target_seq, states_value,[],[],0)
        return decode_seq
    
    def beam_search(self, target_seq, states_value, prob,  path, lens):
        global decode_seq
        node = 2
        output_tokens, h, c = self.inf_decoder_model.predict(
            [target_seq] + states_value)
        output_tokens = output_tokens.reshape((self.num_decoder_tokens))
        sampled_token_index = output_tokens.argsort()[-node:][::-1]
        states_value = [h, c]
        for i in range(node):
            if sampled_token_index[i] == 0:
                sampled_char = ''
            else:
                sampled_char = list(self.tokenizer.word_index.keys())[list(self.tokenizer.word_index.values()).index(sampled_token_index[i])]
            MAX_LEN = 10
            if(sampled_char != 'end' and lens <= MAX_LEN):
                p = output_tokens[sampled_token_index[i]]
                if(sampled_char == ''):
                    p = 1
                prob_new = list(prob)
                prob_new.append(p)
                path_new = list(path)
                path_new.append(sampled_char)
                target_seq = np.zeros((1, 1, self.num_decoder_tokens))
                target_seq[0, 0, sampled_token_index[i]] = 1.
                self.beam_search(target_seq, states_value, prob_new, path_new, lens+1)
            else:
                p = output_tokens[sampled_token_index[i]]
                prob_new = list(prob)
                prob_new.append(p)
                p = functools.reduce(operator.mul, prob_new, 1)
                if(p > self.max_probability):
                    decode_seq = path
                    self.max_probability = p
                    
        
    def index_to_word(self):
        # inverts word tokenizer
        index_to_word = {value: key for key, value in self.tokenizer.word_index.items()}
        print(index_to_word)
        return index_to_word   
    
    
    def get_test_data(self):
        X_test = []
        X_test_filename = []
        with open (main_path + '\\testing_id.txt') as testing_file:
            lines = testing_file.readlines()
            
            for filename in lines:
                filename = filename.strip()
                f = np.load(main_path + '\\testing_data\\' + str(filename) + '.npy')
                X_test.append(f)
                X_test_filename.append(filename[:-4])
            X_test = np.array(X_test)
        return X_test, X_test_filename

    def test(self):
        X_test, X_test_filename = self.get_test_data()
# #         print(len(X_test), len(X_test_filename))
# #         sentence_predicted = self.greedy_search(X_test.reshape((-1, 80, 4096)))
#         if self.search_type == 'greedy':
#             sentence_predicted = self.greedy_search(X_test.reshape((-1, 80, 4096)))
#         else:
#             sentence_predicted = ''
#             decoded_sentence = self.decode_sequence2bs(X_test.reshape((-1, 80, 4096)))
#             decode_str = self.decoded_sentence_tuning(decoded_sentence)
#             for d in decode_str:
#                 sentence_predicted = sentence_predicted + d + ' '
#         # re-init max prob
#         self.max_probability = -1
#         return sentence_predicted, filename

        with open(main_path + '\\shreya_model\\test_output.txt', 'w') as file:
            for idx, x in enumerate(X_test):
                file.write(X_test_filename[idx] + ',')
                if self.search_type == 'greedy':
                    start = time.time()
                    decoded_sentence = self.greedy_search(x.reshape(-1, 80, 4096))
#                     print(decoded_sentence)
                    file.write(decoded_sentence + ',{:.2f}'.format(time.time()-start))
                else:
                    start = time.time()
                    decoded_sentence = self.decode_sequence2bs(x.reshape(-1, 80, 4096))
                    decode_str = self.decoded_sentence_tuning(decoded_sentence)
                    print(decode_str)
                    for d in decode_str:
                        file.write(d + ' ')
                    file.write(',{:.2f}'.format(time.time() - start))
                file.write('\n')

                # re-init max prob
                self.max_probability = -1
    
    
    
    

In [62]:
video_to_text = VideoDescriptionRealTime()
video_to_text.load_inference_models()
video_to_text.test()









