In [1]:
from decomp import UDSCorpus
import pandas as pd

In [82]:
class data_prep():
    def __init__(self):
        self.data_list = []
    
    def check_item(self, sent_id, graph):
        tokenized = graph.sentence
        for (arg0, arg1), edge in graph.semantics_edges().items():
            if 'protoroles' in edge:
                arg0 = arg0.split('-')
                arg1 = arg1.split('-')
                if arg0[-2] == 'pred':
                    pred = int(arg0[-1]) - 1
                    arg = int(arg1[-1]) - 1
                else:
                    pred = int(arg1[-1]) - 1
                    arg = int(arg0[-1]) - 1
                
                if not self.check_forces(pred, arg, tokenized, edge['protoroles']):
                    if not self.check_agent(pred, arg, tokenized, edge['protoroles']):
                        if not self.check_patient(pred, arg, tokenized, edge['protoroles']):
                            if not self.check_instrument(pred, arg, tokenized, edge['protoroles']):
                                if not self.check_manner(pred, arg, tokenized, edge['protoroles']):
                                    self.data_list.append((pred, arg, tokenized, 'NONE'))
                
    def check_agent(self, pred_id, arg_id, tokenized, graph):
        tmp = {'value': 0}
        if (graph.get('volition', tmp)['value'] > 0 or graph.get('instigation', tmp)['value'] > 0) and graph.get('existed_before',tmp)['value'] > 0:
            self.data_list.append((pred_id, arg_id, tokenized, 'AGENT'))
            return True
        return False
    
    def check_patient(self, pred_id, arg_id, tokenized, graph):
        tmp = {'value': 0}
        if (graph.get('change_of_state', tmp)['value'] > 0 or graph.get('change_of_state_continuous', tmp)['value'] > 0) and graph.get('instigated',tmp)['value'] <= 0 and graph.get('volition',tmp)['value'] <= 0:
            self.data_list.append((pred_id, arg_id, tokenized, 'PATIENT'))
            return True
        return False
    
    def check_instrument(self, pred_id, arg_id, tokenized, graph):
        tmp = {'value': 0}
        if graph.get('was_used', tmp)['value'] > 0 and graph.get('volition', tmp)['value'] <= 0 and graph.get('sentient', tmp)['value'] <= 0 and graph.get('awareness', tmp)['value'] <= 0:
            self.data_list.append((pred_id, arg_id, tokenized, 'INSTRUMENT'))
            return True
        return False
    
    def check_manner(self, pred_id, arg_id, tokenized, graph):
        tmp = {'value': 0}
        if graph.get('manner', tmp)['value'] > 0:
            self.data_list.append((pred_id, arg_id, tokenized, 'MANNER'))
#             print(f"{tokenized.split()[pred_id]}|{tokenized.split()[arg_id]}| {tokenized}")
            return True
        return False
    
    def check_forces(self, pred_id, arg_id, tokenized, graph):
        tmp = {'value': 0}
        if graph.get('instigation', tmp)['value'] > 0  and graph.get('sentient', tmp)['value'] <= 0 and graph.get('volition', tmp)['value'] <= 0 and graph.get('awareness', tmp)['value'] <= 0 and graph.get('existed_during',tmp)['value'] > 0:
            self.data_list.append((pred_id, arg_id, tokenized, 'FORCES'))
#             print(f"{tokenized.split()[pred_id]}|{tokenized.split()[arg_id]}| {tokenized}")
            return True
        return False
    
    def get_df(self):
        temp_list = []
        for role in self.data_list:
            temp_list.append({"pred_ind": role[0], "arg_ind": role[1], "tokenized": role[2], "label": role[3]})
        return pd.DataFrame(temp_list)
        

In [17]:
uds_train = UDSCorpus(split='train')

In [83]:
dp = data_prep()
for sent_id, r in uds_train.items():
    dp.check_item(sent_id, r)

In [84]:
dp.get_df().to_csv('dataset_train.tsv', sep='\t')

# DEV

In [77]:
uds_dev = UDSCorpus(split='dev')

In [85]:
dp_dev = data_prep()
for sent_id, r in uds_dev.items():
    dp_dev.check_item(sent_id, r)
dp_dev.get_df().groupby('label').size()

label
AGENT         312
FORCES         95
INSTRUMENT     81
MANNER         38
NONE          151
PATIENT        74
dtype: int64

In [86]:
dp_dev.get_df().to_csv('dataset_dev.tsv', sep='\t')

# Test

In [63]:
uds_test = UDSCorpus(split='test')

In [87]:
dp_test = data_prep()
for sent_id, r in uds_test.items():
    dp_test.check_item(sent_id, r)
dp_test.get_df().groupby('label').size()

label
AGENT         297
FORCES         94
INSTRUMENT     72
MANNER         33
NONE          113
PATIENT        61
dtype: int64

In [88]:
dp_test.get_df().to_csv('dataset_test.tsv', sep='\t')