## 创建词典

In [20]:
from collections import OrderedDict
import pickle
import logging

class Vocabulary:
    def __init__(self):
        self.word2id = OrderedDict()
        self.id2word = OrderedDict()
        self.logger = logging.getLogger('Vocabulary')
        self.max_vocab_length = None
        
    def get_word2id(self):
        return self.word2id
    
    def get_id2word(self):
        return self.id2word
    
    def get_vocabulary_length(self):
        return self.max_vocab_length
        
    def build_dictionary(self, data, max_vocab_length=None):
        self.max_vocab_length = max_vocab_length
        word_count = {}
        for item in data:
            if word_count.get(item):
                word_count[item] += 1
            else:
                word_count[item] = 1
        # 按照 value 排序, 倒序
        sorted_words = sorted(list(word_count.keys()), key=lambda x: word_count[x], reverse=True)
        # 如果有长度限制
        if max_vocab_length:
            sorted_words = sorted_words[:max_vocab_length - 2]
        
        self.word2id['[PAD]'] = 0
        self.word2id['[UNK]'] = 1
        for index, word in enumerate(sorted_words):
            self.word2id[word] = index + 2
            
        for word, index in self.word2id.items():
            self.id2word[index] = word
            
        self.logger.info('build dictionary SUCCESS')
    
    def save_dictionary(self, dict_file):
        with open(dict_file, 'wb') as f:
            pickle.dump(self.word2id, f)
            
        self.logger.info('save dictionary SUCCESS')
    
    def load_dictionary(self, dict_file):
        with open(dict_file, 'rb') as f:
            self.word2id = pickle.load(f)
        
        self.max_vocab_length = len(self.word2id)
        
        for word, index in self.word2id.items():
            self.id2word[index] = word
            
        self.logger.info('load dictionary SUCCESS')
        

In [21]:
data = ['hello', 'world', 'happy', 'hello', 'i', 'you', 'world']

vocab = Vocabulary()
vocab.build_dictionary(data, 3)
vocab.save_dictionary('test.pkl')

In [22]:
print(vocab.id2word)
print(vocab.word2id)

OrderedDict([(0, '[PAD]'), (1, '[UNK]'), (2, 'hello')])
OrderedDict([('[PAD]', 0), ('[UNK]', 1), ('hello', 2)])


In [23]:
vocab2 = Vocabulary()
vocab2.load_dictionary('test.pkl')
print(vocab2.get_vocabulary_length())
print(vocab2.get_word2id())
print(vocab2.get_id2word())

3
OrderedDict([('[PAD]', 0), ('[UNK]', 1), ('hello', 2)])
OrderedDict([(0, '[PAD]'), (1, '[UNK]'), (2, 'hello')])
