In [1]:
import os
import _pickle as pickle
import numpy as np
import random

In [4]:
class GloveEmbeddings:
    def __init__(self):
        self.word2id = {}
        self.vectors = []
        self.words = []
        self.dim = None
        
    def load_glove(self, filename):
        id = 0
        with open(filename) as file:
            for line_ in file:
                line = line_.split()
                word = line[0]
                self.words.append(word)
                self.word2id[word] , id = id, id + 1
                vect = np.array(line[1:]).astype(np.float)
                self.vectors.append(vect)
        self.vectors = np.array(self.vectors)
        self.dim = self.vectors.shape[-1]
        self.add_to_vocab('<unk>')
        self.add_to_vocab('<pad>')
        
        self.unknown_idx = self.word2id['<unk>']
        self.padding_idx = self.word2id['<pad>']
    
    def modify_pretrained(self, vocab):
        for i in vocab.values():
            self.add_to_vocab(i)
    
    def add_to_vocab(self, word):
        word = word.lower()
        self.words.append(word)
        self.word2id[word] = len(self.words) - 1
        self.vectors = np.append(self.vectors, np.random.random(self.dim))
        
    def dump_all(self, filename):
        pickle.dump([self.word2id, self.vectors.reshape(-1), self.words, self.dim], open(filename, 'wb+'))
        
    def load_dump(self, filename1):
        self.word2id, self.vectors, self.words, self.dim = pickle.load(open(filename1, 'rb+'))
        self.vectors = self.vectors.reshape(-1, self.dim)
        
    def convert_to_indices(self, lines):
        '''
        @param lines: lines are list of list of strings. each string is considered as a token
        '''
        indices = []
        for line in lines:
            indices.append([])
            for word in line:
                if word in self.word2id:
                    id = self.word2id[word]
                elif word.lower() in self.word2id:
                    id = self.word2id[word.lower()]
                else:
                    id = self.word2id['<unk>']
                indices[-1].append(id)
        return indices