In [1]:
import json

In [5]:
def load_metric_log(experiment):
    with open('raw_output/' + experiment + '_metric_log.json') as f:
        metric_log = json.load(f)
    return metric_log

In [7]:
cr_metric_log = load_metric_log('gpt2_cr')
eg_metric_log = load_metric_log('gpt2_eg')
cr_eg_metric_log = load_metric_log('gpt2_cr_eg')

In [20]:
def select_epoch(metric_log, best=True, epoch=None, multitask=False, metric=None):
    if not best:
        assert isinstance(epoch, int), "Please provide an int epoch number, when not selecting the best epoch"
        assert 0 <= epoch < len(metric_log), "Epoch number is too small or too large"
        return {f'epoch_{epoch}':metric_log[f'epoch_{epoch}']}
    
    if not metric:
        metric = 'val_total_loss' if multitask else 'val_loss'
    minimize = False
    if 'loss' in metric:
        minimize = True
    best_epoch = 'epoch_0'
    assert metric in metric_log[best_epoch], "Metric is invalid. Consider checking if model is multitask"
    for epoch_key, epoch_metrics in metric_log.items():
        if minimize and epoch_metrics[metric] < metric_log[best_epoch][metric]:
            best_epoch = epoch_key
            continue
        if not minimize and epoch_metrics[metric] > metric_log[best_epoch][metric]:
            best_epoch = epoch_key
    return {best_epoch:metric_log[best_epoch]}

In [19]:
cr_best_epoch, cr_3epochs = select_epoch(cr_metric_log), select_epoch(cr_metric_log, best=False, epoch=2)
eg_best_epoch, eg_10epochs = select_epoch(eg_metric_log), select_epoch(eg_metric_log, best=False, epoch=9)
cr_eg_best_epoch = select_epoch(cr_eg_metric_log, multitask=True)
cr_eg_5epochs = select_epoch(cr_eg_metric_log, best=False, multitask=True, epoch=4)
cr_best_accuracy_epoch = select_epoch(cr_metric_log, metric='accuracy')
cr_eg_best_accuracy_epoch = select_epoch(cr_eg_metric_log, metric='accuracy')

In [31]:
# Selected epochs reported by paper
print('cr_3epochs')
print(cr_3epochs)
print('cr_eg_5epochs')
print(cr_eg_5epochs)
print('eg_10epochs')
print(eg_10epochs)

cr_3epochs
{'epoch_2': {'train_loss': 1.0333056031173888, 'val_loss': 2.0487872369964846, 'accuracy': 0.6655722326454033}}
cr_eg_5epochs
{'epoch_4': {'train_total_loss': 2.24081388208516, 'train_cr_loss': 1.8229743378937628, 'train_eg_loss': 6.001370024467077, 'val_total_loss': 4.805128892899458, 'val_cr_loss': 4.00525179817135, 'val_eg_loss': 12.004023217345948, 'accuracy': 0.6993402450518379, 'bleu1': 0.5609781410714333, 'bleu2': 0.35440634303043567, 'bleu3': 0.25694576650325485, 'bleu4': 0.2153551271775811, 'avg_bleu': 0.3469213444456762, 'rouge1': 0.3443884383744124, 'rouge2': 0.11654123344884001, 'rougel': 0.33495320820229396, 'perplexity': 6.79328727722168}}
eg_10epochs
{'epoch_9': {'train_loss': 4.8014943513867685, 'val_loss': 6.334864847208277, 'bleu1': 0.49573143807849906, 'bleu2': 0.3035366110698238, 'bleu3': 0.20408104394865365, 'bleu4': 0.1629515549214478, 'avg_bleu': 0.2915751620046061, 'rouge1': 0.3266294556128739, 'rouge2': 0.10278297041864803, 'rougel': 0.31260752201666

In [30]:
# CR: best epochs
print('cr_best_epoch')
print(cr_best_epoch)
print('cr_best_accuracy_epoch')
print(cr_best_accuracy_epoch) 

cr_best_epoch
{'epoch_5': {'train_loss': 0.9357343524558833, 'val_loss': 2.0053413079782456, 'accuracy': 0.6819887429643527}}
cr_best_accuracy_epoch
{'epoch_9': {'train_loss': 0.779406982922739, 'val_loss': 2.1329857450936123, 'accuracy': 0.701688555347092}}


In [29]:
# CR-EG: best epochs
print('cr_eg_best_epoch')
print(cr_eg_best_epoch)
print('cr_eg_best_accuracy_epoch')
print(cr_eg_best_accuracy_epoch) 

cr_eg_best_epoch
{'epoch_4': {'train_total_loss': 2.24081388208516, 'train_cr_loss': 1.8229743378937628, 'train_eg_loss': 6.001370024467077, 'val_total_loss': 4.805128892899458, 'val_cr_loss': 4.00525179817135, 'val_eg_loss': 12.004023217345948, 'accuracy': 0.6993402450518379, 'bleu1': 0.5609781410714333, 'bleu2': 0.35440634303043567, 'bleu3': 0.25694576650325485, 'bleu4': 0.2153551271775811, 'avg_bleu': 0.3469213444456762, 'rouge1': 0.3443884383744124, 'rouge2': 0.11654123344884001, 'rougel': 0.33495320820229396, 'perplexity': 6.79328727722168}}
cr_eg_best_accuracy_epoch
{'epoch_8': {'train_total_loss': 1.8103833076726779, 'train_cr_loss': 1.3862930266196578, 'train_eg_loss': 5.627196096482768, 'val_total_loss': 5.341133986521621, 'val_cr_loss': 4.609960372385048, 'val_eg_loss': 11.92169716962658, 'accuracy': 0.7148916116870877, 'bleu1': 0.5581311125261607, 'bleu2': 0.351170538670631, 'bleu3': 0.2561837579816091, 'bleu4': 0.21563342179337744, 'avg_bleu': 0.3452797077429446, 'rouge1': 

In [32]:
# EG: best epoch
print(eg_best_epoch)

{'epoch_6': {'train_loss': 5.448050856015922, 'val_loss': 5.9345028022440465, 'bleu1': 0.537882484100006, 'bleu2': 0.32500600749795944, 'bleu3': 0.22265372228024968, 'bleu4': 0.18041089653768486, 'avg_bleu': 0.316488277603975, 'rouge1': 0.32588110017029265, 'rouge2': 0.09823059827301911, 'rougel': 0.3136544500025351, 'perplexity': 6.818094253540039}}
