In [1]:
import numpy as np
import os
import tensorflow as tf
import random
import re
from time import sleep
import math
import datetime
import time
import pickle
import layers
import functions
import decoder
import encoder
from dataset import PubMed_Dataset
import data_batcher
import models
import summarizer
import data_utils
from sys import stdout
from vocab import Vocab_Lookup

In [2]:
meta_dir = os.path.join(os.getcwd(), 'PubMed')
log_dir = os.path.join(meta_dir, 'logs')
weights_dir = os.path.join(meta_dir, 'weights')
params_dir = os.path.join(meta_dir, 'params')
data_dir = os.path.join(meta_dir, 'data_cache')

In [3]:
#data = pickle.load(open("data.pickle", "rb"))
vocab_lookup = pickle.load(open(os.path.join(meta_dir, "vocab_lookup_30000.pickle"), "rb"))

In [4]:
train_files = []
val_files = []
test_files = []
for filename in os.listdir(data_dir):
    if 'train' in filename:
        train_files.append(os.path.join(data_dir, filename))
    elif 'val' in filename:
        val_files.append(os.path.join(data_dir, filename))
    elif 'test' in filename:
        test_files.append(os.path.join(data_dir, filename))

In [5]:
def data_partition_loader(partition_files):
    i = 0
    while True:
        partition_file = partition_files[i]
        i += 1
        yield pickle.load(open(partition_file, 'rb'))

In [6]:
train_partition_loader = data_partition_loader(train_files)
val_partition_loader = data_partition_loader(val_files)
test_partition_loader = data_partition_loader(test_files)

train_data = next(train_partition_loader)
val_data = next(val_partition_loader)
test_data = next(test_partition_loader)

In [7]:
batch_size = 128
train_batcher = data_batcher.Data_Batcher(train_data, batch_size)
val_batcher = data_batcher.Data_Batcher(val_data, batch_size)
test_batcher = data_batcher.Data_Batcher(test_data, batch_size)
deploy_batcher = data_batcher.Data_Batcher(val_data, 1)

In [8]:
d_pad_len = 150
s_pad_len = 20
embd_dim = 100
hidden_size = 512
n_layers = 2
vocab_size = vocab_lookup.num_words
dropout_keep_prob = 0.8
bidirectional = False
shared_embeddings = True
teacher_forcing_ratios = [1.0] 
teacher_forcing_steps = [1]

display_interval = 100
val_interval = 1000
deploy_interval = 1000
save_interval = 10000
n_iters = 200000 

lr = 0.001
DEVICE = 0
USE_CUDA = True
DEBUG_MODE = False

# w2v = pickle.load(open("w2v_CNN-Dailymail_100.pickle", "rb"))
pretrained_embeddings = None #functions.create_embeddings(vocab_lookup, w2v)    

In [None]:
device_name = '/gpu:{}'.format(DEVICE) if USE_CUDA else '/cpu:{}'.format(DEVICE)

if USE_CUDA:
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = "{}".format(DEVICE)

tf.reset_default_graph()
with tf.device(device_name):
    net = models.Seq2Seq_Basic_Attn(vocab_size, d_pad_len, s_pad_len, embedding_dim=embd_dim, hidden_size=hidden_size, 
                                    n_layers=n_layers, bidirectional=bidirectional, pretrained_embeddings=pretrained_embeddings, 
                                    trainable_embeddings=True, shared_embeddings=shared_embeddings, weight_tying=True,
                                    rnn_cell=tf.contrib.rnn.GRUCell)
    model = summarizer.Text_Summarization(net, lr=lr, mode='train')
    functions.count_params(tf.trainable_variables())
    for var in tf.trainable_variables(): print(var)

  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


# Trainable Parameters: 8427092
<tf.Variable 'embeddings_layer/embeddings:0' shape=(30000, 100) dtype=float32_ref>
<tf.Variable 'encoder/multi_rnn_cell/cell_0/gru_cell/gates/kernel:0' shape=(612, 1024) dtype=float32_ref>
<tf.Variable 'encoder/multi_rnn_cell/cell_0/gru_cell/gates/bias:0' shape=(1024,) dtype=float32_ref>
<tf.Variable 'encoder/multi_rnn_cell/cell_0/gru_cell/candidate/kernel:0' shape=(612, 512) dtype=float32_ref>
<tf.Variable 'encoder/multi_rnn_cell/cell_0/gru_cell/candidate/bias:0' shape=(512,) dtype=float32_ref>
<tf.Variable 'encoder/multi_rnn_cell/cell_1/gru_cell/gates/kernel:0' shape=(1024, 1024) dtype=float32_ref>
<tf.Variable 'encoder/multi_rnn_cell/cell_1/gru_cell/gates/bias:0' shape=(1024,) dtype=float32_ref>
<tf.Variable 'encoder/multi_rnn_cell/cell_1/gru_cell/candidate/kernel:0' shape=(1024, 512) dtype=float32_ref>
<tf.Variable 'encoder/multi_rnn_cell/cell_1/gru_cell/candidate/bias:0' shape=(512,) dtype=float32_ref>
<tf.Variable 'attention/attn_W:0' shape=(512, 5

In [None]:
### params = {key : value for key, value in net.__dict__.items() if not key.startswith('__') and not key.startswith('_')
          and not callable(key) and str(type(value)).find('tensorflow') == -1}
model_name = net.__class__.__name__

if not DEBUG_MODE:
    timestamp = datetime.datetime.fromtimestamp(time.time()).strftime('%Y%m%d_%H%M%S')
    log_file = os.path.join(log_dir, '{}_train_log_{}.txt'.format(model_name, timestamp))
    log_description = '0.001 lr, 128 batch size, params: {}\n'.format(params)
    log = open(log_file, 'w')
    log.close()
    functions.write_to_log(log_description, log_file)
    
    params_filename = '{}_params_{}.pickle'.format(model_name, timestamp)
    with open(os.path.join(params_dir, params_filename), 'wb') as handle:
        pickle.dump(params, handle, protocol=pickle.HIGHEST_PROTOCOL)
        
    checkpoint_dir = os.path.join(weights_dir, '{}_checkpoints_{}'.format(model_name, timestamp))
    os.mkdir(checkpoint_dir)

epoch = 0
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
    sess.run(tf.global_variables_initializer())
    
    if not DEBUG_MODE:
        saver = tf.train.Saver(max_to_keep=100)
     
    best_val_loss = 10e6
    for itr in range(1, n_iters+1):
        if itr in teacher_forcing_steps:
            teacher_forcing_ratio = teacher_forcing_ratios[teacher_forcing_steps.index(itr)]
            
        examples, ep = train_batcher.next_batch()
        if ep == 1:
            try:
                train_data = next(train_partition_loader)
            except:
                epoch += 1
                train_partition_loader = data_partition_loader(train_files)
                train_data = next(train_partition_loader)
            train_batcher = data_batcher.Data_Batcher(train_data, batch_size)
            examples, ep = train_batcher.next_batch()
            
        inputs = [example.source_ids for example in examples] 
        targets = [example.target_ids for example in examples] 
        input_lens = [example.source_len for example in examples]
        target_lens = [example.target_len for example in examples] 

        teacher_forcing = True if random.random() < teacher_forcing_ratio else False
        train_loss, train_acc, grad_norm = model.train_step(sess, inputs, targets, input_lens, target_lens, 
                                                            targets, target_lens, dropout_keep_prob=dropout_keep_prob, 
                                                            teacher_forcing=teacher_forcing)

        if itr % display_interval == 0 or itr == 1:     
            log_string = ('[%d, %5d] loss: %.3f, accuracy: %.3f, grad_norm: %.3f' 
                          % (epoch, itr, train_loss, train_acc, grad_norm))

            if not DEBUG_MODE:
                functions.write_to_log(log_string, log_file)
            print(log_string)

        if itr % val_interval == 0:
            val_loss, val_acc = 0.0, 0.0
            for i in range(int(len(val_batcher.data)/val_batcher.batch_size)):
                examples, ep = val_batcher.next_batch()
                if ep == 1:
                    try:
                        val_data = next(val_partition_loader)
                    except:
                        val_partition_loader = data_partition_loader(val_files)
                        val_data = next(val_partition_loader)
                    val_batcher = data_batcher.Data_Batcher(val_data, batch_size)
                    examples, ep = val_batcher.next_batch()
                inputs = [example.source_ids for example in examples]
                targets = [example.target_ids for example in examples] 
                input_lens = [example.source_len for example in examples]
                target_lens = [example.target_len for example in examples]
                dummy_dec_inputs = np.zeros_like(targets, dtype=int)
                dummy_dec_lens = np.zeros_like(target_lens, dtype=int)

                val_batch_loss, val_batch_acc = model.val_step(sess, inputs,dummy_dec_inputs, input_lens, dummy_dec_lens, 
                                                               targets, target_lens)
                val_loss += ((val_batch_loss - val_loss)/(i+1))
                val_acc += ((val_batch_acc - val_acc)/(i+1))
                #val_loss.append(val_batch_loss)
                #val_acc.append(val_batch_acc)
                if (i+1)*val_batcher.batch_size >= 10000:
                    break
            #val_loss = np.mean(val_loss)
            #val_acc = np.mean(val_acc)
            log_string = ('Validation - loss: %.3f, accuracy: %.3f' % (val_loss, val_acc))

            if not DEBUG_MODE:
                functions.write_to_log(log_string, log_file)
            print(log_string)
            
            if not DEBUG_MODE:
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    weights_prefix = '{}_weights_epoch_{}_itr_{}'.format(model_name, epoch, itr)
                    log_msg = "Weights saved in file: {}\n".format(os.path.join(checkpoint_dir, weights_prefix))
                    print(log_msg)
                    saver.save(sess, os.path.join(checkpoint_dir, weights_prefix))
                    functions.write_to_log(log_msg, log_file)
                elif itr % save_interval == 0:
                    weights_prefix = '{}_weights_epoch_{}_itr_{}'.format(model_name, epoch, itr)
                    log_msg = "Weights saved in file: {}\n".format(os.path.join(checkpoint_dir, weights_prefix))
                    print(log_msg)
                    saver.save(sess, os.path.join(checkpoint_dir, weights_prefix))
                    functions.write_to_log(log_msg, log_file)
            
        if itr % deploy_interval == 0:
            examples, _ = deploy_batcher.next_batch()
            example = examples[0]
            inputs = [example.source_ids]
            input_lens = [example.source_len]
            dummy_dec_inputs = [np.zeros_like(example.target_ids, dtype=int)]
            
            predictions = model.deploy(sess, inputs, input_lens, dummy_dec_inputs)
            generated_words = [vocab_lookup.convert_id2word(prediction) for prediction in predictions[0]]
            
            log_string = ('DOCUMENT:\n{}\nMODEL:\n{}\nGROUND TRUTH:\n{}'
                          .format(example.source_text, 
                                  ' '.join(generated_words), 
                                  example.target_text))
            if not DEBUG_MODE:
                functions.write_to_log(log_string, log_file)
            print(log_string)