In [4]:
import sys
import numpy as np
import pandas as pd
import random
import collections
import re
from nltk.tokenize import word_tokenize

In [5]:
class VocabSet :
    
    def __init__(self, tokenize_fn , th=3) :
        
        self.th = th
        self.tokenize = tokenize_fn
        
    def tokens(self, data) :
        
        vocab_set = collections.Counter()
        
        for i , sen in enumerate(data) :
            tokens = self.tokenize(sen.lower())
            vocab_set.update(tokens)
                
        vocab_dict = dict(vocab_set)
        valid_tok = []
        
        for tok in vocab_dict.keys() :
            count = vocab_dict[tok]
            tok = tok
            if (count > self.th) and (re.search('[0-9]' , tok) == None):
                valid_tok.append(tok)
                
        random.shuffle(valid_tok)
        valid_tok.append('utk')
                
        return dict(zip(valid_tok , range(1,len(valid_tok)+1)))
        

In [6]:
class Encoder :
    
    def __init__(self, data, tokenize_fn , token_dict) :
        
        self.data = data
        self.tokenize = tokenize_fn
        self.token_dict = token_dict
        
        self.sos = len(token_dict) + 1 # Start Token
        self.eos = len(token_dict) + 2 # End Token
        self.v_size = len(token_dict) + 3 # 0 , Start Token , End Token
        
    def get_size(self) :
        
        return self.v_size
        
    def check(self, tok) :
        if tok in self.token_dict :
            return True
        else :
            return False
        
    def encode_sen(self, sen) :
        tokens = self.tokenize(sen)
        encoded = []
        
        for tok in tokens :
            if self.check(tok) :
                idx = self.token_dict[tok]
            else :
                idx = self.token_dict['utk']
            encoded.append(idx)
            
        encoded = [self.sos] + encoded + [self.eos]
            
        return encoded
    
    def encode(self) :
        encoded = []
        
        for i , sen in enumerate(self.data) :
            encoded_sen = self.encode_sen(sen.lower())
            encoded.append(encoded_sen)
        
        return encoded