In [10]:
import re
import json
import os
import sys
from ast import literal_eval
from datasets import load_metric
import pandas as pd
import itertools
# using https://github.com/MantisAI/nervaluate
from nervaluate import Evaluator

def get_BIO(text_w_pairs):
    tokens = []
    ce_tags = []
    next_tag = tag = 'O'
    for tok in text_w_pairs.split(' '):

        # Replace if special
        if '<ARG0>' in tok:
            tok = re.sub('<ARG0>','',tok)
            tag = 'B-C'
            next_tag = 'I-C'
        elif '</ARG0>' in tok:
            tok = re.sub('</ARG0>','',tok)
            tag = 'I-C'
            next_tag = 'O'
        elif '<ARG1>' in tok:
            tok = re.sub('<ARG1>','',tok)
            tag = 'B-E'
            next_tag = 'I-E'
        elif '</ARG1>' in tok:
            tok = re.sub('</ARG1>','',tok)
            tag = 'I-E'
            next_tag = 'O'

        tokens.append(clean_tok(tok))
        ce_tags.append(tag)
        tag = next_tag
    
    return tokens, ce_tags


def clean_tok(tok):
    # Remove all other tags: E.g. <SIG0>, <SIG1>...
    return re.sub('</*[A-Z]+\d*>','',tok) 


def read_predictions(submission_file):
    predictions = []
    with open(submission_file, "r") as reader:
        for line in reader:
            line = line.strip()
            if line:
                predictions.append(json.loads(line)['prediction'])
    return predictions


def keep_relevant_rows_and_unstack(ref_df, predictions):
    
    # Keep only causal examples
    predictions_w_true_labels = []
    eg_id_counter = []
    for i, row in ref_df.iterrows():
        if row.num_rs>0:
            p = predictions[i]
            if len(p)>row.num_rs:
                # Note if you predict more than the number of relations we have, we only keep the first few.
                # We completely ignore the subsequent predictions.
                p = p[:row.num_rs]
            elif len(p)<row.num_rs:
                # Incorporate dummy predictions if there are insufficient predictions
                p.extend([row.text]*(row.num_rs-len(p)))
            predictions_w_true_labels.extend(p)
            eg_id_counter.extend(list(range(row.num_rs)))
    ref_df = ref_df[ref_df['num_rs']>0].reset_index(drop=True)
    
    # Expand into single rows
    ref_df = ref_df.drop(['text_w_pairs'], axis=1)
    ref_df['causal_text_w_pairs'] = ref_df['causal_text_w_pairs'].apply(lambda x: literal_eval(x))
    ref_df = ref_df.explode('causal_text_w_pairs')
    ref_df = ref_df.rename(columns={'causal_text_w_pairs':'text_w_pairs'})
    ref_df['eg_id'] = eg_id_counter
    
    return ref_df.reset_index(drop=True), predictions_w_true_labels


# set save files
output_filename = os.path.join(r"D:\66 CausalMap\Panasonic-IDS\outs", 'scores.txt')
output_file = open(output_filename, 'w')

# read files
truth_file = r"D:\66 CausalMap\Panasonic-IDS\data\MIR_annotated_grouped.csv"
ref_df = pd.read_csv(truth_file, encoding="utf-8")
submission_answer_file = r"D:\66 CausalMap\Panasonic-IDS\outs\predictions.json"
pred_list = read_predictions(submission_answer_file)

# Convert
ref_df, pred_list = keep_relevant_rows_and_unstack(ref_df, pred_list)
assert(len(pred_list)==len(ref_df))
refs = [get_BIO(i) for i in ref_df['text_w_pairs']]
preds = [get_BIO(i) for i in pred_list]

In [11]:
ce_refs_all = []
ce_preds_all = []
for i in range(len(refs)):
    _, ce_ref = refs[i]
    _, ce_pred = preds[i]
    ce_refs_all.append(ce_ref)
    ce_preds_all.append(ce_pred)
    
evaluator = Evaluator(ce_refs_all, ce_preds_all, tags=['C', 'E'], loader="list")
results, results_by_tag = evaluator.evaluate()
results

{'ent_type': {'correct': 30,
  'incorrect': 9,
  'partial': 0,
  'missed': 17,
  'spurious': 3,
  'possible': 56,
  'actual': 42,
  'precision': 0.7142857142857143,
  'recall': 0.5357142857142857,
  'f1': 0.6122448979591837},
 'partial': {'correct': 10,
  'incorrect': 0,
  'partial': 29,
  'missed': 17,
  'spurious': 3,
  'possible': 56,
  'actual': 42,
  'precision': 0.5833333333333334,
  'recall': 0.4375,
  'f1': 0.5},
 'strict': {'correct': 9,
  'incorrect': 30,
  'partial': 0,
  'missed': 17,
  'spurious': 3,
  'possible': 56,
  'actual': 42,
  'precision': 0.21428571428571427,
  'recall': 0.16071428571428573,
  'f1': 0.1836734693877551},
 'exact': {'correct': 10,
  'incorrect': 29,
  'partial': 0,
  'missed': 17,
  'spurious': 3,
  'possible': 56,
  'actual': 42,
  'precision': 0.23809523809523808,
  'recall': 0.17857142857142858,
  'f1': 0.20408163265306123}}

In [12]:
def is_overlapping(x1,x2,y1,y2):
    return max(x1,y1) <= min(x2,y2)

is_overlapping(1,2,0,1)

True

In [13]:
def get_start_end_positions(list_of_tags,tag='C'):
    start = None
    end = None
    for i,t in enumerate(list_of_tags):
        if tag in t:
            if start is None:
                start=i
            end=i
    return start,end

get_start_end_positions(ce_ref,tag='E')

(0, 7)

In [14]:
c_token_overlap = []
e_token_overlap = []
token_overlap = []
exact = 0
for i in range(len(refs)):
    
    _, ce_pred = preds[i]
    if len(set(ce_pred))==1 and ce_pred[0]=='O':
        c_token_overlap.append(0)
        e_token_overlap.append(0)
        token_overlap.append(0)
        
        print('\n',i)
        print(pred_list[i])
        print(ce_pred)
        print(ce_ref)
        
        continue
    
    c_pred_start, c_pred_end = get_start_end_positions(ce_pred,tag='C')
    e_pred_start, e_pred_end = get_start_end_positions(ce_pred,tag='E')    
    _, ce_ref = refs[i]
    c_ref_start, c_ref_end = get_start_end_positions(ce_ref,tag='C')
    e_ref_start, e_ref_end = get_start_end_positions(ce_ref,tag='E')

    
    if is_overlapping(c_ref_start,c_ref_end,c_pred_start, c_pred_end):
        c_token_overlap.append(1)
    else:
        c_token_overlap.append(0)
    if is_overlapping(e_ref_start,e_ref_end,e_pred_start, e_pred_end):
        e_token_overlap.append(1)
    else:
        e_token_overlap.append(0)
    if c_token_overlap[-1]==1 and e_token_overlap[-1]==1:
        token_overlap.append(1)
    else:
        token_overlap.append(0)
        print('\n',i, c_token_overlap[-1], e_token_overlap[-1])
        print(e_pred_start, e_pred_end)
        print(e_ref_start,e_ref_end)
        print(pred_list[i])
        print(ref_df['text_w_pairs'][i])
        print(ce_pred)
        print(ce_ref)
    
    if ce_pred==ce_ref:
        exact+=1
    
print(sum(token_overlap),'/',len(token_overlap))
exact


 4 0 0
4 20
21 32
To kick-start the shift, <ARG1>the report suggests bulk procurement of electric vehicles, building standardized, swappable batteries for two- and three-wheelers </ARG1>to <ARG0>bring down their cost </ARG0>and having favorable tariff structures for charging cars.
To kick-start the shift, the report suggests <ARG0>bulk procurement of electric vehicles, building standardized, swappable batteries for two- and three-wheelers</ARG0> to <ARG1>bring down their cost and having favorable tariff structures for charging cars</ARG1>.
['O', 'O', 'O', 'O', 'B-E', 'I-E', 'I-E', 'I-E', 'I-E', 'I-E', 'I-E', 'I-E', 'I-E', 'I-E', 'I-E', 'I-E', 'I-E', 'I-E', 'I-E', 'I-E', 'I-E', 'B-C', 'I-C', 'I-C', 'I-C', 'I-C', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
['O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-C', 'I-C', 'I-C', 'I-C', 'I-C', 'I-C', 'I-C', 'I-C', 'I-C', 'I-C', 'I-C', 'I-C', 'I-C', 'O', 'B-E', 'I-E', 'I-E', 'I-E', 'I-E', 'I-E', 'I-E', 'I-E', 'I-E', 'I-E', 'I-E', 'I-E']

 6 1 0
0 13
16 26
<ARG1>F

1

In [15]:
token_overlap

[1,
 1,
 1,
 1,
 0,
 1,
 0,
 0,
 1,
 0,
 0,
 1,
 0,
 0,
 1,
 1,
 1,
 0,
 1,
 0,
 0,
 1,
 0,
 0,
 1,
 0,
 0,
 1]