In [None]:
import numpy as np
import pandas as pd
import json
import glob
import os

# conduct error analysis
data = []
path = 'what/exp1_zh/init_left_iter_none/P1001.jsonl'
for line in open(path, "r", encoding="utf8").readlines():
    dic = json.loads(line)
    keys = [ 'prompt', 'sub_label', 'obj_label', 'pred']
    res = [dic[i] for i in keys[:-1] ]
    res.append(''.join(dic['pred'][dic['num_mask']-1]))
    if res[-2] in res[-1] and res[-2] != res[-1]:
        data.append(res)

In [155]:
# calculate single/multi token accuracies
def counting(json_list):
    multi_correct = 0
    multi_n = 0
    single_correct = 0
    single_n = 0
    j = 0
    for json_str in json_list:
        result = json.loads(json_str)
        obj = result['tokenized_obj_label_inflection']
        pred_ind = result['num_mask'] - 1
        pred = result['pred'][pred_ind]
        correct = int(pred == obj)
        if len(obj) > 1:
            multi_n += 1
            multi_correct += correct
        else:
            single_n += 1
            single_correct += correct
    return multi_correct, multi_n, single_correct, single_n

def cal_acc(folder_name):
    path = folder_name + '/'
    mc, mn, sc, sn = 0, 0, 0, 0
    for f_name in os.listdir(path):
        if not f_name.endswith('jsonl'):
            continue
        with open(path + f_name, 'r') as json_file:
            json_list = list(json_file)
            a, b, c, d = counting(json_list)
            mc += a
            mn += b
            sc += c
            sn += d
    return mc/mn, sc/sn

In [None]:
# extract accuracy results from .out files
def extract_result(path):
    data = []
    _, exp_group, name = path.split('\\')
    folder_name = path[:-4]
    if exp_group not in ['exp1_en', 'exp1_zh']:
        return None
    data.append(exp_group[-2:])
    name = name[:-4]
    components = name.split('_')
    if len(components)>4:
        return None
    data+= [components[i] for i in [1,3]]
    
    for line in open(path, "r", encoding="utf8").readlines()[-1:]:
        res = [line.split()[i] for i in [3, 7]]
        data.append(float(res[0][-6:])*100)
        data.append(float(res[1][:6])*100)
        multi_acc, single_acc = cal_acc(folder_name)
        data+=[single_acc*100, multi_acc*100]

    return data

files = glob.glob('what/**/*.out')
res = []
for file in files:
    res.append(extract_result(file))
    
df = pd.DataFrame([i for i in res if i is not None], columns = ['lang', 'init','refine', 'acc per fact', 'acc per relation', 'single', 'multi'])
df = df.drop(['acc per relation'], axis=1)
df = df.groupby(by=[ "init", 'refine', 'lang']).mean().round(2).reset_index()
df = df.astype('str')
df = pd.merge(df[df['lang'] =='en'].reset_index(drop = True).drop('lang', axis=1), df[df['lang'] =='zh'].reset_index(drop = True).drop('lang', axis=1), on=['init','refine']).values


In [None]:
# generate latex source text
def print_table(df):
    prev_t=1
    now_t=2
    prev_f=1
    now_f=2
    j='&'

    for row in df:
        lists = list(row)
        now_t = lists[0]
        now_f = lists[1]    
        if prev_t!= now_t:
            print('\hline')
        else:
            lists[0] = '~'
            if prev_f == now_f:
                lists[1] = '~'
                print('\cline{4-6}')
        prev_t = now_t
        prev_f = now_f
        new_line = j.join(lists)
        print(new_line+'\\\\')
    print('\hline')

print_table(df)