## Output space converter

This is used for the `both` (proofwiki+stacks) joint/autoregressive model.

The model's output space is over proofwiki and stacks reference ids.

When evaluating on an individual dataset, we need to map the model's token id in the combined space to correspond to the reference id from the individual dataset.

The end result is a `tok2tok.pkl` file that is used during evaluation.

In [None]:
import pandas as pd
import glob
import pickle
import os
import json
from pprint import pprint as pp
from collections import defaultdict
import torch
from tqdm import tqdm
from pathlib import Path

In [None]:
base = './data'
outdir = './other'

In [None]:
both = json.load(open(os.path.join(base, 'naturalproofs_both.json')))['dataset']
pw = json.load(open(os.path.join(base, 'naturalproofs_proofwiki.json')))['dataset']
stacks = json.load(open(os.path.join(base, 'naturalproofs_stacks.json')))['dataset']

In [None]:
bothrid2pwrid = {}

rid2label = defaultdict(lambda: defaultdict(str))
label2rid = defaultdict(lambda: defaultdict(str))

for r in both['theorems'] + both['definitions'] + both['others']:
    rid2label['both'][r['id']] = r['label']
    label2rid['both'][r['label']] = r['id']
    
for r in pw['theorems'] + pw['definitions'] + pw['others']:
    rid2label['pw'][r['id']] = r['label']
    label2rid['pw'][r['label']] = r['id']
    
for r in stacks['theorems'] + stacks['definitions'] + stacks['others']:
    rid2label['stacks'][r['id']] = r['label']
    label2rid['stacks'][r['label']] = r['id']

In [None]:
rid2rid = {
    'both2pw': {},
    'both2stacks': {},
}

for rid in rid2label['pw']:
    label = rid2label['pw'][rid]
    rid_both = label2rid['both'][label]
    rid2rid['both2pw'][rid_both] = rid
    
for rid in rid2label['stacks']:
    label = rid2label['stacks'][rid]
    rid_both = label2rid['both'][label]
    rid2rid['both2stacks'][rid_both] = rid

In [None]:
autoreg_both = pickle.load(
    open(os.path.join(base, 'sequence_both__bert-base-cased.pkl'), 'rb')
)

autoreg_pw = pickle.load(
    open(os.path.join(base, 'sequence_proofwiki__bert-base-cased.pkl'), 'rb')
)

autoreg_stacks = pickle.load(
    open(os.path.join(base, 'sequence_stacks__bert-base-cased.pkl'), 'rb')
)

In [None]:
tok2tok = {
    'both2pw': {},
    'both2stacks': {},
}

both_tok2rid = {}
pw_tok2rid = {}
stacks_tok2rid = {}

for rid, tok in autoreg_both['rid2tok'].items():
    both_tok2rid[tok] = rid
    

for both_tok, both_rid in both_tok2rid.items():
    if both_rid not in {'<pad>', '<bos>', '<eos>'}:
        if both_rid in rid2rid['both2pw']:
            pw_rid = rid2rid['both2pw'][both_rid]
            pw_tok = autoreg_pw['rid2tok'][pw_rid]
            tok2tok['both2pw'][both_tok] = pw_tok
    
        if both_rid in rid2rid['both2stacks']:
            stacks_rid = rid2rid['both2stacks'][both_rid]
            stacks_tok = autoreg_stacks['rid2tok'][stacks_rid]
            tok2tok['both2stacks'][both_tok] = stacks_tok
                
    else:
        tok2tok['both2pw'][both_tok] = autoreg_pw['rid2tok'][both_rid]
        tok2tok['both2stacks'][both_tok] = autoreg_stacks['rid2tok'][both_rid]

In [None]:
pickle.dump(rid2rid, open(os.path.join(outdir, 'rid2rid.pkl'), 'wb'))
pickle.dump(tok2tok, open(os.path.join(outdir, 'tok2tok.pkl'), 'wb'))