In [1]:
import numpy as np
import tensorflow as tf
import layer_utils
from coco_utils import load_coco_data, sample_coco_minibatch, decode_captions
from image_utils import image_from_url
from lstm import LSTM

def show_image(batch_size):
    captions, features, urls = sample_coco_minibatch(data, batch_size=batch_size)
    for i, (caption, url) in enumerate(zip(captions, urls)):
        plt.imshow(image_from_url(url))
        plt.axis('off')
        caption_str = decode_captions(caption, data['idx_to_word'])
        plt.title(caption_str)
        plt.show()
        
def sparse_to_one_hot(sparse_input, max_dim):
    one_hot = np.zeros((sparse_input.shape[0], max_dim))
    for idx, input_index in enumerate(sparse_input):
        one_hot[idx, input_index] = 1
    return one_hot

def captions_to_one_hot(captions, vocab_dim):
    return [sparse_to_one_hot(sentence, vocab_dim) for sentence in captions]

def verify_caption_train_target_offset(train_caption, target_caption):
    for i in range(len(target_caption) - 1):
        assert train_caption[i + 1] == target_caption[i]
        
def get_train_target_caption(train_captions_as_word_ids, null_representation):
    """
        Convert training data:
        '<START> a variety of fruits and vegetables sitting on a kitchen counter'
        to target:
        'a variety of fruits and vegetables sitting on a kitchen counter <END>'
    """
    target_captions_as_word_ids = train_captions_as_word_ids[:, 1:]
    train_captions_as_word_ids = train_captions_as_word_ids[:, :-1]
    verify_caption_train_target_offset(train_captions_as_word_ids[0], target_captions_as_word_ids[0])
    not_null_target_mask = target_captions_as_word_ids != null_representation
    return train_captions_as_word_ids, target_captions_as_word_ids, not_null_target_mask

# Load Data
data = load_coco_data(pca_features=False)

## word preprocess
vocab_dim = len(data['word_to_idx'])
image_feature_dim = data['val_features'].shape[1]
enable_preprocessed_embedding = True

if enable_preprocessed_embedding:
    word_embedding_dim = data['word_embedding'].shape[1]
else:
    word_embedding_dim = 256

START_TOKEN = '<START>'
END_TOKEN = '<END>'
NULL_TOKEN = '<NULL>'
NULL_ID = data['word_to_idx'][NULL_TOKEN]
START_ID = data['word_to_idx'][START_TOKEN]
END_ID = data['word_to_idx'][END_TOKEN]

print("Vocab Dim: %i\nImage Feature Dim: %i\nWord Embedding Dim: %i"%(vocab_dim,
                                                                      image_feature_dim,
                                                                      word_embedding_dim))



Vocab Dim: 1004
Image Feature Dim: 4096
Word Embedding Dim: 304


In [2]:
hidden_dim = 512
batch_size = 50
lstm = LSTM(hidden_dim=hidden_dim,
            output_dim=vocab_dim,
            learning_rate=5e-3,
            batch_size=batch_size,
            num_layers=1)

# Word Input
sy_caption_input = tf.placeholder(shape=[batch_size, None], name="caption_input", dtype=tf.int32)

if enable_preprocessed_embedding:
    embedding_init = tf.constant(data['word_embedding'], dtype=tf.float32)
    embedding = tf.get_variable("embedding", initializer=embedding_init)
else:
    embedding_init = tf.random_normal_initializer()
    embedding = tf.get_variable("embedding", [vocab_dim, word_embedding_dim], dtype=tf.float32, initializer=embedding_init)

word_embedding = tf.nn.embedding_lookup(embedding, sy_caption_input)
print(word_embedding)

# Image Input
sy_image_feat_input = tf.placeholder(shape=[batch_size, image_feature_dim], name="image_feat_input", dtype=tf.float32)
initial_hidden_state = layer_utils.affine_transform(sy_image_feat_input, hidden_dim, 'image_proj')
initial_cell_state = initial_hidden_state * 0

lstm.build_model(word_embedding, initial_hidden_state, initial_cell_state)

Tensor("embedding_lookup:0", shape=(50, ?, 304), dtype=float32)


In [3]:
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()

In [8]:
iter_num = 200
embeddings = []
for i in range(iter_num):
    mini_batch, features, url = sample_coco_minibatch(data,  batch_size=batch_size, split='train')
    train_captions, target_captions, target_mask = get_train_target_caption(mini_batch, NULL_ID)
    feed_dict = {
        sy_caption_input: train_captions,
        sy_image_feat_input: features
    }
    c, a = lstm.train(sess, target_captions, target_mask, feed_dict)
    if (i % 10 == 0):
        print("iter {}, cross-entropy: {}, accuracy: {}".format(i, c, a))

iter 0, cross-entropy: 35.00442123413086, accuracy: 0.38235294818878174
iter 10, cross-entropy: 32.651302337646484, accuracy: 0.4020797312259674
iter 20, cross-entropy: 34.308109283447266, accuracy: 0.3986014127731323
iter 30, cross-entropy: 31.362855911254883, accuracy: 0.3955223858356476
iter 40, cross-entropy: 32.74540710449219, accuracy: 0.4060283601284027
iter 50, cross-entropy: 33.60277557373047, accuracy: 0.3986254334449768
iter 60, cross-entropy: 31.882848739624023, accuracy: 0.4003496468067169
iter 70, cross-entropy: 31.202634811401367, accuracy: 0.37833037972450256
iter 80, cross-entropy: 31.555641174316406, accuracy: 0.3974591791629791
iter 90, cross-entropy: 32.09532165527344, accuracy: 0.3731343150138855
iter 100, cross-entropy: 31.986225128173828, accuracy: 0.3759259283542633
iter 110, cross-entropy: 29.608427047729492, accuracy: 0.41360294818878174
iter 120, cross-entropy: 30.509275436401367, accuracy: 0.40250447392463684
iter 130, cross-entropy: 30.927257537841797, accu

In [20]:
def decode_caption_with(word_id_sequence, key_name = 'idx_to_word'):
    id_to_word = data[key_name]
    return decode_captions(word_id_sequence, id_to_word)

mini_batch, features, url = sample_coco_minibatch(data,  batch_size=batch_size, split='val')
GT_input, GT_captions, GT_mask = get_train_target_caption(mini_batch, NULL_ID)
test_input = np.ones((batch_size, 1)) * START_ID
feed_dict = {
    sy_caption_input: test_input,
    sy_image_feat_input: features
}
output, logits = lstm.test(sess, sy_caption_input, feed_dict)

feed_dict[sy_caption_input] = GT_input
pseudo_output = lstm.pseudo_test(sess, GT_captions, feed_dict)


print("GT:{}".format(decode_caption_with(GT_captions[0])))
print("Test:{}".format(decode_caption_with(output[0])))
print("Pseudo-Test:{}".format(decode_caption_with(pseudo_output[0])))

[  4  29  91 217   9   4  91   5   7  69   6  31   2   0   0   0]
[[ 18.11688995  -7.46315861   5.94842863 ...,  -6.25985098  -5.1451869
   -4.78702831]]
(50, 16)
GT:a large tall tower with a clock on top <END>
Test:a large clock tower with a clock on the side of it <END>
Pseudo-Test:a large clock tower with a clock on the of
