In [1]:
import re
import json
import os
import sys
from ast import literal_eval
from datasets import load_metric
import pandas as pd
import itertools
# using https://github.com/MantisAI/nervaluate
from nervaluate import Evaluator

def get_BIO(text_w_pairs):
    tokens = []
    ce_tags = []
    next_tag = tag = 'O'
    for tok in text_w_pairs.split(' '):

        # Replace if special
        if '<ARG0>' in tok:
            tok = re.sub('<ARG0>','',tok)
            tag = 'B-C'
            next_tag = 'I-C'
        elif '</ARG0>' in tok:
            tok = re.sub('</ARG0>','',tok)
            tag = 'I-C'
            next_tag = 'O'
        elif '<ARG1>' in tok:
            tok = re.sub('<ARG1>','',tok)
            tag = 'B-E'
            next_tag = 'I-E'
        elif '</ARG1>' in tok:
            tok = re.sub('</ARG1>','',tok)
            tag = 'I-E'
            next_tag = 'O'

        tokens.append(clean_tok(tok))
        ce_tags.append(tag)
        tag = next_tag
    
    return tokens, ce_tags


def clean_tok(tok):
    # Remove all other tags: E.g. <SIG0>, <SIG1>...
    return re.sub('</*[A-Z]+\d*>','',tok) 


def read_predictions(submission_file):
    predictions = []
    with open(submission_file, "r") as reader:
        for line in reader:
            line = line.strip()
            if line:
                predictions.append(json.loads(line)['prediction'])
    return predictions


def keep_relevant_rows_and_unstack(ref_df, predictions):
    
    # Keep only causal examples
    predictions_w_true_labels = []
    eg_id_counter = []
    for i, row in ref_df.iterrows():
        if row.num_rs>0:
            p = predictions[i]
            if len(p)>row.num_rs:
                # Note if you predict more than the number of relations we have, we only keep the first few.
                # We completely ignore the subsequent predictions.
                p = p[:row.num_rs]
            elif len(p)<row.num_rs:
                # Incorporate dummy predictions if there are insufficient predictions
                p.extend([row.text]*(row.num_rs-len(p)))
            predictions_w_true_labels.extend(p)
            eg_id_counter.extend(list(range(row.num_rs)))
    ref_df = ref_df[ref_df['num_rs']>0].reset_index(drop=True)
    
    # Expand into single rows
    ref_df = ref_df.drop(['text_w_pairs'], axis=1)
    ref_df['causal_text_w_pairs'] = ref_df['causal_text_w_pairs'].apply(lambda x: literal_eval(x))
    ref_df = ref_df.explode('causal_text_w_pairs')
    ref_df = ref_df.rename(columns={'causal_text_w_pairs':'text_w_pairs'})
    ref_df['eg_id'] = eg_id_counter
    
    return ref_df.reset_index(drop=True), predictions_w_true_labels


# set save files
output_filename = os.path.join(r"D:\66 CausalMap\Panasonic-IDS\outs", 'scores.txt')
output_file = open(output_filename, 'w')

# read files
truth_file = r"D:\66 CausalMap\Panasonic-IDS\data\MIR_annotated_grouped.csv"
ref_df = pd.read_csv(truth_file, encoding="utf-8")
submission_answer_file = r"D:\66 CausalMap\Panasonic-IDS\outs\20230315_predictions.json"
pred_list = read_predictions(submission_answer_file)

# Convert
ref_df, pred_list = keep_relevant_rows_and_unstack(ref_df, pred_list)
assert(len(pred_list)==len(ref_df))
refs = [get_BIO(i) for i in ref_df['text_w_pairs']]
preds = [get_BIO(i) for i in pred_list]

In [2]:
ce_refs_all = []
ce_preds_all = []
for i in range(len(refs)):
    _, ce_ref = refs[i]
    _, ce_pred = preds[i]
    ce_refs_all.append(ce_ref)
    ce_preds_all.append(ce_pred)
    
evaluator = Evaluator(ce_refs_all, ce_preds_all, tags=['C', 'E'], loader="list")
results, results_by_tag = evaluator.evaluate()
results

{'ent_type': {'correct': 68,
  'incorrect': 9,
  'partial': 0,
  'missed': 21,
  'spurious': 3,
  'possible': 98,
  'actual': 80,
  'precision': 0.85,
  'recall': 0.6938775510204082,
  'f1': 0.7640449438202247},
 'partial': {'correct': 28,
  'incorrect': 0,
  'partial': 49,
  'missed': 21,
  'spurious': 3,
  'possible': 98,
  'actual': 80,
  'precision': 0.65625,
  'recall': 0.5357142857142857,
  'f1': 0.5898876404494383},
 'strict': {'correct': 28,
  'incorrect': 49,
  'partial': 0,
  'missed': 21,
  'spurious': 3,
  'possible': 98,
  'actual': 80,
  'precision': 0.35,
  'recall': 0.2857142857142857,
  'f1': 0.3146067415730337},
 'exact': {'correct': 28,
  'incorrect': 49,
  'partial': 0,
  'missed': 21,
  'spurious': 3,
  'possible': 98,
  'actual': 80,
  'precision': 0.35,
  'recall': 0.2857142857142857,
  'f1': 0.3146067415730337}}

In [3]:
def is_overlapping(x1,x2,y1,y2):
    return max(x1,y1) <= min(x2,y2)

is_overlapping(1,2,0,1)

True

In [4]:
def get_start_end_positions(list_of_tags,tag='C'):
    start = None
    end = None
    for i,t in enumerate(list_of_tags):
        if tag in t:
            if start is None:
                start=i
            end=i
    return start,end

get_start_end_positions(ce_ref,tag='E')

(0, 7)

In [5]:
def get_to(ce_pred, ce_ref):
    
    if len(set(ce_pred))==1 and ce_pred[0]=='O':
        return 0,0,0
        
    c_pred_start, c_pred_end = get_start_end_positions(ce_pred,tag='C')
    e_pred_start, e_pred_end = get_start_end_positions(ce_pred,tag='E')    
    c_ref_start, c_ref_end = get_start_end_positions(ce_ref,tag='C')
    e_ref_start, e_ref_end = get_start_end_positions(ce_ref,tag='E')
    
    if is_overlapping(c_ref_start,c_ref_end,c_pred_start, c_pred_end):
        c_to=1
    else:
        c_to=0
    if is_overlapping(e_ref_start,e_ref_end,e_pred_start, e_pred_end):
        e_to=1
    else:
        e_to=0
    if c_to==1 and e_to==1:
        to=1
    else:
        to=0

    return c_to, e_to, to


def get_combinations(list1,list2):
    return [list(zip(each_permutation, list2)) for each_permutation in itertools.permutations(list1, len(list2))]


def keep_best_combinations_only(row, refs, preds):
    best_to = []
    
    for points in get_combinations(row.id, row.id):
        token_overlap = []
        for a,b in list(points):
            _, ce_ref = refs[a]
            _, ce_pred = preds[b]
            c_to, e_to, to = get_to(ce_pred, ce_ref)
            token_overlap.append(to)
        if sum(token_overlap)>sum(best_to):
            best_to=token_overlap
    return best_to


do_best_combi = True

if do_best_combi:
    grouped_df = ref_df.copy()
    grouped_df['id'] = [[i] for i in grouped_df.index]
    grouped_df = grouped_df.groupby(['corpus','doc_id','sent_id'])[['eg_id','id']].agg({'eg_id':'count','id':'sum'}).reset_index()
    grouped_df = grouped_df[grouped_df['eg_id']>1]
    req_combi_ids = [item for sublist in grouped_df['id'] for item in sublist]

    # For examples that DO NOT require combination search
    regular_ids = list(set(range(len(preds)))-set(req_combi_ids))
else:
    regular_ids = list(set(range(len(preds))))


token_overlap = []
exact = 0
list_of_ids = []

for i in regular_ids:
    
    _, ce_pred = preds[i]
    _, ce_ref = refs[i]
    c_to, e_to, to = get_to(ce_pred, ce_ref)
    token_overlap.append(to)
    
    if ce_pred==ce_ref:
        exact+=1
    
    list_of_ids.append(i)
    
if do_best_combi:
    for _, row in grouped_df.iterrows():
        to = keep_best_combinations_only(row, refs, preds)
        token_overlap.extend(to)
        list_of_ids.extend(row.id)

    
print(sum(token_overlap),'/',len(token_overlap))
exact

35 / 49


8

In [6]:
req_combi_ids

[10, 11, 14, 15, 27, 28, 34, 35, 37, 38, 39, 40]

In [7]:
for i,p in enumerate([x for _, x in sorted(zip(list_of_ids, token_overlap))] ):
    print(p)
#     print(refs[i])
#     print(preds[i][1])

1
1
1
1
1
1
1
1
1
1
1
0
1
0
1
0
0
1
1
1
1
1
0
1
0
1
1
0
1
1
1
1
0
1
1
0
1
1
0
1
0
0
1
0
1
1
0
1
1
