In [None]:
import nltk
import pickle
import os.path
from pycocotools.coco import COCO
from collections import Counter

class Vocabulary(object):
    def __init__(self, vocab_threshold, vocab_file='./vocab.pkl', start_word="<start>", end_word="<end>", unk_word="<unk>", annotations_file='../cocoapi/annotations/captions_train2014.json',vocab_from_file=False):
        self.vocab_threshold = vocab_threshold
        self.start_word = start_word
        self.end_word = end_word
        self.unk_word = unk_word
        self.annotations_file = annotations_file
        self.vocab_file = vocab_file
        self.vocab_from_file = vocab_from_file
        self.get_vocab()
        
    def get_vocab(self):
        #load vocab from file or build vocab from scratch
        if os.path.exist(self.vocab_file) and self.vocab_from_file:
            with open(self.vocab_file, 'rb') as f:
                vocab = pickle.load(f)
                self.word2idx = vocab.word2idx
                self.idx2word = vocab.idx2word
                print("vocab successfully loaded from vocab.pkl")
        else:
            self.build_vocab()
            with open(self.vocab_file, 'wb') as f:
                pickle.dump(self, f)
            
    def build_vocab(self):
        self.init_vocab()
        self.add_word(self.start_word)
        self.add_word(self.end_word)
        self.add_word(self.unk_word)
        self.add_captions()
        
    def init_vocab(self):
        self.word2idx = {}
        self.idx2word = {}
        self.idx = 0
        
    def add_word(self, word):
        # add token to vocab
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx += 1
            
    def add_captions(self):
        #loop over training captions and add all tokens to vocab if it satisfies to threshold
        coco = COCO(self.annotations_file)
        counter = Counter()
        ids = coco.anns.keys()
        for i, id in enumerate(ids):
            caption = str(coco.anns[id]['caption'])
            tokens = nltk.tokenize.word_tokenize(caption.lower())
            counter.update(tokens)
            
            if i % 100000 == 0:
                print("[%d/%d] Tokenizing captions.." % (i, len(ids)))
        
        words = [word for word, cnt in counter.items() if cnt >= self.vocab_threshold]
        
        for i, word in enumerate(words):
            self.add_word(word)
    def __call__(self):
        if not word in self.word2idx:
            return self.word2idx[self.unk_word]
        return self.word2idx[word]
    
    def __len__(self):
        return len(self.word2idx)
    

In [20]:
import torch
from collections import Counter
asd = torch.randn(3,3)
counter = Counter()

counter.update(asd)
asd, counter

(tensor([[ 1.4613,  0.7748, -0.9540],
         [-1.9958,  0.8001, -0.2460],
         [ 1.4827, -0.7581,  0.8209]]),
 Counter({tensor([ 1.4613,  0.7748, -0.9540]): 1,
          tensor([-1.9958,  0.8001, -0.2460]): 1,
          tensor([ 1.4827, -0.7581,  0.8209]): 1}))