In [14]:
import matplotlib
import pickle as pkl
import pprint as pp
from collections import OrderedDict
matplotlib.use('Agg')
import matplotlib.pyplot as plt



def print_stats(stats_fname):
    with open(stats_fname, 'rb') as f:
        stats = pkl.load(f)
    print("\nAUROCS:")
    aurocs = OrderedDict({k: stats[k] for k in stats if k[-7:] == 'auroc_c'})
    pp.pprint(aurocs)
    print("\nAUPRCS:")
    auprcs = OrderedDict({k: stats[k] for k in stats if k[-7:] == 'auprc_c'})
    pp.pprint(auprcs)

def get_params(result_path):
    with open(result_path, 'rb') as f:
        results = pkl.load(f)
    params = ['model_name', 'num_layers', 'num_heads', 'exclusion_interval', 'init_lr', "weight_decay", 'pool_name', 
    'no_random_sample_eval_trajectories', 'pad_size', 'epochs', 'train', 'dev', 'test', 'cuda', 
    'train_batch_size', 'eval_batch_size', 'max_batches_per_train_epoch', 'max_batches_per_dev_epoch', 
    'max_events_length', 'max_eval_indices', 'eval_auroc', 'eval_auprc', 'eval_c_index']
    configs = [{param: results[param]} for param in params]
    return configs

def plot_train_metric(epoch_stats_path, metric_keys):
    with open(epoch_stats_path, 'rb') as f:
        epoch_stats = pkl.load(f)
    plt.figure(figsize=(10, 7))
    lines = []
    labels = []
    for metric_key in metric_keys:
        # print(cls)
        # print(fprs[cls])
        num_of_epoch = len(epoch_stats[metric_key])
        line, = plt.plot(list(range(1, num_of_epoch + 1)), epoch_stats[metric_key], lw=1)
        lines.append(line)
        labels.append(metric_key)
    plt.xticks(np.arange(0, 21, 1))
    plt.xlim([0, 21])
    plt.xlabel('Epoch', size=20)
    plt.ylabel('Score', size=20)
    plt.legend(lines, labels, loc='best', prop=dict(size=12))
    plt.title("Training process for transformer model", size=16)

In [12]:
def get_training_summary(model, log_dir, test = False):
    result_fname = model + '.results'
    result_path = os.path.join(log_dir, result_fname)
    dev_stat_path = result_path + '.dev_stats'
    params = get_params(result_path)
    pp.pprint(params)
    print_stats(dev_stat_path)
    if test:
        test_stat_path = result_path + '.test_stats'
        print_stats(test_stat_path)
    plot_train_metrics(result_path, model, metrics)