## Running the metrics for proof generation

In [None]:
import sys
sys.path.append("..")

In [None]:
import json
import transformers
import torch
import pickle
from tqdm.notebook import tqdm
import pandas as pd
import glob
from pprint import pprint
import transformers
import os
import datasets
datasets.logging.set_verbosity(50)
transformers.logging.set_verbosity_error()

%load_ext autoreload
%autoreload 2
%pylab inline

### Data

In [None]:
ds_base = json.load(open(
    '../data/base/proofwiki.json'
))
ds_generations = json.load(open(
    '../data/base/proofwiki__refs_ground_truth.json'
))
redirects = json.load(open(
    '../data/base/proofwiki_redirects.json'
))

In [None]:
tokenizer = transformers.AutoTokenizer.from_pretrained('facebook/bart-large')

### Utilities

In [None]:
def get_theorem_ids(split):
    import csv
    with open('../data/base/core_evalset.tsv') as f:
        reader = csv.DictReader(f, delimiter='\t')
        core_evalset = [row for row in reader]
        theorem_ids = [int(row['theorem_id']) for row in core_evalset if split == 'both' or row['split'] == split]
    return theorem_ids

def load_gens(name2gen_files, split=None):
    name2gens = {}
    for name in name2gen_files:
        try:
            name2gens[name] = json.load(open(name2gen_files[name]))
        except:
            try:
                name2gens[name] = pickle.load(open(name2gen_files[name], 'rb'))
            except:
                print("skipping %s" % name)
        if split is not None:
            theorem_ids = get_theorem_ids(split)
            name2gens[name]['full_generations'] = [x for x in name2gens[name]['full_generations'] if x['metadata'][0] in theorem_ids]
    return name2gens

In [None]:
from npgen.evaluation_proofgen import full_generation_metrics

def evaluate(name2gens):    
    name2metrics_object = {}
    df_data = []
    for name, gens in tqdm(name2gens.items(), total=len(name2gens)):
        metrics, metrics_object = full_generation_metrics(
            gens['full_generations'], 
            ds_base, 
            ds_generations,
            redirects,
            tokenizer, 
            return_metrics_object=True
        )
        name, run = name.split('___')
        for span in name.split('__'):
            if '=' in span:
                key, value = span.split('=')
                metrics[key] = value
        metrics['name'] = name
        metrics['run'] = run
        df_data.append(metrics)
        name2metrics_object[name] = metrics_object

    df = pd.DataFrame(df_data)
    return df

In [None]:
from copy import deepcopy

default_cols = ['ppl', 'gleu', 'ref_f1', 'kf1', 'f1', 'corpus_ref_halluc']
def display(df, cols=default_cols, sort=None):
    df = deepcopy(df)  
    if cols is None:
        min_cols = ['corpus_ref_halluc', 'ppl']
        df_show = df[['run', 'ppl', 'name', 'f1', 'gleu', 'ref_precision', 'ref_recall', 'ref_f1', 'corpus_ref_halluc']]    
    else:
        min_cols = [c for c in ['corpus_ref_halluc', 'ppl'] if c in cols]
        show_cols = cols
        df_show = df[['run', 'name'] + show_cols]   

    df_show = df_show.groupby('name').mean().round(4)
    df_show = df_show.round(2)
    if sort is not None:
        df_show = df_show.sort_values(sort)

    return df_show

### GPT-3

In [None]:
split = 'test'

name2gen_files = {    
    'gpt3___1': '../other/naturalprover_generations/gpt3.pkl',
    'naturalprover-retrieve___1': '../other/naturalprover_generations/naturalprover_retrieved.pkl',
    'naturalprover___1': '../other/naturalprover_generations/naturalprover.pkl',    
    'naturalprover++___1': '../other/naturalprover_generations/naturalprover_plusplus.pkl',
}
name2gen = load_gens(name2gen_files, split=split)
df = evaluate(name2gen)
display(df, cols=['gleu', 'f1', 'kf1', 'ref_precision', 'ref_recall', 'ref_f1', 'corpus_ref_halluc'], sort='ref_f1')

### GPT-J/2

In [None]:
split='valid'
name2gen_files = {    
    'gpt3curie-greedy___1': '../other/naturalprover_generations/naturalprover.pkl',
    'gptj6b-greedy___1': '../other/naturalprover_generations/naturalprover_gptj6b.json',
    'gpt2-greedy___1': '../other/naturalprover_generations/naturalprover_gpt2.json',
}
name2gen = load_gens(name2gen_files, split=split)
df = evaluate(name2gen)
display(df, cols=['gleu', 'ref_f1', 'corpus_ref_halluc'], sort='ref_f1')