In [1]:
import json

In [5]:
def load_metric_log(experiment):
    with open('raw_output/' + experiment + '_metric_log.json') as f:
        metric_log = json.load(f)
    return metric_log

In [7]:
cr_metric_log = load_metric_log('gpt2_cr')
eg_metric_log = load_metric_log('gpt2_eg')
cr_eg_metric_log = load_metric_log('gpt2_cr_eg')

In [20]:
def select_epoch(metric_log, best=True, epoch=None, multitask=False, metric=None):
    if not best:
        assert isinstance(epoch, int), "Please provide an int epoch number, when not selecting the best epoch"
        assert 0 <= epoch < len(metric_log), "Epoch number is too small or too large"
        return {f'epoch_{epoch}':metric_log[f'epoch_{epoch}']}
    
    if not metric:
        metric = 'val_total_loss' if multitask else 'val_loss'
    minimize = False
    if 'loss' in metric:
        minimize = True
    best_epoch = 'epoch_0'
    assert metric in metric_log[best_epoch], "Metric is invalid. Consider checking if model is multitask"
    for epoch_key, epoch_metrics in metric_log.items():
        if minimize and epoch_metrics[metric] < metric_log[best_epoch][metric]:
            best_epoch = epoch_key
            continue
        if not minimize and epoch_metrics[metric] > metric_log[best_epoch][metric]:
            best_epoch = epoch_key
    return {best_epoch:metric_log[best_epoch]}

In [19]:
cr_best_epoch, cr_3epochs = select_epoch(cr_metric_log), select_epoch(cr_metric_log, best=False, epoch=2)
eg_best_epoch, eg_10epochs = select_epoch(eg_metric_log), select_epoch(eg_metric_log, best=False, epoch=9)
cr_eg_best_epoch = select_epoch(cr_eg_metric_log, multitask=True)
cr_eg_5epochs = select_epoch(cr_eg_metric_log, best=False, multitask=True, epoch=4)
cr_best_accuracy_epoch = select_epoch(cr_metric_log, metric='accuracy')
cr_eg_best_accuracy_epoch = select_epoch(cr_eg_metric_log, metric='accuracy')

In [59]:
def view(metrics):
    for epoch, log in metrics.items():
        print(epoch, [f'{metric}: {value:.4f}' for metric, value in log.items() if 'train' not in metric])

In [61]:
# Selected epochs reported by paper
print('CR:')
view(cr_3epochs)
print('CR-EG:')
view(cr_eg_5epochs)
print('EG:')
view(eg_10epochs)

CR:
epoch_2 ['val_loss: 2.0488', 'accuracy: 0.6656']
CR-EG:
epoch_4 ['val_total_loss: 4.8051', 'val_cr_loss: 4.0053', 'val_eg_loss: 12.0040', 'accuracy: 0.6993', 'bleu1: 0.5610', 'bleu2: 0.3544', 'bleu3: 0.2569', 'bleu4: 0.2154', 'avg_bleu: 0.3469', 'rouge1: 0.3444', 'rouge2: 0.1165', 'rougel: 0.3350', 'perplexity: 6.7933']
EG:
epoch_9 ['val_loss: 6.3349', 'bleu1: 0.4957', 'bleu2: 0.3035', 'bleu3: 0.2041', 'bleu4: 0.1630', 'avg_bleu: 0.2916', 'rouge1: 0.3266', 'rouge2: 0.1028', 'rougel: 0.3126', 'perplexity: 7.2673']


In [68]:
# CR: best epochs
print('Best loss:')
view(cr_best_epoch)
print('Best accuracy:')
view(cr_best_accuracy_epoch) 

Best loss:
epoch_5 ['val_loss: 2.0053', 'accuracy: 0.6820']
Best accuracy:
epoch_9 ['val_loss: 2.1330', 'accuracy: 0.7017']


In [67]:
# CR-EG: best epochs
print('Best loss:')
view(cr_eg_best_epoch)
print('Best accuracy:')
view(cr_eg_best_accuracy_epoch) 

Best loss:
epoch_4 ['val_total_loss: 4.8051', 'val_cr_loss: 4.0053', 'val_eg_loss: 12.0040', 'accuracy: 0.6993', 'bleu1: 0.5610', 'bleu2: 0.3544', 'bleu3: 0.2569', 'bleu4: 0.2154', 'avg_bleu: 0.3469', 'rouge1: 0.3444', 'rouge2: 0.1165', 'rougel: 0.3350', 'perplexity: 6.7933']
Best accuracy:
epoch_8 ['val_total_loss: 5.3411', 'val_cr_loss: 4.6100', 'val_eg_loss: 11.9217', 'accuracy: 0.7149', 'bleu1: 0.5581', 'bleu2: 0.3512', 'bleu3: 0.2562', 'bleu4: 0.2156', 'avg_bleu: 0.3453', 'rouge1: 0.3462', 'rouge2: 0.1132', 'rougel: 0.3357', 'perplexity: 6.6798']


In [69]:
# EG: best epoch
view(eg_best_epoch)

epoch_6 ['val_loss: 5.9345', 'bleu1: 0.5379', 'bleu2: 0.3250', 'bleu3: 0.2227', 'bleu4: 0.1804', 'avg_bleu: 0.3165', 'rouge1: 0.3259', 'rouge2: 0.0982', 'rougel: 0.3137', 'perplexity: 6.8181']
