In [1]:
import numpy as np
np.set_printoptions(threshold=np.inf)
import os
import json
import h5py
from tqdm import tqdm
from collections import Counter
import cv2
import nltk
import torch
import torch.utils.data as tud

In [2]:
data_folder = r'/home/anna/pycharm_proj/image_caption/Flickr8k/'
input_files_folder = '/home/anna/pycharm_proj/image_caption/input_files/'

###### 获取全部数据集的图像路径与描述

In [3]:
with open(os.path.join(data_folder, 'Flickr8k_text/Flickr8k.lemma.token.txt'), 'r', encoding='utf-8') as fp:
    lines = fp.read().split('\n')
image_paths = [line.split('\t')[0].split('#')[0] for line in lines]
image_captions = [nltk.word_tokenize(line.split('\t')[-1]) for line in lines]
# image_paths.remove('')
# image_captions.remove('')
# del image_captions[-1]
# assert len(image_captions) == len(set(image_paths)) * 5 # 5 captions per image
image_path2caption = {image_paths[5*i]:image_captions[5*i:5*i+5] for i in range(int(len(image_paths)/5))}

###### 分别将训练集，验证集，测试集读入内存

In [4]:
trainImages_paths = []
trainImages_captions = []
devImages_paths = []
devImages_captions = []
testImages_paths = []
testImages_captions = []
for split in ['trainImages', 'devImages', 'testImages']:
    with open(os.path.join(data_folder, 'Flickr8k_text/Flickr_8k.' + split + '.txt'), 'r', encoding='utf-8') as fp:
        path = fp.read().split('\n')
        path.remove('')
    caption = [image_path2caption[p] for p in path]
    assert len(caption) == len(path)
    locals()[split+'_paths'] = path 
    locals()[split+'_captions'] = caption

###### 生成word_map并写入 .json

In [5]:
max_len = 100
freq = Counter()
for caps in [trainImages_captions, devImages_captions, testImages_captions]:
    for cap in caps:
        for sentence in cap:
            if len(sentence) >= max_len:
                print("警告！caption的长度大于100 ！")
            freq.update(sentence)

min_word_freq = 5
words = [w for w in freq.keys() if freq[w] > min_word_freq]
word_map = {w: n + 1 for n, w in enumerate(words)}
word_map['<unk>'] = len(word_map) + 1
word_map['<start>'] = len(word_map) + 1
word_map['<end>'] = len(word_map) + 1
word_map['<pad>'] = 0

input_files_folder = '/home/anna/pycharm_proj/image_caption/input_files/'
with open(os.path.join(input_files_folder, 'WORDMAP.json'), 'w') as j:
    json.dump(word_map, j)

###### 生成输入数据文件

In [8]:
for impaths, imcaps, split in [(trainImages_paths, trainImages_captions, 'TRAIN'),
                               (devImages_paths, devImages_captions, 'VAL'),
                               (testImages_paths, testImages_captions, 'TEST')]:

    with h5py.File(os.path.join(input_files_folder, split + '_IMAGES' + '.hdf5'), 'a') as h:
        h.attrs['captions_per_image'] = 5

        # 创建一个HDF5的dataset来存储images
        images = h.create_dataset('images', (len(impaths), 3, 256, 256), dtype='uint8')

        print("\nReading %s images and captions, storing to file...\n" % split)

        enc_captions = []
        caplens = []

        for i, path in enumerate(tqdm(impaths)):

            # Read images
            img = cv2.imread(data_folder + 'Flickr8k_Dataset/Flicker8k_Dataset/' + impaths[i])
            img = cv2.resize(img, (256, 256))
            img = img.transpose(2, 0, 1) # 通道变换为NCHW
            assert img.shape == (3, 256, 256)
            assert np.max(img) <= 255

            # Save image to HDF5 file
            images[i] = img

            for j, c in enumerate(imcaps[i]):
                # Encode captions
                enc_c = [word_map['<start>']] + [word_map.get(word, word_map['<unk>']) for word in c] + [word_map['<end>']] + [word_map['<pad>']] * (max_len - len(c))

                # Find caption lengths
                c_len = len(c) + 2

                enc_captions.append(enc_c)
                caplens.append(c_len)

        # Sanity check
        assert images.shape[0] * 5 == len(enc_captions) == len(caplens)

        # Save encoded captions and their lengths to JSON files
        with open(os.path.join(input_files_folder, split + '_CAPTIONS' + '.json'), 'w') as j:
            json.dump(enc_captions, j)

        with open(os.path.join(input_files_folder, split + '_CAPLENS' + '.json'), 'w') as j:
            json.dump(caplens, j)

  0%|          | 26/6000 [00:00<00:23, 257.94it/s]


Reading TRAIN images and captions, storing to file...



100%|██████████| 6000/6000 [00:19<00:00, 308.94it/s]
  3%|▎         | 29/1000 [00:00<00:03, 285.02it/s]


Reading VAL images and captions, storing to file...



100%|██████████| 1000/1000 [00:03<00:00, 298.30it/s]
  3%|▎         | 32/1000 [00:00<00:03, 314.65it/s]


Reading TEST images and captions, storing to file...



100%|██████████| 1000/1000 [00:03<00:00, 314.90it/s]
