In [1]:
import tensorflow as tf

In [None]:
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import

import os
import tensorflow as tf
import numpy as np
import scipy
import time
from tensorflow.python.client import device_lib

from model import Model_S2VT
from data_generator import Data_Generator
from inference_util import Inference

import inception_base
import configuration

FLAGS = tf.app.flags.FLAGS

tf.flags.DEFINE_integer("batch_size", 64,
                       "Batch size of train data input.")
tf.flags.DEFINE_integer("num_epochs", 10,
                       "Number of epochs to train the model.")
tf.flags.DEFINE_string("checkpoint_model", None,
                       "Model Checkpoint to use.")
tf.flags.DEFINE_string("inception_checkpoint", None,
                       "Inception Checkpoint to use")
tf.flags.DEFINE_integer("summary_freq", 100,
                       "Frequency of writing summary to tensorboard.")
tf.flags.DEFINE_integer("save_freq", None,
                       "Frequency of saving model.")

In [None]:
data_config = configuration.DataConfig().config
data_gen = Data_Generator(processed_video_dir = data_config["processed_video_dir"],
                            caption_file = data_config["caption_file"],
                            unique_freq_cutoff = data_config["unique_frequency_cutoff"],
                            max_caption_len = data_config["max_caption_length"])

data_gen.load_vocabulary(data_config["caption_data_dir"])
data_gen.load_dataset(data_config["caption_data_dir"])


In [None]:
model_config = configuration.ModelConfig(data_gen).config


In [None]:
model = Model_S2VT( num_frames = model_config["num_frames"],
                    image_width = model_config["image_width"],
                    image_height = model_config["image_height"],
                    image_channels = model_config["image_channels"],
                    num_caption_unroll = model_config["num_caption_unroll"],
                    num_last_layer_units = model_config["num_last_layer_units"],
                    image_embedding_size = model_config["image_embedding_size"],
                    word_embedding_size = model_config["word_embedding_size"],
                    hidden_size_lstm1 = model_config["hidden_size_lstm1"],
                    hidden_size_lstm2 = model_config["hidden_size_lstm2"],
                    vocab_size = model_config["vocab_size"],
                    initializer_scale = model_config["initializer_scale"],
                    learning_rate = model_config["learning_rate"],
                    mode="train",
                    rnn1_input_keep_prob=model_config["rnn1_input_keep_prob"],
                    rnn1_output_keep_prob=model_config["rnn1_output_keep_prob"],
                    rnn2_input_keep_prob=model_config["rnn2_input_keep_prob"],
                    rnn2_output_keep_prob=model_config["rnn2_output_keep_prob"]
                    )


In [None]:
tf.trainable_variables()

In [None]:
model.build()


In [None]:
tf.trainable_variables()

In [None]:
summary_op = tf.summary.merge(model.summaries)

In [None]:
sess = tf.Session()

train_writer = tf.summary.FileWriter(data_config["train_log_dir"],sess.graph)
saver = tf.train.Saver(max_to_keep=200,keep_checkpoint_every_n_hours=0.5)

In [None]:
data_config

In [None]:
if FLAGS.checkpoint_model:
    model_path = FLAGS.checkpoint_model
else:
    model_path = tf.train.latest_checkpoint(data_config["checkpoint_dir"])


In [None]:
model_path

In [None]:
FLAGS.inception_checkpoint=data_config["inception_pretrained_checkpoint"]

In [None]:
if model_path != None:
    print("Restoring weights from %s" %model_path)
    saver.restore(sess,model_path)
else:
    print("No checkpoint found. Intializing Variables from scratch and restoring from inception checkpoint")
    assert FLAGS.inception_checkpoint, "--Inception checkpoint must be given"
    sess.run(tf.global_variables_initializer())
    saver2 = tf.train.Saver(model.inception_variables)
    saver2.restore(sess,FLAGS.inception_checkpoint)


In [None]:
data_gen.init_batch(int(FLAGS.batch_size),"train")

In [None]:
if FLAGS.save_freq:
    iter_to_save = np.int32(FLAGS.save_freq)
else:
    iter_to_save = int(data_gen.iter_per_epoch["train"]/4)


In [None]:
iter_to_save

In [None]:

epoch=0
i=0

In [None]:
start_time = time.time()
dataset = data_gen.get_next_batch("train")
data_gen_time = time.time() - start_time

feed_dict={}
feed_dict[model.video_mask] = np.ones([dataset["video"].shape[0],dataset["video"].shape[1]],dtype=np.int32)
feed_dict[model.caption_input] = dataset["indexed_caption"]
feed_dict[model.caption_mask] = dataset["caption_mask"]
feed_dict[model.rnn_input] = dataset["video"]



In [None]:
feed_dict[model.video_mask].shape

In [None]:
feed_dict[model.rnn_input].shape

In [None]:
for i in range(100):
    loss,global_step,_ = sess.run([model.batch_loss,model.global_step,model.train_step],feed_dict=feed_dict)
    print(loss)

In [None]:
loss

In [None]:
sess.close()