In [1]:
import nltk
nltk.download('punkt')


class Vocabulary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

    def add_word(self, word):
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __call__(self, word):
        if not word in self.word2idx:
            return self.word2idx['<unk>']
        return self.word2idx[word]

    def __len__(self):
        return len(self.word2idx)


[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Zoomi\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [4]:
from collections import Counter
import json

def build_vocab(json_data_path, threshold):
    data = json.load(open(json_data_path))
    captions = []
    for k, v in data.items():
        if k == 'train' or k == 'validation':
            for item in v:
                caption = list(item.values())[0]['caption']
                captions.append(caption)
    counter = Counter()
    for i, caption in enumerate(captions):
        tokens = nltk.tokenize.word_tokenize(caption.lower())
        counter.update(tokens)
        if (i+1) % 500 == 0:
            print("[{}/{}] Tokenized the captions.".format(i+1, len(captions)))

    # If the word frequency is less than 'threshold', then the word is discarded.
    words = [word for word, cnt in counter.items() if cnt >= threshold]

    # Create a vocab wrapper and add some special tokens.
    vocab = Vocabulary()
    vocab.add_word('<pad>')
    vocab.add_word('<start>')
    vocab.add_word('<end>')
    vocab.add_word('<unk>')

    # Add the words to the vocabulary.
    for i, word in enumerate(words):
        vocab.add_word(word)
    return vocab

In [5]:
import pickle

def main(json_data, threshold, vocab_path='vocab.pkl'):
    built_vocab = build_vocab(json_data, threshold)
    # Save the vocabulary to a file
    with open(vocab_path, 'wb') as f:
        pickle.dump(built_vocab, f)

    print("Total vocabulary size: {}".format(len(built_vocab)))
    print("Saved the vocabulary wrapper to '{}'".format(vocab_path))


threshold_word_filtering = 3
main(json_data='selected_dataset/selected_dataset_info.json', threshold=threshold_word_filtering)

[500/2500] Tokenized the captions.
[1000/2500] Tokenized the captions.
[1500/2500] Tokenized the captions.
[2000/2500] Tokenized the captions.
[2500/2500] Tokenized the captions.
Total vocabulary size: 2236
Saved the vocabulary wrapper to 'vocab.pkl'
