# Wikipedia Image-Caption Competition

### Import libraries

In [None]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

import base64
import io
import json
import numpy as np
import tensorflow as tf
import time
import IPython
import PIL

from kaggle_secrets import UserSecretsClient


try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print("Device:", tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()

AUTOTUNE = tf.data.experimental.AUTOTUNE
print("Number of replicas:", strategy.num_replicas_in_sync)
print(tf.__version__)

In [None]:
user_secrets = UserSecretsClient()
user_credential = user_secrets.get_gcloud_credential()
user_secrets.set_tensorflow_credential(user_credential)

### Import data
Data consists of first 10 joined datasets (00000-00009) from the archive here: https://analytics.wikimedia.org/published/datasets/one-off/caption_competition/training/joined/

In [None]:
DS_PATH = "../input/wikipedia-train-0"

DESC_COLUMN = "caption_title_and_reference_description"
IMG_COLUMN = "b64_bytes"
FEAT_COLUMN = "wit_features"
URL_COLUMN = "image_url"

filenames = sorted(os.listdir(DS_PATH))
json_content = []

start_time = time.time()
step_time = start_time
for file in filenames:
    filename = os.path.join(DS_PATH, file)
    with open(filename, "rb") as fr:
        for line in fr:
            if line:
                obj = json.loads(line)                
                content = {}
                content[URL_COLUMN] = obj[URL_COLUMN]
                content[DESC_COLUMN] = []
                content[IMG_COLUMN] = obj[IMG_COLUMN]
                
                for element in obj[FEAT_COLUMN]:
                    if element.get(DESC_COLUMN):
                        content[DESC_COLUMN].append(element[DESC_COLUMN])
                
                # Only keep content if both description and image are not empty
                if content[URL_COLUMN] != "" and content[IMG_COLUMN] != "" and len(content[DESC_COLUMN]) > 0:
                    json_content.append(content)
    cur_time = time.time()
    print(f"Import took {(cur_time - step_time):.3f} seconds from file: {file}")
    step_time = cur_time

end_time = time.time()
print(f"Time to read {len(filenames)} files: {(end_time - start_time):.3f} seconds")

In [None]:
print(json_content[0].keys())
print(f"Total items: {len(json_content)}")

### Visualize data

In [None]:
def display_from_json(content):
    decoded = base64.b64decode(content[IMG_COLUMN])
    image = PIL.Image.open(io.BytesIO(decoded)).convert("RGB")
    print(f"\nImage URL: {content[URL_COLUMN]}")
    print(f"Description: {content[DESC_COLUMN]}\n")
    IPython.display.display(image)

    
for _ in range(3):
    rand_index = np.random.randint(0, len(json_content))
    display_from_json(json_content[rand_index])