In [None]:
# What we want to do:
#1) Load two models together--> Loading happens through a session which restores variables, etc. 
#2) We wish to just load the encoder variables (pretrained) and add them to our computation for composite model
 

#What we try doing is based on the following logic:
#1) A graph defines all operations and data flow. 
#2) However, given that a graph can be invoked through a session, we will need to first create a session. 
#3) QUes: Can a graph only be invoked through a session

#4) Then we load this graph and import the encoder part of it into our composite graph.
#5) We finally run a session for this composite graph. 



#6) HOWEVER, we need to make sure that the encoder is not trained any further. 
#7) Otherwise, we will continue with just two separate sessions. 


'''
The final procedure:
1) Create session to load pretrained graph. Store the graph. Quite session. We do this only once. A graph has to be declared in advance. 
2) Set the encoder weights to be non trainable. 
3) Create new session to train original model and use encoder as a part of the operation flow. 


'''




#For optimization and graph construction:
https://www.kdnuggets.com/2017/05/how-not-program-tensorflow-graph.html
    
#General overview of graphs
https://www.tensorflow.org/programmers_guide/graphs#visualizing_your_graph
    
#Making non trainable variables
https://stackoverflow.com/questions/37326002/is-it-possible-to-make-a-trainable-variable-not-trainable?rq=1
https://stackoverflow.com/questions/35298326/freeze-some-variables-scopes-in-tensorflow-stop-gradient-vs-passing-variables
    
#Loading multiple graphs:
https://stackoverflow.com/questions/41990014/load-multiple-models-in-tensorflow
    
#Loading trained weights from one graph to another
https://stackoverflow.com/questions/39068703/tensorflow-using-weights-trained-in-one-model-inside-another-different-model?rq=1

In [5]:
import os
os.chdir('../model/')
import tensorflow as tf
import pickle
import numpy as np
from tensorflow.contrib.rnn import LSTMCell, LSTMStateTuple

In [6]:
def get_word_embeddings(filename):
    try:
        with np.load(filename) as data:
            return data["embeddings"]

    except IOError:
        raise MyIOError(filename)
        
def dataset_load(domain_tr_data_path, vocab_path):
    with open(domain_tr_data_path,'r') as p1:
        domain_tr_data = pickle.load(p1)
    with open(vocab_path,'r') as p1:
        vocab = pickle.load(p1)
        
    domain_tr_data = map(lambda x: x[0],domain_tr_data)
    idd_domain_tr_data = map(lambda x: [vocab[word] for word in x], domain_tr_data)
    return idd_domain_tr_data, vocab

In [7]:
domain_name = 'laptop'
domain_tr_data_path = '../data/Final_joint_data_absa//Domains/Laptop/Normal__normal_training_list.pickle'
embeddings_path = '../data/Embeddings/Pruned/np_glove_200d_trimmed.npz'
embeddings_name = 'glove200d'
vocab_path = '../data/vocab_to_id.pkl'


In [8]:
with open(vocab_path,'r') as p1:
        vocab = pickle.load(p1)
word_embeddings_np = get_word_embeddings(embeddings_path)
pad_token = '<PAD>' 
eos_token = '<END>'
PAD = vocab[pad_token]
EOS = vocab[eos_token]

In [9]:
vocab_size = len(vocab)
input_embedding_size = 200
encoder_hidden_units = 50 #100

In [10]:
encoder_inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='encoder_inputs')
encoder_inputs_length = tf.placeholder(shape=(None,), dtype=tf.int32, name='encoder_inputs_length')
embeddings = tf.Variable(word_embeddings_np, name="word_embeds",dtype=tf.float32, trainable=False)
encoder_inputs_embedded = tf.nn.embedding_lookup(embeddings, encoder_inputs)
encoder_cell = LSTMCell(encoder_hidden_units)
((encoder_fw_outputs,
          encoder_bw_outputs),
         (encoder_fw_final_state,
          encoder_bw_final_state)) = (
            tf.nn.bidirectional_dynamic_rnn(cell_fw=encoder_cell,
                                    cell_bw=encoder_cell,
                                    inputs=encoder_inputs_embedded,
                                    sequence_length=encoder_inputs_length,
                                    dtype=tf.float32, time_major=True)
            ) #'
    
encoder_outputs = tf.concat((encoder_fw_outputs, encoder_bw_outputs),2)
    

encoder_final_state_c = tf.concat(
        (encoder_fw_final_state.c, encoder_bw_final_state.c), 1)

encoder_final_state_h = tf.concat(
        (encoder_fw_final_state.h, encoder_bw_final_state.h), 1)

encoder_final_state = LSTMStateTuple(
        c=encoder_final_state_c,
        h=encoder_final_state_h
    ) #this is useful later

encoder_concat_everything = tf.concat([encoder_final_state_c,encoder_final_state_h], 1)

In [13]:
def dataset_load(domain_tr_data_path, vocab_path):
    with open(domain_tr_data_path,'r') as p1:
        domain_tr_data = pickle.load(p1)
    with open(vocab_path,'r') as p1:
        vocab = pickle.load(p1)
        
    domain_tr_data = map(lambda x: x[0],domain_tr_data)
    idd_domain_tr_data = map(lambda x: [vocab[word] for word in x], domain_tr_data)
    return idd_domain_tr_data

def get_holdout_data(idd_domain_tr_data, num=10):
    '''Simple function to check if encoder is functioning properly'''
    return idd_domain_tr_data[:num]

def batch_modify(inputs, max_sequence_length=None):
    """
    Args:
        inputs:
            list of sentences (integer lists)
        max_sequence_length:
            integer specifying how large should `max_time` dimension be.
            If None, maximum sequence length would be used
    
    Outputs:
        inputs_time_major:
            input sentences transformed into time-major matrix 
            (shape [max_time, batch_size]) padded with 0s
        sequence_lengths:
            batch-sized list of integers specifying amount of active 
            time steps in each input sequence
    """
    
    sequence_lengths = [len(seq) for seq in inputs]
    batch_size = len(inputs)
    
    if max_sequence_length is None:
        max_sequence_length = max(sequence_lengths)
    
    inputs_batch_major = PAD*np.ones(shape=[batch_size, max_sequence_length], dtype=np.int32) # == PAD
    
    for i, seq in enumerate(inputs):
        for j, element in enumerate(seq):
            inputs_batch_major[i, j] = element

    # [batch_size, max_time] -> [max_time, batch_size]
    inputs_time_major = inputs_batch_major.swapaxes(0, 1)

    return inputs_time_major, sequence_lengths

def feed_enc(enc_batch):
    
    encoder_inputs_, encoder_input_lengths_ = batch_modify(enc_batch)
    return {encoder_inputs: encoder_inputs_,
            encoder_inputs_length: encoder_input_lengths_}

In [14]:
idd_data = dataset_load(domain_tr_data_path,vocab_path)
holdout_data = get_holdout_data(idd_data,10)

In [15]:
embed_type = "Glove"
model_path = '../results/seq2seq/{}_seq2seqmodel_embeds{}_{}d_{}hiddenunits.ckpt'.format(domain_name,embed_type,input_embedding_size,encoder_hidden_units)

In [16]:
graph = tf.get_default_graph()
config = tf.ConfigProto(device_count={'GPU': 0})
saver = tf.train.Saver()
#encoder_saver = tf.train.Saver({"seq2seq_encoder": encoder_concat_everything})

In [17]:
'''Cell to test in format for experiment'''
def restored_model_enc_out(model_exists_already=True):
    with tf.Session(config=config) as sess:
        
        sess.run(tf.global_variables_initializer())

        print("Initialized session")
        if(model_exists_already):
            print("loading existing model")
            saver.restore(sess, model_path)
        
       
        f_enc = feed_enc(holdout_data)
        encoder_useful_state = sess.run(encoder_concat_everything, f_enc)
            
        #encoder_useful_state = sess.run(encoder_concat_everything)
    return encoder_useful_state

In [18]:
encoder_useful_state = restored_model_enc_out()

Initialized session
loading existing model
INFO:tensorflow:Restoring parameters from ../results/seq2seq/laptop_seq2seqmodel_embedsGlove_200d_50hiddenunits.ckpt


In [19]:
encoder_useful_state

array([[ 0.3665045 ,  0.19248164, -0.34810996, ..., -0.7770532 ,
        -0.06400465,  0.28331846],
       [ 0.35242832,  0.2059004 , -0.11128056, ..., -0.75132567,
        -0.02789734,  0.3261328 ],
       [ 0.38270092,  0.07817139, -0.59795696, ..., -0.6560076 ,
        -0.08645746,  0.39262387],
       ...,
       [ 0.39830253,  0.04827934, -0.6380171 , ..., -0.5388174 ,
        -0.07801261,  0.26077485],
       [ 0.15671724,  0.04524764, -0.34900898, ..., -0.7905662 ,
        -0.10922995,  0.37479112],
       [ 0.48515686,  0.4221653 , -0.5161304 , ..., -0.52702236,
        -0.04700104,  0.05832616]], dtype=float32)