In [None]:
import tensorflow as tf
import os
import random
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from nltk.corpus import stopwords
from string import punctuation
import pickle

from src import datasets
from src import data_utils
from src import models
from src import functions

In [None]:
params = {
    'dataset' : 'newsgroups',
    'val_split' : 0.1,
    'vocab_size' : 30000,
    'max_len' : 100,
    'embedding_dim' : 100,
    'hidden_size' : 128,
    'straight_through' : False,
    'dropout_keep_prob' : 0.5,
    'use_attn' : True,
    'n_layers' : 1,
    'hierarchies' : [6, 6, 6],
    'batch_size' : 32,
    'n_epochs' : 200,
    'learning_rate' : 0.001,
    'gumbel_temperature' : 1.0,
    'discrete_loss_weight' : 1.0,
    'continuous_loss_weight' : 1.0,
    'gumbel_loss_weight' : 0.0,
    'orthogonality_loss_weight' : 0.0,
    'mse_loss_weight' : 0.0,
    'trainable_embeddings' : True,
    'rnn_cell' : tf.contrib.rnn.GRUCell,
    'device' : 0,
    'use_cuda' : True,
    'debug_mode' : False,
    'display_interval' : 200,
    'val_interval' : 1000,
    'deploy_interval' : 5000,
    'weights_file' : None, #'train_results/RNNClassifier_01102019_160346/weights/RNNClassifier_weights_epoch_65_itr_21000',
    'finetune' : False
}

In [None]:
if params['dataset'] == 'newsgroups':
    data = datasets.NewsGroup(params['val_split'])
elif params['dataset'] == 'reuters8':
    data = datasets.Reuters('data/reuters', 8, params['val_split'])
elif params['dataset'] == 'reuters52':
    data = datasets.Reuters('data/reuters', 52, params['val_split'])

In [None]:
tfidf = TfidfVectorizer(stop_words=stopwords.words('english') + list(punctuation))
tfidf.fit(data.train_texts + data.val_texts)

In [None]:
vectorizer = data_utils.SequenceVectorizer(vocab=None, lowercase=True, keep_digits=True, keep_punctuations=False, punctuations_to_keep=['.', '?', '!'], 
                 vocab_size=params['vocab_size'], max_sequence_len=params['max_len'], pad_token='PAD', unk_token='UNK',  eos_token=None, go_token=None)
vectorizer.fit(data.train_texts)

In [None]:
class Example:
    def __init__(self, text, label, vectorizer):
        self.text = text
        self.label = label
        self.ids, self.real_len = vectorizer.transform(text)

In [None]:
train_examples = [Example(text, label, vectorizer) for text, label in zip(data.train_texts, data.train_labels)]
val_examples = [Example(text, label, vectorizer) for text, label in zip(data.val_texts, data.val_labels)]
test_examples = [Example(text, label, vectorizer) for text, label in zip(data.test_texts, data.test_labels)]

In [None]:
train_batcher = data_utils.DataBatcher(train_examples, params['batch_size'])
val_batcher = data_utils.DataBatcher(val_examples, params['batch_size'])
test_batcher = data_utils.DataBatcher(test_examples, params['batch_size'])

### Model

In [None]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "{}".format(params['device'])
device_name = '/gpu:{}'.format(params['device']) if params['use_cuda'] else '/cpu:{}'.format(params['device'])

tf.reset_default_graph()
graph = tf.Graph()
with graph.as_default(), tf.device(device_name):
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    model = models.DiscreteHierarchicalClassifier(
        len(data.classes), 
        params['vocab_size'], 
        params['max_len'], 
        learning_rate=params['learning_rate'],
        embedding_dim=params['embedding_dim'], 
        hidden_size=params['hidden_size'], 
        straight_through=params['straight_through'],
        use_attn=params['use_attn'], 
        n_layers=params['n_layers'], 
        hierarchies=params['hierarchies'],
        pretrained_embeddings=None, 
        trainable_embeddings=params['trainable_embeddings'], 
        rnn_cell=params['rnn_cell'],
        class_weights=data.class_weights
    )
    
    init = tf.global_variables_initializer()
    pretrain_loader = tf.train.Saver(var_list=model.pretrain_vars)
    saver = tf.train.Saver(var_list=tf.trainable_variables(), max_to_keep=10)

    functions.count_params(tf.trainable_variables())
    for var in tf.trainable_variables(): print(var)

In [None]:
if not params['debug_mode']:
    results_dir = functions.prepare_results_dir(model.__class__.__name__)
    weights_dir = os.path.join(results_dir, 'weights')
    if not os.path.exists(weights_dir): os.makedirs(weights_dir)
    log_file = os.path.join(results_dir, 'train_log.txt')
    log_description = "\n".join(["{} : {}".format(k, v) for k, v in params.items()])
    log = open(log_file, 'w')
    log.close()
    functions.write_to_log(log_description, log_file)
    pickle.dump(params, open(os.path.join(results_dir, 'params.pickle'), 'wb'))
    pickle.dump(vectorizer, open(os.path.join(results_dir, 'vectorizer.pickle'), 'wb'))

In [None]:
gumbel_temperature_schedule = lambda x, itr : max(x, 1.0 - (1.0-x)*(max(itr, 10000)-10000)/10000)

In [None]:
def evaluate(model, data_batcher):
    discrete_loss_weight, continuous_loss_weight = params['discrete_loss_weight'], params['continuous_loss_weight']
    loss, discrete_loss, continuous_loss, gumbel_loss, orth_loss, mse_loss, discrete_acc, continuous_acc = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
    pred_hier_labels, actual_hier_labels = [], []
    for i in range(int(np.ceil(len(data_batcher.data)/data_batcher.batch_size))):
        batch = data_batcher.next_batch()
        inputs = [example.ids for example in batch]
        input_lens = [example.real_len for example in batch]
        input_targets = [example.label for example in batch]

        if params['discrete_loss_weight'] == 'random':
            discrete_loss_weight = 1.0
        if params['continuous_loss_weight'] == 'random':
            continuous_loss_weight = 1.0

        feed_dict = {
            model.inputs : inputs,
            model.input_lens : input_lens,
            model.targets : input_targets,
            model.dropout_keep_prob : 1.0,
            model.gumbel_temperature : gumbel_temperature_schedule(params['gumbel_temperature'], itr),
            model.discrete_loss_weight : discrete_loss_weight,
            model.continuous_loss_weight : continuous_loss_weight,
            model.gumbel_loss_weight : params['gumbel_loss_weight'],
            model.orthogonality_loss_weight : params['orthogonality_loss_weight'],
            model.mse_loss_weight : params['mse_loss_weight']
        }
        (batch_softmaxes, batch_loss, batch_discrete_loss, batch_continuous_loss, batch_gumbel_loss, batch_orth_loss, 
         batch_mse_loss, batch_discrete_acc, batch_continuous_acc) = sess.run(
            [model.gumbel_onehots, model.loss, model.discrete_loss, model.continuous_loss, model.gumbel_loss, 
             model.orthogonality_loss, model.mse_loss, model.discrete_accuracy, model.continuous_accuracy], feed_dict=feed_dict
        )

        pred_hier_labels += list(zip(*[np.argmax(x, axis=1) for x in batch_softmaxes]))
        actual_hier_labels += [data.label_dict[l].split('.') for l in input_targets]

        loss += (batch_loss - loss)/(i+1)
        discrete_loss += (batch_discrete_loss - discrete_loss)/(i+1)
        continuous_loss += (batch_continuous_loss - continuous_loss)/(i+1)
        gumbel_loss += (batch_gumbel_loss - gumbel_loss)/(i+1)
        orth_loss += (batch_orth_loss - orth_loss)/(i+1)
        mse_loss += (batch_mse_loss - mse_loss)/(i+1)
        discrete_acc += (batch_discrete_acc - discrete_acc)/(i+1)
        continuous_acc += (batch_continuous_acc - continuous_acc)/(i+1)
    hier_bcubed = functions.evaluate_hierarchical_bcubed(pred_hier_labels, actual_hier_labels)
    return loss, discrete_loss, continuous_loss, gumbel_loss, orth_loss, mse_loss, discrete_acc, continuous_acc, hier_bcubed

In [None]:
sess.run(init)
if params['weights_file'] != None:
    pretrain_loader.restore(sess, params['weights_file'])
    log_output = "\nTraining variables:\n{}\n".format(model.finetune_vars)
    print(log_output)
    if not params['debug_mode']:
        functions.write_to_log(log_output, log_file)

best_loss, best_acc = 10e6, 0.0
discrete_loss_weight, continuous_loss_weight = params['discrete_loss_weight'], params['continuous_loss_weight']
while train_batcher.epoch < params['n_epochs']:
    itr = tf.train.global_step(sess, model.global_step) + 1
        
    train_batch = train_batcher.next_batch()
    inputs = [example.ids for example in train_batch]
    input_lens = [example.real_len for example in train_batch]
    input_targets = [example.label for example in train_batch]
    
    noise_switch = random.randint(0, 1)
    if params['discrete_loss_weight'] == 'random':
        discrete_loss_weight = noise_switch
    if params['continuous_loss_weight'] == 'random':
        continuous_loss_weight = 1-noise_switch
    
    feed_dict = {
        model.inputs : inputs,
        model.input_lens : input_lens,
        model.targets : input_targets,
        model.dropout_keep_prob : params['dropout_keep_prob'],
        model.gumbel_temperature : gumbel_temperature_schedule(params['gumbel_temperature'], itr),
        model.discrete_loss_weight : discrete_loss_weight,
        model.continuous_loss_weight : continuous_loss_weight,
        model.gumbel_loss_weight : params['gumbel_loss_weight'],
        model.orthogonality_loss_weight : params['orthogonality_loss_weight'],
        model.mse_loss_weight : params['mse_loss_weight']
    }
    
    if params['finetune']:
        _ = sess.run(model.finetune_optimizer, feed_dict=feed_dict)
    else:
        _ = sess.run(model.optimizer, feed_dict=feed_dict)
    
    if itr % params['display_interval'] == 0 or itr == 1: 
        (train_loss, train_discrete_loss, train_continuous_loss, train_gumbel_loss, train_orth_loss, 
         train_mse_loss, train_discrete_acc, train_continuous_acc, grad_norm) = sess.run(
            [model.loss, model.discrete_loss, model.continuous_loss, model.gumbel_loss, model.orthogonality_loss, model.mse_loss, 
             model.discrete_accuracy, model.continuous_accuracy, model.gradient_norm], feed_dict=feed_dict
        )
        
        log_output = "[{}, {:5d}] train loss: (batch: {:.4f}, discrete: {:.4f}, continuous: {:.4f}, gumbel: {:.4f}, orth: {:.4f}, mse: {:.4f}), \
discrete acc: {:.4f}, continuous acc: {:4f}, grad_norm: {:.4f}".format(train_batcher.epoch, itr, train_loss, train_discrete_loss, train_continuous_loss, 
                                                               train_gumbel_loss, train_orth_loss, train_mse_loss, train_discrete_acc, train_continuous_acc, grad_norm)
        print(log_output)
        if not params['debug_mode']:
            functions.write_to_log(log_output, log_file)
        
    if itr % params['val_interval'] == 0:
        (val_loss, val_discrete_loss, val_continuous_loss, val_gumbel_loss, val_orth_loss, val_mse_loss, 
         val_discrete_acc, val_continuous_acc, val_hier_bcubed) = evaluate(model, val_batcher)
        
        log_output = "Val - loss: {:.4f}, discrete loss: {:.4f}, continuous loss: {:.4f}, gumbel loss: {:.4f}, orth loss: {:4f}, mse loss: {:4f}, \
discrete acc: {:.4f}, continuous acc: {:.4f}".format(val_loss, val_discrete_loss, val_continuous_loss, val_gumbel_loss, val_orth_loss, 
                                             val_mse_loss, val_discrete_acc, val_continuous_acc)
        log_output += "\nhierarchical bcubed:\n"
        for k, v in val_hier_bcubed.items(): log_output += '[{}] F1: {:.4f}, R: {:.4f}, P: {:.4f}\n'.format(k, v['F1'], v['R'], v['P'])
        print(log_output)
        if not params['debug_mode']:
            functions.write_to_log(log_output, log_file)
        
        if val_loss < best_loss:
            best_loss = val_loss
        if val_continuous_acc > best_acc:
            best_acc = val_continuous_acc
            (test_loss, test_discrete_loss, test_continuous_loss, test_gumbel_loss, test_orth_loss, test_mse_loss, 
             test_discrete_acc, test_continuous_acc, test_hier_bcubed) = evaluate(model, test_batcher)
            
            log_output = "Test - loss: {:.4f}, discrete loss: {:.4f}, continuous loss: {:.4f}, gumbel loss: {:.4f}, orth loss: {:4f}, mse loss: {:4f}, \
discrete acc: {:.4f}, continuous acc: {:.4f}".format(test_loss, test_discrete_loss, test_continuous_loss, test_gumbel_loss, test_orth_loss, 
                                             test_mse_loss, test_discrete_acc, test_continuous_acc)
            log_output += "\nhierarchical bcubed:\n"
            for k, v in test_hier_bcubed.items(): log_output += '[{}] F1: {:.4f}, R: {:.4f}, P: {:.4f}\n'.format(k, v['F1'], v['R'], v['P'])
            print(log_output)
            if not params['debug_mode']:
                functions.write_to_log(log_output, log_file)
            
            if not params['debug_mode']:
                weights_prefix = "{}_weights_epoch_{}_itr_{}".format(model.__class__.__name__, train_batcher.epoch, itr)
                log_output = "Weights saved in file: {}\n".format(os.path.join(weights_dir, weights_prefix))
                print(log_output)
                if not params['debug_mode']:
                    functions.write_to_log(log_output, log_file)
                saver.save(sess, os.path.join(weights_dir, weights_prefix))
        
    if itr % params['deploy_interval'] == 0:
        softmaxes, targets, texts = [], [], []
        for i in range(int(np.ceil(len(val_batcher.data)/val_batcher.batch_size))):
            val_batch = val_batcher.next_batch()
            input_texts = [example.text for example in val_batch]
            inputs = [example.ids for example in val_batch]
            input_lens = [example.real_len for example in val_batch]
            input_targets = [example.label for example in val_batch]
            
            feed_dict = {
                model.inputs : inputs,
                model.input_lens : input_lens,
                model.dropout_keep_prob : 1.0,
                model.gumbel_temperature : 0.1
            }
            softmax = sess.run(model.gumbel_onehots, feed_dict=feed_dict)
            softmaxes.append(softmax)
            targets += input_targets
            texts += input_texts
            
        softmaxes = [np.concatenate(s, axis=0) for s in list(zip(*softmaxes))]
        hierarchy_results = []
        for level, softmax in enumerate(softmaxes):
            log_output = '\nHierarchy #{}\n'.format(level)
            hierarchy_dict = dict([(i, np.zeros(len(data.classes), dtype=int)) for i in range(softmax.shape[-1])])
            label_assignment_dict = dict([(data.label_dict[i], np.zeros(softmax.shape[-1], dtype=int)) for i in range(len(data.classes))])
            keywords_dict = dict([(i, '') for i in range(softmax.shape[-1])])
            for c, lbl, txt in zip(np.argmax(softmax, axis=1), targets, texts):
                hierarchy_dict[c][lbl] += 1
                label_assignment_dict[data.label_dict[lbl]][c] += 1
                keywords_dict[c] += ' ' + txt
            log_output += 'Label Distributions:\n'
            for k, v in hierarchy_dict.items(): log_output += '{}: {}\n'.format(k, list(v))
            hierarchy_results.append(hierarchy_dict)
            
            log_output += '\nHierarchy Assignments:\n'
            for k, v in sorted(list(label_assignment_dict.items()), key=lambda x: np.argmax(x[1])): log_output += '{}: {} ==> {}\n'.format(k, list(v), np.argmax(list(v)))
            
            log_output += '\nKeywords:\n'
            for k, v in keywords_dict.items(): log_output += '{}: {}\n'.format(k, functions.get_important_words(v, tfidf, max_n=5))
            
            print(log_output)
            if not params['debug_mode']:
                functions.write_to_log(log_output, log_file)