In [103]:
import os
from os import path
import glob
import pandas as pd
import json
from collections import defaultdict

In [104]:
perf_dir = "/share/data/speech/shtoshni/research/litbank_coref/models/perf/"
slurm_id = "6522299"

files = sorted(glob.glob(path.join(perf_dir, slurm_id + "*")), key=lambda x: int(path.splitext(path.basename(x))[0].split('_')[1]))

print(len(files))

30


In [105]:
model_dict_list = []
for file in files:
    model_dict_list.append(json.loads(open(file).read()))
    
# print(model_dict_list)

In [106]:
def determine_varying_attributes(model_dict_list, ignore_attribs=['train', 'test', 'test_1', 'dev', 'dev_1',  'pretrained_mention_model',
                                                                  'conll_data_dir','slurm_id', 'best_model_dir', 'data_dir']):
    attrib_to_vals = defaultdict(set)
    for model_dict in model_dict_list:
        for attrib, val in model_dict.items():
            if attrib in ignore_attribs:
                continue
            else:
                attrib_to_vals[attrib].add(val)
    
    
    varying_attribs = []
    for attrib, vals in attrib_to_vals.items():
        if len(vals) > 1:
            varying_attribs.append(attrib)
            
    return varying_attribs

In [107]:
varying_attribs = determine_varying_attributes(model_dict_list)
# perf_attribs = ['MUC' , 'Bcub', 'CEAFE']
perf_attribs = []
varying_attribs.extend(['mem_type', 'max_ents'])
print(varying_attribs)

['model_dir', 'cross_val_split', 'label_smoothing_wt', 'mem_type', 'max_ents']


### Load all dev and test f-scores

In [108]:
perf_df = pd.DataFrame(columns=(varying_attribs + ['devf', 'fs'] + perf_attribs))

for model_dict in model_dict_list:
    perf_dict = {}
    for attrib in varying_attribs:
        perf_dict[attrib] = model_dict[attrib]
        
    
    for perf_attrib in perf_attribs:
        if perf_attrib in model_dict['test']:
            attrib_dict = model_dict['test'][perf_attrib]
            perf_dict[perf_attrib] = (attrib_dict['recall'], attrib_dict['precision'], attrib_dict['fscore'])
        else:
            perf_dict[perf_attrib] = '-'
            
    
    perf_dict['devf'] = model_dict['dev']['fscore']
    perf_dict['fs'] = model_dict['test']['fscore']
    
    perf_df = perf_df.append(perf_dict, ignore_index=True)

### Filter by maximum dev-scores among hyperparams

In [109]:
perf_df = perf_df.rename(columns={"label_smoothing_wt": "ls_wt", "sample_invalid": "samp", "max_training_segments": "segs"})
idx = perf_df.groupby(['cross_val_split'])['devf'].idxmax()

dev_max_df = perf_df.iloc[idx]
dev_max_df

Unnamed: 0,model_dir,cross_val_split,ls_wt,mem_type,max_ents,devf,fs
0,/share/data/speech/shtoshni/research/litbank_c...,0,0.1,unbounded,,77.6,74.8
1,/share/data/speech/shtoshni/research/litbank_c...,1,0.1,unbounded,,76.1,78.6
12,/share/data/speech/shtoshni/research/litbank_c...,2,0.01,unbounded,,79.6,77.8
3,/share/data/speech/shtoshni/research/litbank_c...,3,0.1,unbounded,,77.5,77.4
4,/share/data/speech/shtoshni/research/litbank_c...,4,0.1,unbounded,,77.4,76.7
5,/share/data/speech/shtoshni/research/litbank_c...,5,0.1,unbounded,,77.1,74.1
6,/share/data/speech/shtoshni/research/litbank_c...,6,0.1,unbounded,,75.7,79.0
7,/share/data/speech/shtoshni/research/litbank_c...,7,0.1,unbounded,,79.0,76.8
8,/share/data/speech/shtoshni/research/litbank_c...,8,0.1,unbounded,,77.7,78.3
9,/share/data/speech/shtoshni/research/litbank_c...,9,0.1,unbounded,,78.3,77.0


### Get varying memory type and memory size configurations

In [110]:
# z = dev_max_df.groupby(['mem_type','max_ents']).size()
# print(z)

# multindex = z.axes[0]
# mem_types = list(multindex.get_level_values(0))
# max_ents = list(multindex.get_level_values(1))

# print(mem_types, num_cells)

In [111]:
mem_types = ['unbounded']
max_ents = [None]

### Get location of all conll output files

In [112]:
SPLIT = 'dev'
# SPLIT = 'test'


model_config_to_conll_files = []
    
conll_files = []
json_files = []
model_config = ('unbounded', 20)
for cross_val_split in range(10):
#     print(dev_max_df)
#     print(dev_max_df.loc[dev_max_df['cross_val_split'] == cross_val_split])
    model_dir = dev_max_df.loc[dev_max_df['cross_val_split'] == cross_val_split]['model_dir'].values[0]
#     print(model_dir)
    conll_file = path.join(model_dir, f'{SPLIT}.conll')
    conll_files.append(conll_file)
    json_files.append(path.join(model_dir, f'{SPLIT}.log.jsonl'))

model_config_to_conll_files.append((model_config, conll_files, json_files))

### Concat all Cross Val CoNLLs and JSONLs

In [113]:
output_dir = "../models/litbank_preds/"
if not path.exists(output_dir):
    os.makedirs(output_dir)


model_config_output_file_list = []    
for model_config, conll_files, jsonl_files in model_config_to_conll_files:
    conll_output_file = path.join(output_dir, f'{model_config[0]}_{model_config[1]}_{SPLIT}.conll')
    jsonl_output_file = path.join(output_dir, f'{model_config[0]}_{model_config[1]}_{SPLIT}.jsonl')
    
    model_config_output_file_list.append((model_config, conll_output_file))
    
    with open(conll_output_file, "w") as output_w:
        for conll_file in conll_files:
            with open(conll_file) as g:
                for line in g:
                    output_w.write(line)
                    
    with open(jsonl_output_file, "w") as output_w:
        for jsonl_file in jsonl_files:
            with open(jsonl_file) as g:
                for line in g:
                    output_w.write(line)
                    
        
                    

#### Setup coref evaluation script path and Gold CoNLL 

In [114]:
import sys
import subprocess
import re

gold_conll = f"/home/shtoshni/Research/litbank_coref/data/litbank/all.{SPLIT}.conll"
scorer_path = "/home/shtoshni/Research/litbank_coref/lrec2020-coref/reference-coreference-scorers/scorer.pl"

def get_coref_score(metric, path_to_scorer, gold=None, preds=None):
    output=subprocess.check_output(["perl", path_to_scorer, metric, preds, gold]).decode("utf-8")
    output=output.split("\n")[-3]
    matcher=re.search("Coreference: Recall: \(.*?\) (.*?)%	Precision: \(.*?\) (.*?)%	F1: (.*?)%", output)
    if matcher is not None:
        recall=float(matcher.group(1))
        precision=float(matcher.group(2))
        f1=float(matcher.group(3))
    return recall, precision, f1

In [115]:
metrics = ['MUC', 'Bcub', 'CEAFE']
for model_config, conll_file in model_config_output_file_list:
    print(f"\\{model_config[0]} & {model_config[1]}", end="")
    fscore_list = []
    for metric in metrics:
        recall, precision, fscore = get_coref_score(metric.lower(), scorer_path, gold_conll, conll_file)
        print(f" & {recall:.1f} & {precision:.1f} & {fscore:.1f} ", end="")
        fscore_list.append(fscore)
    
    print(f"& {sum(fscore_list)/len(fscore_list): .1f}")


\unbounded & 20 & 90.2 & 87.2 & 88.7  & 80.1 & 75.2 & 77.6  & 67.6 & 65.8 & 66.7 &  77.6
