## CREST Statistics

In [1]:
import pandas as pd
import numpy as np

df = pd.read_excel("../data/causal/crest.xlsx")

# check if split is nan, then set the split to train
df.loc[np.isnan(df['split']),'split'] = 0

# check if there's no more nan split value
assert len(df.loc[np.isnan(df['split'])]) == 0

In [2]:
df[df['source'].isin([7])].head()

Unnamed: 0.1,Unnamed: 0,original_id,span1,span2,signal,context,idx,label,source,ann_file,split
16850,16850,E1,['that has arisen'],['the past few years'],['over'],"And second, we should address the issue that h...","{'span1': [[40, 55]], 'span2': [[76, 94]], 'si...",0,7,CHRG-111shrg61651.ann,0.0
16851,16851,E4,['these banks are too big to fail'],"['they have lower funding costs, they are able...",['Because'],"Because these banks are too big to fail, they ...","{'span1': [[8, 39]], 'span2': [[41, 179]], 'si...",1,7,CHRG-111shrg61651.ann,0.0
16852,16852,E5,['they make more money'],['the cycle'],['over'],"Because these banks are too big to fail, they ...","{'span1': [[111, 131]], 'span2': [[137, 146]],...",0,7,CHRG-111shrg61651.ann,0.0
16853,16853,E6,['too big'],['fail'],"['too', 'to']","Because these banks are too big to fail, they ...","{'span1': [[24, 31]], 'span2': [[35, 39]], 'si...",1,7,CHRG-111shrg61651.ann,0.0
16854,16854,E7,['you look at the European situation today'],['it is much worse than what we have in this c...,['If'],"If you look at the European situation today, f...","{'span1': [[3, 43]], 'span2': [[58, 153]], 'si...",0,7,CHRG-111shrg61651.ann,0.0


In [3]:
print('train: {}, dev: {}, test: {}'.format(len(df.loc[df['split'] == 0]), len(df.loc[df['split'] == 1]), len(df.loc[df['split'] == 2])))

train: 13622, dev: 1586, test: 4371


## Span Length
We want to see what find the average number of tokens in the span arguments of a causal relation in CREST.

In [5]:
import ast

sources = [1, 2, 3, 4, 5, 6, 7]

span_length = {}

for source in sources:
    df_source = df[df['source'].isin([source])]
    sum_length = 0
    n_count = 0
    for index, row in df_source.iterrows():
        span1 = ' '.join(ast.literal_eval(row['span1']))
        span2 = ' '.join(ast.literal_eval(row['span2']))
        
        if span1.strip() != "" and span2.strip() != "":
            len_span = len(span1.strip().split(' '))
            len_span += len(span2.strip().split(' '))
        
        if len_span in span_length:
            span_length[len_span] += 1
        else:
            span_length[len_span] = 1
        
        sum_length += len_span
        n_count += 2
    
    assert n_count == len(df_source) * 2
    try:
        print("source {}: {}".format(source, sum_length / n_count))
    except:
        pass

source 1: 1.3068181818181819
source 2: 1.0458150601847531
source 3: 1.0
source 4: 1.001572327044025
source 5: 1.16545245398773
source 6: 1.1015187849720225
source 7: 8.568587105624143


In [9]:
span_length = dict(sorted(span_length.items()))

In [10]:
vals = {1: span_length[2], 2: 0}
for key, value in span_length.items():
    if key > 2:
        vals[2] += value
vals

{1: 14911, 2: 2668}

In [3]:
test_df = df.loc[df['split'] == 2]
dev_df = df.drop(test_df.index)
source_groups = dev_df.groupby(dev_df["source"]).size().reset_index(name='count')
label_groups = dev_df.groupby(dev_df["label"]).size().reset_index(name='count')

In [4]:
source_groups

Unnamed: 0,source,count
0,1,140
1,2,8000
2,3,71
3,4,318
4,5,2608
5,6,2342
6,7,729
7,8,1000


In [5]:
label_groups

Unnamed: 0,label,count
0,0,9852
1,1,2695
2,2,2661
