In [None]:
import os
import re

import tensorflow.compat.v1 as tf

from statistics import mean
from collections import defaultdict

from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.rouge.rouge import Rouge

shots = [8,16,32,64,128]
supports = ['supp', 'all']

b4m, mm, rm = Bleu(4), Meteor(), Rouge()

bleu4 = lambda targets, predictions: {'bleu4': 100 * mean([b4m.compute_score({i:[targets[i]]}, {i:[predictions[i]]})[0][3] for i in range(len(targets))])}
bleu3 = lambda targets, predictions: {'bleu3': 100 * mean([b4m.compute_score({i:[targets[i]]}, {i:[predictions[i]]})[0][2] for i in range(len(targets))])}
bleu2 = lambda targets, predictions: {'bleu2': 100 * mean([b4m.compute_score({i:[targets[i]]}, {i:[predictions[i]]})[0][1] for i in range(len(targets))])}
bleu1 = lambda targets, predictions: {'bleu1': 100 * mean([b4m.compute_score({i:[targets[i]]}, {i:[predictions[i]]})[0][0] for i in range(len(targets))])}

rouge = lambda targets, predictions: {'rougeL': 100 * mean([rm.compute_score({i:[targets[i]]}, {i:[predictions[i]]})[0] for i in range(len(targets))])}

def meteor(targets, predictions):
    from pycocoevalcap.meteor.meteor import Meteor
    return {'meteor': 100 * mean([mm.compute_score({i:[targets[i]]}, {i:[predictions[i]]})[0] for i in range(len(targets))])}

def remove_tags_and_strip(string, **unused_kwargs): 
    return re.sub('<extra_id_[0-9]{1,2}>', '', string).strip()

def readlines_into_list_gf(path):
    with tf.io.gfile.GFile(path) as f:
        return [l.strip() for l in f.readlines()]

predictions = defaultdict(lambda: defaultdict(lambda: {'best_checkpoint': None, "all_checkpoint_data": None}))
    
for shot in shots:
    for support in supports:
        MODEL_DIR = f'gs://t5_fewshot_mqg/models/t5_baseline_0/3B/{shot}_SHOT/{support}/validation_eval/'
        gold_fp = os.path.join(MODEL_DIR, 'fhpqg_no_rationale_targets')
        gold = readlines_into_list_gf(gold_fp)
        checkpoints = [int(p.split('_')[-2]) for p in tf.io.gfile.glob(os.path.join(MODEL_DIR, 'fhpqg_no_rationale_*_predictions'))]
        checkpoint_predictions = defaultdict(lambda: {"predictions":None, 
                                                      "bleu4": None, 
                                                      "bleu3": None, 
                                                      "bleu2": None, 
                                                      "bleu1": None, 
                                                      "rouge": None, 
                                                      "meteor": None,
                                                      "average": None}
                                            )
        for checkpoint in checkpoints:
            checkpoint_predictions_fp = os.path.join(MODEL_DIR, f"fhpqg_no_rationale_{checkpoint}_predictions")
            checkpoint_predictions[checkpoint]["predictions"] = readlines_into_list_gf(checkpoint_predictions_fp)
            assert len(checkpoint_predictions[checkpoint]["predictions"]) == len(gold)
            checkpoint_predictions[checkpoint]["bleu4"] = bleu4(gold, checkpoint_predictions[checkpoint]["predictions"])['bleu4']
            checkpoint_predictions[checkpoint]["bleu3"] = bleu3(gold, checkpoint_predictions[checkpoint]["predictions"])['bleu3']
            checkpoint_predictions[checkpoint]["bleu2"] = bleu2(gold, checkpoint_predictions[checkpoint]["predictions"])['bleu2']
            checkpoint_predictions[checkpoint]["bleu1"] = bleu1(gold, checkpoint_predictions[checkpoint]["predictions"])['bleu1']
            checkpoint_predictions[checkpoint]["meteor"] = meteor(gold, checkpoint_predictions[checkpoint]["predictions"])['meteor']
            checkpoint_predictions[checkpoint]["rouge"] = rouge(gold, checkpoint_predictions[checkpoint]["predictions"])['rougeL']
            checkpoint_predictions[checkpoint]["average"] = mean([
                checkpoint_predictions[checkpoint]["bleu4"],
                checkpoint_predictions[checkpoint]["bleu3"],
                checkpoint_predictions[checkpoint]["bleu2"],
                checkpoint_predictions[checkpoint]["bleu1"],
                checkpoint_predictions[checkpoint]["meteor"],
                checkpoint_predictions[checkpoint]["rouge"]
            ])
        predictions[shot][support]['all_checkpoint_data'] = checkpoint_predictions

In [21]:
for shot in shots:
    for support in supports:
        checkpoint_predictions = predictions[shot][support]['all_checkpoint_data']
        avg_scores_across_all_checkpoints = [checkpoint_predictions[chkpt]['average'] for chkpt in checkpoint_predictions]
        all_checkpoints = [chkpt for chkpt in checkpoint_predictions]
        max_avg_score = max(avg_scores_across_all_checkpoints)
        for chkpt in checkpoint_predictions:
            if checkpoint_predictions[chkpt]['average'] == max_avg_score:
                predictions[shot][support]['best_checkpoint'] = chkpt
                print(f"{shot}, {support}")
                for i, c in enumerate(avg_scores_across_all_checkpoints):
                    print(f"{all_checkpoints[i]}: {c:.2f}")
                print(f"best_checkpoint {chkpt} @ {max_avg_score:.2f}")

8, supp
1000000: 3.75
1000500: 9.43
1001000: 10.70
1001500: 12.37
1002000: 10.69
1002500: 10.35
1003000: 10.86
best_checkpoint 1001500 @ 12.37
8, all
1000000: 2.97
1000500: 8.25
1001000: 11.27
1001500: 10.86
1002000: 9.48
1002500: 10.86
1003000: 10.22
best_checkpoint 1001000 @ 11.27
16, supp
1000000: 6.27
1000500: 16.88
1001000: 12.48
1001500: 9.90
1002000: 10.31
1002500: 10.36
1003000: 10.19
best_checkpoint 1000500 @ 16.88
16, all
1000000: 4.66
1000500: 9.04
1001000: 9.99
1001500: 10.72
1002000: 11.56
1002500: 9.60
1003000: 10.41
best_checkpoint 1002000 @ 11.56
32, supp
1000000: 6.33
1000500: 13.83
1001000: 15.24
1001500: 16.43
1002000: 16.47
1002500: 17.14
1003000: 17.50
best_checkpoint 1003000 @ 17.50
32, all
1000000: 5.20
1000500: 12.99
1001000: 14.72
1001500: 14.78
1002000: 16.19
1002500: 17.28
1003000: 16.84
best_checkpoint 1002500 @ 17.28
64, supp
1001000: 21.03
1001500: 21.20
1002000: 20.37
1002500: 21.14
1003000: 20.29
1003500: 19.91
1004000: 20.00
1004500: 20.51
1005000: 20.6

In [None]:
import json
with open('data.json', 'r') as f:
    predict_data = json.load(f)

In [None]:
gold_fp[30]

In [None]:
targets, predictions = [], []
for i in range(len(predict_data)):
    gold = predict_data[i]['gold_q']
    predq = predict_data[i].get('q', None)
    if predq != None:
        targets.append(gold)
        predictions.append(remove_tags_and_strip(predq))
    gold = predict_data[i]['gold_q']
    predq = predict_data[i].get('q_c', None)
    if predq != None:
        targets.append(gold)
        predictions.append(remove_tags_and_strip(predq))
        
b4 = bleu4(targets, predictions)
b3 = bleu3(targets, predictions)
b2 = bleu2(targets, predictions)
b1 = bleu1(targets, predictions)
r = rouge(targets, predictions)
m = meteor(targets, predictions)

metrics = {}
metrics.update(b4)
metrics.update(b3)
metrics.update(b2)
metrics.update(b1)
metrics.update(r)
metrics.update(m)

In [None]:
', '.join([f"{key}: {metrics[key]:.2f}" for key in sorted(metrics.keys())])

In [None]:
import tensorflow.compat.v1 as tf
import os

eval_checkpoints =  [int(fp.split('-')[-1][:-5]) for fp in tf.io.gfile.glob(os.path.join(MODEL_DIR, 'model.ckpt-*.meta'))]

In [None]:
eval_checkpoints

In [None]:
import time
s = time.time()
time.sleep(10)
e = time.time()

print(e-s)

In [None]:
e = sorted(eval_checkpoints, reverse=True)

In [None]:
MODEL
eval_jsons =  [int(fp.split('-')[-1][:-5]) for fp in tf.io.gfile.glob(os.path.join(MODEL_DIR, 'model.ckpt-*.meta'))]
eval_checkpoints = sorted(eval_checkpoints, reverse=True)
print(f"Found checkpoints eval_checkpoints: {eval_checkpoints}")