In [1]:
import tensorflow as tf
import numpy as np
from transformer.hparams import Hparams

  from ._conv import register_converters as _register_converters


In [2]:
# data_load utils:

#1. tokenize using google sentence-piece split vocab 
def load_vocab(vocab_fpath):
    vocab = [line.split()[0] for line in open(vocab_fpath,'r').read().splitlines()]
    token2idx = {token:idx for idx,token in  enumerate(vocab)}
    idx2token = {idx:token for idx,token in enumerate(vocab)}
    return token2idx,idx2token

# load length controlled sentences

def load_data(fpath1,fpath2,maxlen1,maxlen2):
    sents1,sents2 = [],[]
    with open(fpath1,'r') as f1, open(fpath2,'r') as f2:
        for sent1,sent2 in zip(f1,f2):
            if len(sent1.split())+1 > maxlen1:
                continue
            if eln(sent2.split())+1 > maxlen2:
                continue
            sents1.append(sent1.strip())
            sents2.append(sent2.strip())
    return sents1,sents2

def calc_num_batches(total_num,batch_size):
    return total_num // batch_size + int(total_num % batch_size != 0)

def sent2num_encode(inp_sent,type_inp,t2idx_dict):
    inp_str = inp_sent.decode("utf-8")
    if type_inp == "x":
        tokens = inp_str.split() + ["</s>"]
    else:
        tokens= ["<s>"] + inp_str.split() + ["</s>"]
    enc_list = [t2idx_dict.get(t,t2idx_dict["<unk>"]) for t in tokens]
    return enc_list


def gen_fn(sents1,sents2,vocab_fpath):
    token2idx, _ = load_vocab(vocab_fpath) 
    #Now the generator part using yield
    for sent1,sent2 in zip(sents1,sents2):
        x = sent2num_encode(sents1,"x",token2idx)
        y = sent2num_encode(sents2,"y",token2idx)
        
        decoder_inp,y = y[:-1], y[1:]
        x_seq_len, y_seq_len =  len(x),len(y)
        yield (x,x_seq_len,sent1) , (decoder_inp,y,y_seq_len,sent2)


def data_feed_fn(sents1,sents2,vocab_fpath,batch_size,shuffle=False):
    # xs = (x[int32] , x_seq_len[int32], sents1[string])
    # ys = (decoder_inp[int32], y[int32], y_seq_len[int32], sents2[string])
    shapes=(([None],(), ()),([None],[None],(),()))  # represents xs, ys
    
    types = (
    (tf.int32,tf.int32,tf.string),
    (tf.int32,tf.int32,tf.int32,tf.string)
    )
    
    paddings = (
    (0,0,''),
    (0,0,0,'')
    )
    
    dataset = tf.data.Dataset.from_generator(
        gen_fn,
        output_shapes=shapes,
        output_types = types,
        args = (sents1,sents2,vocab_fpath))
    if shuffle:
        dataset = dataset.shuffle(128 * batch_size)
    
    dataset = dataset.repeat() # infinite fetch
    dataset  = dataset.padded_batch(batch_size, shapes, paddings).prefetch(1)
    return dataset

def get_batch(fpath1,fpath2,maxlen1,maxlen2,vocab_fpath,batch_size,shuffle=False):
    sents1,sents2 = load_data(fpath1,fpath2,maxlen1,maxlen2)
    batches = data_feed_fn(sents1,sents2,vocab_fpath,batch_size,shuffle)
    num_batches = calc_num_batches(len(sents1), batch_size)
    return batches, num_batches, len(sents1)
    





In [7]:
# Useful module functions..
#1. Token_embedding layer
def get_token_embeddings(vocab_sz,d_embed,zero_pad=True):
    with tf.variable_scope("shared_weight_matrix"):
        embeddings = tf.get_variable("weight_mat",
                                     dtype = tf.float32,
                                     shape= (vocab_sz,d_embed),
                                     initializer=tf.contrib.layers.xavier_initializer())
        if zero_pad:
            embeddings = tf.concat((tf.zeros(shape=[1,d_embed]) ,embeddings[1:,:]),0)
    return embeddings


#2. positional encoding as given in the paper 

def positional_encoding(inputs,maxlen,masking=True,scope="positional_encoding"):
    # Get embedding_dims into E : this is a fixed value
    E = inputs.get_shape().as_list()[-1]
    
    # Now get the batch_len [here 1] N  and  seq size T <= maxlen .. Both of these are dynamic , so the tf.shape here
    N, T = tf.shape(inputs)[0],tf.shape(inputs)[1]
    
    with tf.variable_scope(scope,reuse=tf.AUTO_REUSE):
        #Now get positional indices. This is implemented by creating a range(T) ie. [0..T]  , which can then be 
        # used to look for index in embedding. Tile it to create a stack of rows, which will be easy for look up
        
        positional_indices = tf.tile(tf.expand_dims(tf.range(T),0),[N,1]) # Tile N,1 times or [1,1] / 1 times here..
        
        # And that's super smart code.. B-)  
        position_enc = np.array([[pos/ np.power(10000,(i-i%2)/E) for i in range(E)]
                            for pos in range(maxlen)])
        
        #Apply sin to odd positions and cos to even positions in embeddings.. 
        position[:,::2] = np.sin(position_enc[:,::2]) # 2i
        position[:,1::2] = np.cos(position_enc[:,1::2]) # 2i+1
        
        #Convert to tensor to use it in embedding 
        position_enc = tf.convert_to_tensor(position_enc, tf.float32) # (maxlen, E)
        
        
        # lookup
        outputs = tf.nn.embedding_lookup(position_enc, position_ind)
        
        #Masking logic : if inputs contains 0 at that position, make it zero or use outputs
        
        if masking:
            outputs = tf.where(tff.equal(inputs,0) , inputs, outputs)
        
    return tf.to_float(outputs)
    
    
    




In [None]:
# Now comes the Transformer class

class Transformer:
    
    def __init__(self,hp):
        self.hp = hp
        self.token2idx, self.idx2token = load_vocab(hp.vocab)
        self.embeddings = get_token_embeddings(self.hp.vocab_size,self.hp.d_model,zero_pad=True)
        
    
    def encode(self,xs, training=True):
        with tf.variable_scope("encoder",reuse=tf.AUTO_REUSE):
            x, seqlens, sents1 = xs
            
            # Now the embedding part
            enc = tf.nn.embedding_lookup(self.embeddings,x)
            # scale the embedding with sqrt of model_dim
            enc *= self.hp.d_model ** 0.5
            
            # Now add positional encoding using the formulas given in the paper.
            enc += positional_encoding(enc,self.hp.maxlen1)
            
            
            
            