In [1]:
from __future__ import print_function
from __future__ import division
from __future__ import absolute_import

import cv2
import os
import pysrt
import sys
import argparse
import numpy as np

import tensorflow as tf

from scipy import misc

from inference_util import Inference
from model import Model_S2VT
import configuration
import inception_base

from data_generator import Data_Generator

video_path = "/home/ozym4nd145/Downloads/[9anime.to] Boku no Hero Academia 2nd Season - 09 - 720p.mp4"

srt_path = os.path.splitext(video_path)[0]+".srt"

checkpoint_path = "../../dataset/inception_v4.ckpt"

num_frames_per_sec=10
num_frames_per_clip = 100

max_len=20

batch_size=8

video = cv2.VideoCapture(video_path)

length = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
width  = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps    = video.get(cv2.CAP_PROP_FPS)

time_length = length/fps

num_frames_to_read = int((int(time_length)*(num_frames_per_sec))/num_frames_per_clip)*num_frames_per_clip

frames_to_read = set(np.linspace(0,length-1,num=num_frames_to_read,dtype=np.int32))

## Building model
image_feed = tf.placeholder(dtype=tf.float32,shape=[None,299,299,3],name="image_feed")
inception_output = inception_base.get_base_model(image_feed)
inception_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope="InceptionV4")

data_config = configuration.DataConfig().config
data_gen = Data_Generator(processed_video_dir = data_config["processed_video_dir"],
                         caption_file = data_config["caption_file"],
                         unique_freq_cutoff = data_config["unique_frequency_cutoff"],
                         max_caption_len = data_config["max_caption_length"])

data_gen.load_vocabulary(data_config["caption_data_dir"])
data_gen.load_dataset(data_config["caption_data_dir"])

model_config = configuration.ModelConfig(data_gen).config
model = Model_S2VT( num_frames = model_config["num_frames"],
                    image_width = model_config["image_width"],
                    image_height = model_config["image_height"],
                    image_channels = model_config["image_channels"],
                    num_caption_unroll = model_config["num_caption_unroll"],
                    num_last_layer_units = model_config["num_last_layer_units"],
                    image_embedding_size = model_config["image_embedding_size"],
                    word_embedding_size = model_config["word_embedding_size"],
                    hidden_size_lstm1 = model_config["hidden_size_lstm1"],
                    hidden_size_lstm2 = model_config["hidden_size_lstm2"],
                    vocab_size = model_config["vocab_size"],
                    initializer_scale = model_config["initializer_scale"],
                    learning_rate = model_config["learning_rate"],
                    mode="inference",
                    rnn1_input_keep_prob=model_config["rnn1_input_keep_prob"],
                    rnn1_output_keep_prob=model_config["rnn1_output_keep_prob"],
                    rnn2_input_keep_prob=model_config["rnn2_input_keep_prob"],
                    rnn2_output_keep_prob=model_config["rnn2_output_keep_prob"]
                    )
model.build()

infer_util = Inference(model,data_gen.word_to_idx,data_gen.idx_to_word)

tf.trainable_variables()

tf.contrib.framework.list_variables("../E5/models/train/model-10144")

sess = tf.Session()

saver = tf.train.Saver(var_list=inception_variables)
saver.restore(sess,checkpoint_path)

saver = tf.train.Saver()

saver.restore(sess,"../E5/models/train/model-10680")

captions=[]

frame_list = []

num_frames_per_clip*batch_size

start_time=0

processed_batch=[]

frame_list= []
for i in range(length):
    ret, frame = video.read()
    if ret is False:
        break
    if i in frames_to_read:
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame = misc.imresize(frame,[299,299,3])
        frame = ((2*(frame.astype(np.float32) / 255 ))-1)
        frame = sess.run(inception_output,feed_dict={image_feed:[frame]})
        frame_list.append(frame)
        if len(frame_list)%100==0:
            print(len(frame_list))
    if len(frame_list)==(num_frames_per_clip*batch_size):
        print("Processing batch")
        processed_batch = np.array(frame_list,dtype=np.float32)
        embedded_frames = np.reshape(processed_batch,[-1,num_frames_per_clip,inception_base.num_end_units_v4])
        caption_batch = infer_util.generate_caption_batch(sess,embedded_frames,max_len=max_len)
        for cap in caption_batch:
            caption = {}
            caption["start_time"] = start_time
            caption["end_time"] = start_time+ (num_frames_per_clip/num_frames_per_sec)
            start_time = caption["end_time"]
            caption["caption"] = cap
            captions.append(caption)
        frame_list = []
        del processed_batch
        del embedded_frames
        del caption_batch

subtitles = pysrt.srtfile.SubRipFile()

index = 1

for caption in captions:
    sub = pysrt.srtitem.SubRipItem()
    sub.start.seconds = caption["start_time"]
    sub.end.seconds = caption["end_time"]
    sub.text = caption["caption"]
    sub.index = index
    index += 1
    subtitles.append(sub)

subtitles.save(srt_path)