In [10]:
import os
import re

import tensorflow.compat.v1 as tf

from statistics import mean
from collections import defaultdict

from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.rouge.rouge import Rouge

shots = [8,16,32,64,128]
supports = ['supp', 'all']

b4m, mm, rm = Bleu(4), Meteor(), Rouge()

bleu4 = lambda targets, predictions: {'bleu4': 100 * mean([b4m.compute_score({i:[targets[i]]}, {i:[predictions[i]]})[0][3] for i in range(len(targets))])}
bleu3 = lambda targets, predictions: {'bleu3': 100 * mean([b4m.compute_score({i:[targets[i]]}, {i:[predictions[i]]})[0][2] for i in range(len(targets))])}
bleu2 = lambda targets, predictions: {'bleu2': 100 * mean([b4m.compute_score({i:[targets[i]]}, {i:[predictions[i]]})[0][1] for i in range(len(targets))])}
bleu1 = lambda targets, predictions: {'bleu1': 100 * mean([b4m.compute_score({i:[targets[i]]}, {i:[predictions[i]]})[0][0] for i in range(len(targets))])}

rouge = lambda targets, predictions: {'rougeL': 100 * mean([rm.compute_score({i:[targets[i]]}, {i:[predictions[i]]})[0] for i in range(len(targets))])}

def meteor(targets, predictions):
    from pycocoevalcap.meteor.meteor import Meteor
    return {'meteor': 100 * mean([mm.compute_score({i:[targets[i]]}, {i:[predictions[i]]})[0] for i in range(len(targets))])}

def remove_tags_and_strip(string, **unused_kwargs): 
    return re.sub('<extra_id_[0-9]{1,2}>', '', string).strip()

def readlines_into_list_gf(path):
    with tf.io.gfile.GFile(path) as f:
        return [l.strip() for l in f.readlines()]

predictions = defaultdict(lambda: defaultdict(lambda: {'best_checkpoint': None, "all_checkpoint_data": None}))

import contextlib


for shot in shots:
    for support in supports:
        MODEL_DIR = f'gs://t5_fewshot_mqg/models/t5_baseline_0/3B/{shot}_SHOT/{support}/test_eval/'
        gold_fp = os.path.join(MODEL_DIR, 'fhpqg_no_rationale_targets')
        gold = readlines_into_list_gf(gold_fp)
        checkpoints = [int(p.split('_')[-2]) for p in tf.io.gfile.glob(os.path.join(MODEL_DIR, 'fhpqg_no_rationale_*_predictions'))]
        checkpoint_predictions = defaultdict(lambda: {"predictions":None, 
                                                      "bleu4": None, 
                                                      "bleu3": None, 
                                                      "bleu2": None, 
                                                      "bleu1": None, 
                                                      "rouge": None, 
                                                      "meteor": None,
                                                      "average": None}
                                            )
        for checkpoint in checkpoints:
            print(shot, support, checkpoint)
            checkpoint_predictions_fp = os.path.join(MODEL_DIR, f"fhpqg_no_rationale_{checkpoint}_predictions")
            x = readlines_into_list_gf(checkpoint_predictions_fp)
            try:
                assert len(x) == len(gold)
            except:
                continue
            checkpoint_predictions[checkpoint]["predictions"] = x
            with contextlib.redirect_stdout(None):
                checkpoint_predictions[checkpoint]["bleu4"] = bleu4(gold, checkpoint_predictions[checkpoint]["predictions"])['bleu4']
                checkpoint_predictions[checkpoint]["bleu3"] = bleu3(gold, checkpoint_predictions[checkpoint]["predictions"])['bleu3']
                checkpoint_predictions[checkpoint]["bleu2"] = bleu2(gold, checkpoint_predictions[checkpoint]["predictions"])['bleu2']
                checkpoint_predictions[checkpoint]["bleu1"] = bleu1(gold, checkpoint_predictions[checkpoint]["predictions"])['bleu1']
            checkpoint_predictions[checkpoint]["meteor"] = meteor(gold, checkpoint_predictions[checkpoint]["predictions"])['meteor']
            checkpoint_predictions[checkpoint]["rouge"] = rouge(gold, checkpoint_predictions[checkpoint]["predictions"])['rougeL']
            checkpoint_predictions[checkpoint]["average"] = mean([
                checkpoint_predictions[checkpoint]["bleu4"],
                checkpoint_predictions[checkpoint]["bleu3"],
                checkpoint_predictions[checkpoint]["bleu2"],
                checkpoint_predictions[checkpoint]["bleu1"],
                checkpoint_predictions[checkpoint]["meteor"],
                checkpoint_predictions[checkpoint]["rouge"]
            ])
        predictions[shot][support]['all_checkpoint_data'] = checkpoint_predictions

8 supp 1001500
8 supp 1003000
8 all 1001000
8 all 1003000
16 supp 1000500
16 supp 1003000
16 all 1002000
16 all 1003000
32 supp 1003000
32 all 1002500
32 all 1003000
64 supp 1001500
64 supp 1005120
64 all 1004000
64 all 1005120
128 supp 1009500
128 supp 1010240
128 all 1006000
128 all 1010240


In [14]:
import copy

backup1 = copy.deepcopy(backup2)
predictions = backup2
for shot in shots:
    for support in supports:
        checkpoint_predictions = predictions[shot][support]['all_checkpoint_data']
        avg_scores_across_all_checkpoints = [checkpoint_predictions[chkpt]['average'] for chkpt in checkpoint_predictions]
        all_checkpoints = [chkpt for chkpt in checkpoint_predictions]
        try:
            max_avg_score = max(avg_scores_across_all_checkpoints)
        except:
            print(shot, support, avg_scores_across_all_checkpoints)
            raise Exception
        for chkpt in checkpoint_predictions:
            if checkpoint_predictions[chkpt]['average'] == max_avg_score:
                predictions[shot][support]['best_checkpoint'] = chkpt
                print(f"{shot}, {support}")
                for i, c in enumerate(avg_scores_across_all_checkpoints):
                    print(f"{all_checkpoints[i]}: {c:.2f}")
                print(f"best_checkpoint {chkpt} @ {max_avg_score:.2f}")
                # for metric_name in ['bleu4', 'bleu3', 'bleu2', 'bleu1', 'meteor', 'rouge']:
                #     print(f"best_checkpoint {chkpt} @ {metric_name}, {predictions[shot][support]['all_checkpoint_data'][chkpt][metric_name]}")

8, supp
1001500: 15.33
best_checkpoint 1001500 @ 15.33
8, all
1001000: 15.16
best_checkpoint 1001000 @ 15.16
16, supp
1000500: 15.95
best_checkpoint 1000500 @ 15.95
16, all
1002000: 16.47
best_checkpoint 1002000 @ 16.47
32, supp
1003000: 18.76
best_checkpoint 1003000 @ 18.76
32, all
1002500: 17.89
best_checkpoint 1002500 @ 17.89
64, supp
1001500: 19.32
best_checkpoint 1001500 @ 19.32
64, all
1004000: 18.69
best_checkpoint 1004000 @ 18.69
128, supp
1009500: 21.48
best_checkpoint 1009500 @ 21.48
128, all
1006000: 19.37
best_checkpoint 1006000 @ 19.37


In [19]:
for shot in shots:
    for support in supports:
        checkpoint_predictions = predictions[shot][support]['all_checkpoint_data']
        chkpt = predictions[shot][support]['best_checkpoint']
        print(f'{shot} {support} top {chkpt}')
        for metric_name in ['bleu1', 'bleu2', 'bleu3', 'bleu4', 'meteor', 'rouge']:
            print(f"{metric_name}, {checkpoint_predictions[chkpt][metric_name]:.2f}")
        print()

8 supp top 1001500
bleu1, 24.40
bleu2, 13.76
bleu3, 7.49
bleu4, 4.50
meteor, 17.18
rouge, 24.63

8 all top 1001000
bleu1, 24.17
bleu2, 13.46
bleu3, 7.38
bleu4, 4.46
meteor, 16.79
rouge, 24.71

16 supp top 1000500
bleu1, 24.47
bleu2, 14.37
bleu3, 8.39
bleu4, 5.33
meteor, 17.93
rouge, 25.21

16 all top 1002000
bleu1, 25.61
bleu2, 15.04
bleu3, 8.76
bleu4, 5.47
meteor, 18.74
rouge, 25.18

32 supp top 1003000
bleu1, 28.27
bleu2, 17.63
bleu3, 10.84
bleu4, 7.03
meteor, 21.00
rouge, 27.77

32 all top 1002500
bleu1, 27.04
bleu2, 16.75
bleu3, 10.23
bleu4, 6.64
meteor, 20.31
rouge, 26.39

64 supp top 1001500
bleu1, 28.74
bleu2, 18.19
bleu3, 11.44
bleu4, 7.59
meteor, 21.78
rouge, 28.20

64 all top 1004000
bleu1, 28.06
bleu2, 17.52
bleu3, 10.89
bleu4, 7.14
meteor, 21.11
rouge, 27.42

128 supp top 1009500
bleu1, 31.42
bleu2, 20.53
bleu3, 13.34
bleu4, 8.94
meteor, 24.02
rouge, 30.62

128 all top 1006000
bleu1, 28.31
bleu2, 18.05
bleu3, 11.41
bleu4, 7.60
meteor, 22.65
rouge, 28.18



In [None]:
import json
with open('data.json', 'r') as f:
    predict_data = json.load(f)

In [None]:
gold_fp[30]

In [None]:
targets, predictions = [], []
for i in range(len(predict_data)):
    gold = predict_data[i]['gold_q']
    predq = predict_data[i].get('q', None)
    if predq != None:
        targets.append(gold)
        predictions.append(remove_tags_and_strip(predq))
    gold = predict_data[i]['gold_q']
    predq = predict_data[i].get('q_c', None)
    if predq != None:
        targets.append(gold)
        predictions.append(remove_tags_and_strip(predq))
        
b4 = bleu4(targets, predictions)
b3 = bleu3(targets, predictions)
b2 = bleu2(targets, predictions)
b1 = bleu1(targets, predictions)
r = rouge(targets, predictions)
m = meteor(targets, predictions)

metrics = {}
metrics.update(b4)
metrics.update(b3)
metrics.update(b2)
metrics.update(b1)
metrics.update(r)
metrics.update(m)

In [None]:
', '.join([f"{key}: {metrics[key]:.2f}" for key in sorted(metrics.keys())])

In [None]:
import tensorflow.compat.v1 as tf
import os

eval_checkpoints =  [int(fp.split('-')[-1][:-5]) for fp in tf.io.gfile.glob(os.path.join(MODEL_DIR, 'model.ckpt-*.meta'))]

In [None]:
eval_checkpoints

In [None]:
import time
s = time.time()
time.sleep(10)
e = time.time()

print(e-s)

In [None]:
e = sorted(eval_checkpoints, reverse=True)

In [None]:
MODEL
eval_jsons =  [int(fp.split('-')[-1][:-5]) for fp in tf.io.gfile.glob(os.path.join(MODEL_DIR, 'model.ckpt-*.meta'))]
eval_checkpoints = sorted(eval_checkpoints, reverse=True)
print(f"Found checkpoints eval_checkpoints: {eval_checkpoints}")