In [1]:
import tensorflow as tf
import matplotlib.pyplot as plt
import codecs
import tqdm
from PIL import Image
import logging
import sys
import threading
from datetime import datetime
import pickle
import numpy as np
import os 

In [41]:
class WordSequence(object):
    def __init__(self,train_captions,max_len):
        self.dict = {'<pad>': 0, '<start>': 1, '<end>': 2,'<unknow>':3}
        self.corpus = train_captions
        self._word_count()
        self.max_len =max_len
        
    def _word_count(self):
        word_dict = {}
        for word in self.corpus:
            if word not in word_dict:
                word_dict[word] = 0
            word_dict[word] = word_dict[word]+1
        print(len(word_dict))
        word_dict = {k:v for k,v in word_dict.items() if v>=2}
        print(len(word_dict))
        vocabulary = sorted(word_dict.items(), key=lambda x:-x[1])
        for i in vocabulary:
            self.dict[i[0]] = len(self.dict)
        self.id2word_dict = {value:key for key,value in self.dict.items()}
        
        
    def word2id(self,word_list):
        word_list.append('<end>')
        ids = []
        for word in word_list:
            if word in self.dict:
                ids.append(self.dict[word])
            else:
                ids.append(3)
        return ids
    
    def id2word(self,ids):
        words = []
        for i in ids:
            words.append(self.id2word_dict[i])
        print(words)
        return words

In [3]:
def get_ids_length(string):
    length = len(string.split(' '))
    return length 

In [4]:
def _process_image(filename, hps):
    image = hps.data_dir+filename
    img_buffer = None
    with open(image,'rb') as f:
        img_buffer = f.read()
    return img_buffer

In [5]:
def _int64_feature(value):
    """Wrapper for inserting int64 features into Example proto.
    
    """
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

In [6]:
hps = tf.contrib.training.HParams(
        num_shards=20,
        num_threads=5,
        dataset_name='satellite',
        data_dir='data/crop_imgs/',
        output_directory='./tfrecords'
    )

In [7]:
def _bytes_feature(value):
    """Wrapper for inserting bytes features into Example proto."""
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

In [8]:
def _convert_to_example(filename, image_buffer,caption,caption_len):
    """Build an Example proto for an example.
    Args:
      filename: string, path to an image file, e.g., '36979.jpg'
      image_buffer: string, JPEG encoding of RGB image
      text: string,image captions
      height: integer, image height in pixels
      width: integer, image width in pixels
    Returns:
      Example proto
    """
    height=224
    width=224

    example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': _int64_feature(height),
        'image/width': _int64_feature(width),
        'image/caption': _bytes_feature(caption.encode()),
        'image/caption_len':_int64_feature(caption_len),
        'image/filename': _bytes_feature(filename.encode()),
        'image/img': _bytes_feature(image_buffer)}))
    return example

In [9]:
captions = {}
with codecs.open ('./data/results_20130124.token','r','utf-8') as f:
    lines = f.readlines()
    bar = tqdm.tqdm(lines)
    for line in bar:
        line = line.strip()
        words = line.split("\t")
        if words[0].split("#")[0] not in captions:
            captions[words[0].split("#")[0]]=[]
        else:
            caption = words[1]
            caption = caption.strip().lower()
            caption = caption.replace(',', '').replace("'", '').replace('"', '')
            caption = caption.replace('&', 'and').replace('(', '').replace(')', '').split()
            caption = [w for w in caption if len(w) > 0]
            captions[words[0].split("#")[0]].append(caption)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 158915/158915 [00:01<00:00, 142193.10it/s]


In [38]:
captions_words = []
for k,values in captions.items():
    for value in values:
        captions_words.extend(value)

In [39]:
len(captions_words)

1498287

In [42]:
ws = WordSequence(captions_words,20)

17119
10652


In [49]:
with open('./ws.pkl','wb') as f:
    pickle.dump(obj=ws,file=f)

In [43]:
captions_ids = {}
# ws = None
# with open('./ws.pkl','rb') as f:
#     ws = pickle.load(f)
for key in list(captions.keys()):
    ids = []
    for integer in ws.word2id(captions[key][1]):
        ids.append(str(integer))
    ids = ' '.join(ids)
    captions_ids[key]=ids

In [None]:
list(captions_ids.items())[0][1]

In [None]:
img_caption = captions['36979.jpg']
print(' '.join(img_caption[1]))
img_caption = ' '.join(img_caption[1])

In [None]:
get_ids_length(list(captions_ids.items())[0][1])

In [44]:
def _process_image_files_batch(thread_index,ranges,filenames,num_shards):
    """Processes and saves list of images as TFRecord in 1 thread.
    Args:
      coder: instance of ImageCoder to provide TensorFlow image coding utils.
      thread_index: integer, unique batch to run index is within [0, len(ranges)).
      ranges: list of pairs of integers specifying ranges of each batches to
        analyze in parallel.
      name: string, unique identifier specifying the data set
      filenames: list of strings; each string is a path to an image file
      texts: list of strings; each string is human readable, e.g. 'dog'
      num_shards: integer number of shards for this data set.
    """
    # Each thread produces N shards where N = int(num_shards / num_threads).
    # For instance, if num_shards = 128, and the num_threads = 2, then the first
    # thread would produce shards [0, 64).
    num_threads = len(ranges) #线程的数目，
    assert not num_shards % num_threads #确认，num_shards，也就是想要分片的数目是线程数目的整数倍。
    num_shards_per_batch = int(num_shards / num_threads) #num_shards_pre_batch,每个线程需要生成多个shard
    shard_ranges = np.linspace(ranges[thread_index][0],
                               ranges[thread_index][1],
                               num_shards_per_batch + 1).astype(int)#这个线程需要处理的图像索引的范围。
    num_files_in_thread = ranges[thread_index][1] - ranges[thread_index][0] #这个线程需要处理多少个图像。
    counter = 0
    
    for s in range(num_shards_per_batch):
        # Generate a sharded version of the file name, e.g. 'train-00002-of-00010'
        shard = thread_index * num_shards_per_batch + s #可以标识一个shard分片
        output_filename = '%s_%.5d-of-%.5d.tfrecord' % ('1', shard, num_shards)
        output_file = os.path.join(hps.output_directory, output_filename)
        writer = tf.python_io.TFRecordWriter(output_file)
        shard_counter = 0
        files_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int)
        for i in files_in_shard:
            image_file_name = filenames[i]
            image_caption = captions_ids[image_file_name]
            caption_len = get_ids_length(image_caption)
            image_buffer = _process_image(image_file_name, hps) #读取img的图像、高度和宽度
            example = _convert_to_example(image_file_name, image_buffer,image_caption,caption_len)#将想要保存的数据封装成tf.train.Example
            writer.write(example.SerializeToString()) # 将Example数据写入到tfrecord 文件中。
            shard_counter += 1
            counter += 1

            if not counter % 1000:
                logging.info('%s [thread %d]: Processed %d of %d images in thread batch.' %
                             (datetime.now(), thread_index, counter, num_files_in_thread))
                sys.stdout.flush()

        writer.close()
        logging.info('%s [thread %d]: Wrote %d images to %s' %
                     (datetime.now(), thread_index, shard_counter, output_file))
        sys.stdout.flush()
        shard_counter = 0
    logging.info('%s [thread %d]: Wrote %d images to %d shards.' %
                 (datetime.now(), thread_index, counter, num_files_in_thread))
    sys.stdout.flush()

In [45]:
def _process_image_files(name, filenames,num_shards,hps):
    """Process and save list of images as TFRecord of Example protos.
    Args:
      name: string, unique identifier specifying the data set
      filenames: list of strings; each string is a path to an image file
      texts: list of strings; each string is human readable, e.g. 'dog'
      labels: list of integer; each integer identifies the ground truth
      num_shards: integer number of shards for this data set.
    """
    # Break all images into batches with a [ranges[i][0], ranges[i][1]].
    '''
    spacing 是文件序号的空间，比如文件数目是20，有2个线程来处理这些文件，那么spacing分成3个数，
    0,10,20
    则两个线程分别需要处理的序号是（0,10），（10,20），
    通过np.arrange(0,10)就会生成[0,1,2,...,9]这些序号列表，
    通过np.arrange(10,20)就会生成[10,11,12,...,19]这些序号列表。
    '''
    spacing = np.linspace(0, len(filenames), hps.num_threads + 1).astype(np.int)
    ranges = []
    for i in range(len(spacing) - 1):
        ranges.append([spacing[i], spacing[i + 1]])
    # Launch a thread for each batch.
    logging.info('Launching %d threads for spacings: %s' % (hps.num_threads, ranges))
    sys.stdout.flush()

    # Create a mechanism for monitoring when all threads are finished.
    coord = tf.train.Coordinator()
    # Create a generic TensorFlow-based utility for converting all image codings.
    threads = []
    for thread_index in range(len(ranges)):
        args = (thread_index, ranges,filenames,num_shards)
        t = threading.Thread(target=_process_image_files_batch, args=args)
        t.start()
        threads.append(t)

    # Wait for all the threads to terminate.
    coord.join(threads)
    logging.info('%s: Finished writing all %d images in data set.' %
                 (datetime.now(), len(filenames)))
    sys.stdout.flush()

In [46]:
import os
import numpy as np

In [47]:
file_names = list(captions.keys())

In [48]:
_process_image_files('name', file_names,hps.num_shards,hps)

In [None]:
filename_queue = tf.train.string_input_producer(['./tfrecords/1.tfrecord'])#生成一个queue队列
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)#返回文件名和文件
features = tf.parse_single_example(serialized_example,
                                   features={
                                       'image/height':tf.FixedLenFeature([], tf.int64),
                                        'image/width':tf.FixedLenFeature([], tf.int64),
                                        'image/caption':tf.FixedLenFeature([], tf.string),
                                        'image/filename':tf.FixedLenFeature([], tf.string),
                                        'image/img':tf.FixedLenFeature([], tf.string)
                                   })#将image数据和label取出来

img = features['image/img']
img = tf.image.decode_jpeg(img)
img_caption = features['image/caption']
height=tf.cast(features['image/height'],tf.int32)
width=tf.cast(features['image/width'],tf.int32)
coord = tf.train.Coordinator()


sess = tf.Session()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)
result_height,result_width,result_img,result_img_caption = sess.run([height,width,img,img_caption])
plt.imshow(result_img)
plt.title(' '.join([str(result_height),str(result_width)]))
plt.axis("off")
print(result_img_caption.decode())

In [None]:
f = open('ws.pkl','rb')
ws = pickle.load(f)

In [None]:
caption1 = captions[file_names[0]]

In [None]:
caption1[1]

In [None]:
with open('./caption_file/captions.txt','w') as g:
    bar = tqdm.tqdm(list(captions.keys()))
    for i in bar:
        ids = []
        for j in ws.word2id(captions[i][1]):
            ids.append(str(j))
        id_string = ' '.join(ids)
        g.write(id_string+'\n')


In [None]:
' '.join(ids)

In [None]:
import numpy as np
spacing = np.linspace(0, 101, 20 + 1).astype(np.int)
print(spacing)
ranges = []
for i in range(len(spacing) - 1):
    ranges.append([spacing[i], spacing[i + 1]])
print(ranges)

In [None]:
def _process_dataset(name, hps,filenames,texts):
    """Process a complete data set and save it as a TFRecord.
    Args:
      name: string, unique identifier specifying the data set.
      num_shards: integer number of shards for this data set.
    """
    _process_image_files(name, filenames, texts,hps.train_shards)

In [None]:
img = Image.open('./data/crop_imgs/3025093.jpg')
#     print(type(img))
#     print(dir(img))
height,width = img.size
img_buffer = img.tobytes()
_bytes_feature(img_buffer)

In [None]:
_process_image('3025093.jpg',hps)