In [2]:
import os
import nltk
import pickle
from collections import Counter

class Vocabulary(object):
    """Simple vocabulary wrapper."""
    def __init__(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0

    def add_word(self, word):
        if word not in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1

    def __call__(self, word):
        if word not in self.word2idx:
            return self.word2idx['<unk>']
        return self.word2idx[word]

    def __len__(self):
        return len(self.word2idx)


def build_vocab(captions, threshold):
    # If the word frequency is less than 'threshold', then the word is discarded.
    counter = Counter(captions)
    words = [word for word, cnt in counter.items() if cnt >= threshold]

    # Create a vocab wrapper and add some special tokens.
    vocab = Vocabulary()
    vocab.add_word('<pad>')
    vocab.add_word('<start>')
    vocab.add_word('<end>')
    vocab.add_word('<unk>')

    # Add the words to the vocabulary.
    for i, word in enumerate(words):
        vocab.add_word(word)
    return vocab


def main():
    # Set the paths and threshold
    caption_path = r"C:\Users\shema\Downloads\cifar-100-python\cifar-100-python\train"
    # caption_path = ".data/cifar-100-python/train"
    vocab_path = './data/vocab.pkl'
    threshold = 4

    # Try different encodings to read the captions
    encodings = ['utf-8', 'latin-1']  # List of encodings to try

    for encoding in encodings:
        try:
            with open(caption_path, 'r', encoding=encoding) as f:
                captions = f.readlines()
            break  # Exit the loop if decoding is successful
        except UnicodeDecodeError:
            continue  # Try the next encoding if decoding fails

    # Extract the captions
    captions = [caption.strip() for caption in captions if caption.strip()]

    # Flatten the captions
    captions = [word for caption in captions for word in caption.split()]

    # Build the vocabulary
    vocab = build_vocab(captions, threshold=threshold)

    # Save the vocabulary wrapper
    if not os.path.exists(os.path.dirname(vocab_path)):
        os.makedirs(os.path.dirname(vocab_path))

    with open(vocab_path, 'wb') as f:
        pickle.dump(vocab, f)

    print("Total vocabulary size: {}".format(len(vocab)))
    print("Saved the vocabulary wrapper to '{}'".format(vocab_path))


if __name__ == '__main__':
    main()


Total vocabulary size: 21386
Saved the vocabulary wrapper to './data/vocab.pkl'
