# Data preprocessing for Image captioning task
## **Part 1**: 
Download the images whose urls are specified in the training dataset.  

In [1]:
import requests
import numpy as np
import pandas as pd
import os
import torch as t

# Load Data
data_path = 'raw_data/Train_GCC-training.tsv'
data = pd.read_csv(data_path, sep='\t', header=None)

In [2]:
from urllib.request import urlretrieve
from urllib.error import HTTPError

# Download the images into imgs/ folder
def download(url, filename, i):
    if os.path.exists(filename):
        print('file exists!')
        return
    try:
        r = requests.get(url, stream=True, timeout=20)
        print("Download imgs: {} Done".format(i))
        with open(filename, 'wb') as f:
            for chunk in r.iter_content(chunk_size=1024):
                if chunk:
                    f.write(chunk)
                    f.flush()
        return True
    
    except:
        if os.path.exists(filename):
            os.remove(filename)
        return False

def read_file():
    lost_labels = []
    if not os.path.exists('imgs'):
        os.mkdir('imgs')
    
    for i, url in enumerate(data.loc[:,1].values):
        filename = os.path.join('imgs/','{}.jpg'.format(i))
        if not download(url,filename,i):
            lost_labels.append(i)
        if i == 50000:
            break
    
    return lost_labels

In [None]:
lost_labels = read_file()

In [55]:
#Image validity test
import warnings
warnings.filterwarnings('ignore')
from PIL import Image

img_data = os.listdir('imgs')
a = []
for i in range(len(img_data)):
    try:
        img = Image.open(os.path.join('imgs/'+img_data[i]))
    except:
        a.append(img_data[i])
        pass

In [None]:
for i in range(len(a)):
    filename = os.path.join('imgs/'+a[i])
    if os.path.exists(filename):
        os.remove(filename)
        print("delete: %d"%i)

In [2]:
img_data = os.listdir('imgs')
img_data = list(map(lambda x: x.split('.')[0], img_data))
img_data = sorted([int(i) for i in img_data if i != ''])
lost_data = sorted(list(set(range(50001)).difference(set(img_data))))
print(len(lost_data))

# image id dictionary
# 0.jpg -> 0
id2ix = {str(item)+'.jpg': ix for ix, item in enumerate(img_data)}
# 0-> 0.jpg
ix2id = {item: id for item, ix in (id2ix.items())}

3020


## **Part 2**: 
Deal with the text data and store information for the following use.

In [3]:
captions = {str(i)+'.jpg': caption.split() for i, caption in enumerate(data.loc[:50001,0].values) if i not in lost_data}


In [4]:
# word id dictionary
word_nums = {}
def count(word_nums):
    def count_word(word):
        word_nums[word] = word_nums.get(word,0)+1
        return None
    return count_word
lambda_ = count(word_nums)

_ = {lambda_(word) for _,caption in captions.items() for word in caption}
word_nums = sorted(word_nums.items(), key=lambda x: x[1], reverse=True)

In [5]:
words = [x[0] for x in word_nums if x[1] >= 2 and len(x[0]) <= 12]
words = ["<START>", "<EOS>", "<UNK>", "<PAD>"]+words
word2ix = {word: ix for ix, word in enumerate(words)}
ix2word = {ix: word for word, ix in word2ix.items()}

In [6]:
def lambda_(words):
    word_list = []
    for word in words:
        if word in word2ix:
            word_list.append(word2ix[word])
        else:
            word_list.append(word2ix["<UNK>"])
    return word_list
captions_list = [lambda_(words) for key, words in captions.items()]

In [7]:
results = {
        'captions': captions,
        'captions_list': captions_list,
        'word2ix': word2ix,
        'ix2word': ix2word,
        'id2ix': id2ix,
        'ix2id': ix2id,
}
t.save(results, "caption.pth")
print('save file in caption.pth')

save file in caption.pth


In [3]:
results = t.load('caption.pth')
print(results['captions']['49998.jpg'])
results['id2ix']['49998.jpg']

['a', 'public', 'fountain', 'turned', 'into', 'a', 'pool', '.']


46979