In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import pickle
import tensorflow.contrib.slim as slim
from tensorflow.contrib.seq2seq import BahdanauAttention
from tensorflow.contrib.seq2seq import AttentionMechanism
from tensorflow.contrib.seq2seq import AttentionWrapper
from tensorflow.contrib import seq2seq
from tensorflow import layers
from tensorflow.layers import dense
from tensorflow.nn.rnn_cell import LSTMCell
from tensorflow.nn.rnn_cell import DropoutWrapper
from tensorflow.nn.rnn_cell import ResidualWrapper
from tensorflow.nn.rnn_cell import MultiRNNCell
from tqdm import tqdm
from tensorflow import gfile
from tensorflow import logging
from tensorflow.python.ops.rnn_cell_impl import LSTMStateTuple
from tensorflow.contrib.seq2seq import AttentionWrapperState

![show attend and tell](./imgs/图像标题生成模型结构图.png)

下面是定义的WordSequence模型，使用WordSequence类统计文本标注语句的词频，并对文本进行编解码

In [None]:
class WordSequence(object):
    def __init__(self,train_captions,max_len):
        self.dict = {'<pad>': 0, '<start>': 1, '<end>': 2,'<unknow>':3}
        self.corpus = train_captions
        self._word_count()
        self.max_len =max_len
        
        def _word_count(self):
            word_dict = {}
            for word in self.corpus:
                if word not in word_dict:
                    word_dict[word] = 0
                word_dict[word] = word_dict[word]+1
            print(len(word_dict))
            word_dict = {k:v for k,v in word_dict.items() if v>=2}
            print(len(word_dict))
            vocabulary = sorted(word_dict.items(), key=lambda x:-x[1])
            for i in vocabulary:
                self.dict[i[0]] = len(self.dict)
            self.id2word_dict = {value:key for key,value in self.dict.items()}
        
        
    def word2id(self,word_list):
        word_list.append('<end>')
        ids = []
        for word in word_list:
            if word in self.dict:
                ids.append(self.dict[word])
            else:
                ids.append(3)
        return ids
    
    def id2word(self,ids):
        words = []
        for i in ids:
            words.append(self.id2word_dict[i])
        return words

In [None]:
def pares_tfrecords(record):
    features = tf.parse_single_example(
        record,
        features={
            'image/height':tf.FixedLenFeature([], tf.int64),
            'image/width':tf.FixedLenFeature([], tf.int64),
            'image/caption':tf.FixedLenFeature([], tf.string),
            'image/caption_len':tf.FixedLenFeature([], tf.int64),
            'image/filename':tf.FixedLenFeature([], tf.string),
            'image/img':tf.FixedLenFeature([], tf.string)
        })
    
    height = features['image/height']
    width = features['image/width']
    filename = features['image/filename']
    caption = features['image/caption']
    caption = tf.string_split([caption]).values
    caption = tf.string_to_number(caption)
    caption = tf.to_int32(caption)
    caption_len = features['image/caption_len']
    caption_len = tf.to_int32(caption_len)
    img = features['image/img']
    img = tf.image.decode_jpeg(img)
    img = tf.to_float(img)
    img.set_shape(shape=[224,224,3])

    #从原始阁像数据解析出图像矩阵，并根据图像尺、f还原图像。
    return img,caption,filename,caption_len

In [None]:
def make_target_input(img,caption,filename,caption_len):
    input_caption = tf.concat([[1],caption[:-1]],axis=0)
    label = caption
    return img,input_caption,filename,caption_len,label

下面是对tfrecord数据进行解析

In [None]:
def generate_train_data(batch_size,epochs):
    tfrecords_dir = './tfrecords'
    files = tf.gfile.Glob(tfrecords_dir + '/*')  # 匹配路径下所有的图片路径，返回一个文件路径列表
    dataset = tf.data.TFRecordDataset(files)  # 建立一个dataset
    dataset = dataset.map(pares_tfrecords)  #
    dataset = dataset.map(make_target_input)
    padded_shapes = (
        tf.TensorShape([224,224,3]),  # 
        tf.TensorShape([None]),
        tf.TensorShape([]),
        tf.TensorShape([]),
        tf.TensorShape([None])
    )
    dataset = dataset.repeat(epochs).shuffle(buffer_size=5000)
    dataset = dataset.padded_batch(batch_size, padded_shapes)
    
    iterator = dataset.make_one_shot_iterator()
    img, caption, filename,caption_len,label= iterator.get_next()
# #     caption.set_shape(shape=(batch_size,))
#     filename.set_shape(shape=(batch_size,))
#     img.set_shape(shape=(batch_size, 224, 224, 3))
#     img = tf.to_float(img)
    return img, caption, filename,caption_len,label,iterator

In [None]:
ws = None
with open('./ws.pkl','rb') as f:
    ws = pickle.load(f)

In [None]:
init = tf.global_variables_initializer()
local_init = tf.local_variables_initializer()
img, caption, filename,caption_len,label,_ = generate_train_data(10,10)
with tf.Session() as sess:
    sess.run([init,local_init])
    result_img,result_caption,result_filename,result_caption_len,result_label = sess.run([img, caption, filename,caption_len,label])
    print(result_caption)
    for i in result_caption:
        print(ws.id2word(i))
    print(result_label)
    for j in result_label:
        print(ws.id2word(j))

 使用tensorflow.contrib.slim 模型块儿定义的vgg_19网络

In [None]:
def vgg_19(inputs,
           num_classes=1000,
           is_training=True,
           dropout_keep_prob=0.5,
           spatial_squeeze=True,
           scope='vgg_19',
           fc_conv_padding='VALID',
           global_pool=False):
    with tf.variable_scope(scope, 'vgg_19', [inputs]) as sc:
        end_points_collection = sc.original_name_scope + '_end_points'
        # Collect outputs for conv2d, fully_connected and max_pool2d.
        with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
                            outputs_collections=end_points_collection):
            net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1')
            net = slim.max_pool2d(net, [2, 2], scope='pool1')
            net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2')
            net = slim.max_pool2d(net, [2, 2], scope='pool2')
            net = slim.repeat(net, 4, slim.conv2d, 256, [3, 3], scope='conv3')
            net = slim.max_pool2d(net, [2, 2], scope='pool3')
            net = slim.repeat(net, 4, slim.conv2d, 512, [3, 3], scope='conv4')
            net = slim.max_pool2d(net, [2, 2], scope='pool4')
            net = slim.repeat(net, 4, slim.conv2d, 512, [3, 3], scope='conv5')
            net = slim.max_pool2d(net, [2, 2], scope='pool5')

            # Use conv2d instead of fully_connected layers.
            net = slim.conv2d(net, 4096, [7, 7], padding=fc_conv_padding, scope='fc6')
            net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
                             scope='dropout6')
            net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
            # Convert end_points_collection into a end_point dict.
            end_points = slim.utils.convert_collection_to_dict(end_points_collection)
            if global_pool:
                net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool')
                end_points['global_pool'] = net
            if num_classes:
                net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
                               scope='dropout7')
            net = slim.conv2d(net, num_classes, [1, 1],
                              activation_fn=None,
                              normalizer_fn=None,
                              scope='fc8')
            if spatial_squeeze:
                net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
            end_points[sc.name + '/fc8'] = net
        return net, end_points

In [None]:
MEAN_VALUES = np.array([123.68, 116.779, 103.939]).reshape((1, 1, 3))

In [None]:
def get_save_variables():
    variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='vgg_19')
    return variables

In [None]:
vgg_model = './vgg_model/vgg_19.ckpt'
batch_size = 60
epochs = 35
img, input_caption, filename,caption_len,label,iterator = generate_train_data(batch_size,epochs)
net,end_points = vgg_19(img-MEAN_VALUES)
img_features = end_points['vgg_19/conv5/conv5_3']
print(img_features)
img_features = tf.reshape(img_features,shape=[-1,196,512])
print(img_features)

In [None]:
def batch_norm(inputs, name,training):
    return tf.contrib.layers.batch_norm(inputs, decay=0.95, center=True, scale=True, is_training=training, 
                                        updates_collections=None, scope=name)

In [None]:
def build_single_cell(n_hidden,use_residual,keep_prob_placeholder):
        '''
        构建一个单独的rnn cell
        :param n_hidden: 隐藏层的神经单元数量
        :param use_residual: 是否使用residual wrapper
        :return:
        '''
        cell_type = LSTMCell
        cell = cell_type(n_hidden)
        #使用self.use_dropout 可以避免过拟合，等等。
        
        cell = DropoutWrapper(
            cell,
            dtype=tf.float32,
            output_keep_prob=keep_prob_placeholder,
            seed = 0 #一些层之间操作的随机数
            )
        #使用ResidualWrapper进行封装可以避免一些梯度消失或者梯度爆炸
        if use_residual:
            cell = ResidualWrapper(cell)
        return cell

In [None]:
def cell_input_fn(inputs,attention):
    
    
    '''
    根据attn_input_feeding属性来判断是否在attention计算前进行一次投影的计算
    使用注意力机制才会进行的运算
    :param inputs:
    :param attention:
    :return:
    '''

#     if not use_residual:
#         print(inputs.get_shape,'inputs_shape')
#         print(attention.get_shape,'inputs_shape')
#         print(tf.concat([inputs,attention],-1),'inputs和attention拼接之后的形状')
#         return tf.concat([inputs,attention],-1)

    attn_projection = layers.Dense(1024,
                                   dtype=tf.float32,
                                   use_bias=False,
                                   name='attention_cell_input_fn')

    '''
    这个attn_projection(array_ops.concat([inputs,attention],-1))我的理解就是
    layers.Dense(self.hidden_units,
                                   dtype=tf.float32,
                                   use_bias=False,
                                   name='attention_cell_input_fn')(array_ops.concat([inputs,attention],-1))
    Dense最终继承了Layer类，Layer中定义了call方法和__call__ 方法，Dense也重写了call方法，__call__方法中调用call方法，call方法中还是起一个全连接层层的作用，__call__
    方法中执行流程是：pre process，call，post process
    '''
    print(inputs,'xixi')
    print(attention,'xixi')
    t1 = tf.concat([tf.cast(inputs,dtype=tf.float32),tf.cast(attention,dtype=tf.float32)],1)
    print(t1,'t1')
    return tf.nn.relu(batch_norm(attn_projection(t1),'cell_input_fn',True))

In [None]:
target_vocab_size = 10656
hidden_size = 1024
max_gradient_norm=5.0

In [None]:
e_initializer = tf.random_uniform_initializer(-1.0, 1.0)

下面定义的就是图像生成文本的解码部分

In [None]:
with tf.variable_scope('captions') as decoder_scope:
    context = batch_norm(img_features,'context',True)
    context = tf.nn.relu(context)
    attention_mechanism = BahdanauAttention(
                num_units=512,
                memory=context,
                memory_sequence_length=[196]*batch_size
            )
    print(context,'context')
    
    cell = MultiRNNCell(
            [
                build_single_cell(hidden_size,use_residual=True,keep_prob_placeholder = 0.8) for _ in range(1)
            ])
    cell = AttentionWrapper(
            cell=cell,
            attention_mechanism=attention_mechanism,
            attention_layer_size=hidden_size,
            alignment_history=True,#这个是attention的历史信息
            cell_input_fn=cell_input_fn,#将attention拼接起来和input拼接起来
            name='Attention_Wrapper'
        )
    
    decoder_initial_state = cell.zero_state(
            batch_size,tf.float32
        )#这里初始化decoder_inital_state
    print(decoder_initial_state)

#     传递encoder的状态
    context_mean = tf.reduce_mean(context, 1)
    the_state = tf.nn.relu(batch_norm(dense(context_mean, hidden_size, name='initial_state'),'initial_state',True))
    the_memory = tf.nn.relu(batch_norm(dense(context_mean, hidden_size, name='initial_memory'),'initial_memory',True))
    the_cell_state_for_init = (LSTMStateTuple(the_memory,the_state),)
    print(the_cell_state_for_init)
    decoder_initial_state = decoder_initial_state.clone(
        cell_state = the_cell_state_for_init
    )
    
    #创建embedding向量
    words_embeddings = tf.get_variable(
                        name='embeddings',
                        shape=(target_vocab_size, 512),
                        initializer=e_initializer,
                        dtype=tf.float32
                    )
    
    decoder_output_projection = layers.Dense(
                target_vocab_size,
                dtype=tf.float32,
                use_bias=False,
                name='decoder_output_projection'
            )
    
    captions_inputs_embedded = tf.nn.embedding_lookup(
                params=words_embeddings,
                ids=input_caption
            )
    
    training_helper = seq2seq.TrainingHelper(
                inputs=captions_inputs_embedded,#这个是decoder的inputs,不是label
                sequence_length=caption_len,#用作输入的解码器长度。
                time_major=False,
                name='training_helper'
            )
    
    training_decoder = seq2seq.BasicDecoder(
                cell=cell,
                helper=training_helper,
                initial_state=decoder_initial_state
            )
    
    max_decoder_length = tf.reduce_max(
                caption_len
            )
    (
            outputs,
            final_state,
            final_sequence_lengths
        ) = seq2seq.dynamic_decode(
            decoder=training_decoder,
            output_time_major=False,
            impute_finished=True,
            maximum_iterations=max_decoder_length,
            parallel_iterations=5,
            swap_memory=True,
            scope=decoder_scope
    )
    
    decoder_logits_train = decoder_output_projection(
                outputs.rnn_output
            )
    
    masks = tf.sequence_mask(
                lengths=caption_len,
                maxlen=max_decoder_length,
                dtype=tf.float32,
                name='masks'
            )
    loss = seq2seq.sequence_loss(
                logits=decoder_logits_train,
                targets=label,#这里应该改成decoder的标签了
                weights=masks,# 区分padding位和数据位，这时候需要。
                average_across_timesteps=True,
                average_across_batch=True
            )
    print(decoder_logits_train)
    print(label)

    opt = tf.train.AdamOptimizer(
                learning_rate = 0.001
            )
    
    #计算精确度
    logits_flatted = tf.reshape(decoder_logits_train,(-1,target_vocab_size))
    prediction = tf.argmax(logits_flatted, 1, output_type=tf.int32)
    label_flatten = tf.reshape(label, [-1]) 
    mask_flatten = tf.reshape(masks, [-1])
    mask_flatten = tf.cast(mask_flatten,tf.float32)
    
    correct_prediction = tf.equal(prediction, label_flatten)
    print(correct_prediction.get_shape)
    print(mask_flatten.get_shape)
    correct_prediction_with_mask = tf.multiply(
        tf.cast(correct_prediction, tf.float32),
        mask_flatten)

    mask_sum = tf.reduce_sum(mask_flatten)
    accuracy = tf.reduce_sum(correct_prediction_with_mask) / mask_sum
    
    
    
    


In [None]:
save_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='captions')
trainable_params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope='captions')
global_step = tf.Variable(0, trainable=False)

In [None]:
gradients = tf.gradients(loss,trainable_params)
clip_gradients,_ = tf.clip_by_global_norm(gradients,max_gradient_norm)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    with tf.name_scope('train_op'):
        train_op = opt.apply_gradients(zip(clip_gradients,trainable_params),
                                                global_step=global_step)

In [None]:
init = tf.global_variables_initializer()
local_init = tf.local_variables_initializer()
saver = tf.train.Saver(get_save_variables())
saver1 = tf.train.Saver(save_params)

In [None]:
with tf.Session() as sess:
    sess.run([init])
    saver.restore(sess=sess,save_path=vgg_model)
    saver1.restore(sess=sess,save_path='./mysave_model/captions.ckpt-0')
    print(dir(iterator))
    
    for epoch in range(epochs):
#         sess.run(iterator.initializer())
        bar = tqdm(range(529))
        try:
            for i in bar:
                _,loss_val,accuracy_val,global_step_val = sess.run([train_op,loss,accuracy,global_step])
                bar.set_description('%s:Epoch, Step: %5d, loss: %3.3f, accuracy: %3.3f'
                                                % (str(epoch),global_step_val, loss_val, accuracy_val))
        except Exception as e:
            print(e)
        
        
        saver1.save(sess,save_path='./mysave_model/captions.ckpt-%s' %(str(epoch)))
    
        
            
        
        
#     result_img = np.reshape(total_result[0],(224,224,3))
#     plt.figure()
#     plt.imshow(result_img)
#     plt.axis('off')
#     print(total_result[0].shape)
#     print(' '.join(ws.id2word(total_result[1][0])))
#     print(total_result[2])
#     print(total_result[3])
#     print(total_result[0].shape)

In [None]:
ws = None
with open('./ws.pkl','rb') as f:
    ws = pickle.load(f)

In [None]:
len(ws.dict)

In [None]:
[196] * 10 

In [None]:

with tf.Session() as sess:
    sess.run([init,local_init])
    saver.restore(sess=sess,save_path=vgg_model)
    total_result = sess.run([img_features])
#     result_img = np.reshape(total_result[0],(224,224,3))
#     plt.figure()
#     plt.imshow(result_img)
#     plt.axis('off')
#     print(total_result[0].shape)
#     print(' '.join(ws.id2word(total_result[1][0])))
#     print(total_result[2])
#     print(total_result[3])
    print(total_result[0].shape)