## CREST Statistics

In [1]:
import pandas as pd
import numpy as np

df = pd.read_excel("../data/causal/crest.xlsx")

# check if split is nan, then set the split to train
df.loc[np.isnan(df['split']),'split'] = 0

# check if there's no more nan split value
assert len(df.loc[np.isnan(df['split'])]) == 0

In [3]:
df.head()

Unnamed: 0.1,Unnamed: 0,original_id,span1,span2,signal,context,idx,label,source,ann_file,split
0,0,1,['tumor shrinkage'],['radiation therapy'],[],The period of tumor shrinkage after radiation ...,"{'span1': [[14, 29]], 'span2': [[36, 53]], 'si...",2,1,,0.0
1,1,2,['Habitat degradation'],['stream channels'],[],Habitat degradation from within stream channel...,"{'span1': [[0, 19]], 'span2': [[32, 47]], 'sig...",0,1,,0.0
2,2,3,['discomfort'],['traveling'],[],Earplugs relieve the discomfort from traveling...,"{'span1': [[21, 31]], 'span2': [[37, 46]], 'si...",2,1,,0.0
3,3,4,['daily terror'],['antipersonnel land mines'],[],We continue to see progress toward a world fre...,"{'span1': [[55, 67]], 'span2': [[71, 95]], 'si...",2,1,,0.0
4,4,5,['segment'],['anecdotes'],[],The Global Warming segment starts off with two...,"{'span1': [[19, 26]], 'span2': [[53, 62]], 'si...",0,1,,0.0


In [2]:
print('train: {}, dev: {}, test: {}'.format(len(df.loc[df['split'] == 0]), len(df.loc[df['split'] == 1]), len(df.loc[df['split'] == 2])))

train: 13622, dev: 1586, test: 4371


## Average Span Length
We want to see what find the average number of tokens in the span arguments of a causal relation in CREST.

In [20]:
import ast

sources = [1, 2, 3, 4, 5, 6, 7, 8]
for source in sources:
    df_source = df.loc[df['source'] == source]
    sum_length = 0
    n_count = 0
    for index, row in df_source.iterrows():
        span1 = ' '.join(ast.literal_eval(row['span1']))
        span2 = ' '.join(ast.literal_eval(row['span2']))
        if span1.strip() != "":
            sum_length += len(span1.strip().split(' '))
            n_count += 1
        if span2.strip() != "":
            sum_length += len(span2.strip().split(' '))
            n_count += 1
    try:
        print("source {}: {}".format(source, sum_length / n_count))
    except:
        pass

source 1: 1.3068181818181819
source 2: 1.0458150601847531
source 3: 1.0
source 4: 1.001572327044025
source 5: 1.16545245398773
source 6: 1.1015187849720225
source 7: 8.927641099855283
source 8: 5.6195


In [3]:
test_df = df.loc[df['split'] == 2]
dev_df = df.drop(test_df.index)
source_groups = dev_df.groupby(dev_df["source"]).size().reset_index(name='count')
label_groups = dev_df.groupby(dev_df["label"]).size().reset_index(name='count')

In [4]:
source_groups

Unnamed: 0,source,count
0,1,140
1,2,8000
2,3,71
3,4,318
4,5,2608
5,6,2342
6,7,729
7,8,1000


In [5]:
label_groups

Unnamed: 0,label,count
0,0,9852
1,1,2695
2,2,2661
