In [91]:
import json
from os import path
from collections import defaultdict
import os

In [92]:
def get_token_to_pred(seq_words, bio_tags):
    bio_tags = bio_tags[1:-1]
    token_idx = 0
    word_tag_pair = []
    for tag in bio_tags:
        if tag == 'X':
            continue
        else:
            word_tag_pair.append((seq_words[token_idx], tag))
            token_idx += 1
    assert (len(word_tag_pair) == len(seq_words))
    
    return word_tag_pair

In [93]:
def get_tag_to_token_indices(word_tag_pair):
    tag_to_token_indices = defaultdict(list)
    for word_idx, (word, tag) in enumerate(word_tag_pair): 
        if tag == 'O':
            continue
#         if 'N-' in tag:
#             tag = tag.replace('N-', '')
        tag_to_token_indices[tag].append(word_idx)
        
    return tag_to_token_indices

In [100]:
def process_split(input_dir, output_dir, split='train'):
    split_file = f'{split}.jsonl'
    input_file = path.join(input_dir, split_file)
    output_file = path.join(output_dir, split_file)
    
    tag_stats = defaultdict(int)
    coverage = 0
    total_sent = 0
    
    output_dict = {}
    with open(input_file) as f:
        for line in f.readlines():
            instance = json.loads(line.strip())
            word_tag_pair = get_token_to_pred(instance["seq_words"], instance["BIO"])
            tag_to_token_indices = get_tag_to_token_indices(word_tag_pair)
            
            pred_sense = instance['pred_sense']
            total_sent += 1

            if len(tag_to_token_indices):   
                found_pred = False
                for tag in tag_to_token_indices:
                    if (tag == 'V' or '-V' in tag) and (pred_sense[0] in tag_to_token_indices[tag]):
                        found_pred = True
                        break
                
                if found_pred:
                    coverage += 1
                    for tag in tag_to_token_indices:
                        tag_stats[tag] += 1
                    
                    # JSON doesn't accept tuple as key, hence the funny encoding
                    output_dict[" ".join(instance["seq_words"]) + "\t" + str(pred_sense[0])] = {'tag_dict': tag_to_token_indices}
                    
#                 else:
#                     print(instance)
            
    with open(output_file, 'w') as output_f:
        print(len(output_dict))
        output_f.write(json.dumps(output_dict))
    
    print(coverage, total_sent, round(coverage/total_sent, 3))
    if split == 'train':
        print(sorted(list(tag_stats.items()), key=lambda x: x[1], reverse=True))

In [101]:
input_dir = "/home/shtoshni/Research/events/proc_data/conll09/bertsrl_output/"
output_dir = "/home/shtoshni/Research/events/proc_data/conll09/bertsrl_output_proc/"

splits = ['train', 'dev', 'test']
if not path.exists(output_dir):
    os.makedirs(output_dir)
    
for split in splits:
    process_split(input_dir, output_dir, split)

3828
4846 4880 0.993
[('A1', 3310), ('V', 2832), ('A0', 2433), ('N-V', 1809), ('A2', 1008), ('AM-TMP', 735), ('AM-LOC', 369), ('AM-MOD', 298), ('AM-ADV', 295), ('AM-MNR', 281), ('A3', 174), ('AM-NEG', 118), ('A0-N-V', 113), ('A1-N-V', 107), ('R-A1', 88), ('R-A0', 87), ('AM-PNC', 86), ('AM-DIR', 82), ('AM-CAU', 74), ('AM-DIS', 73), ('R-AM-TMP', 44), ('A4', 35), ('A2-N-V', 25), ('R-AM-CAU', 21), ('R-AM-LOC', 15), ('AM-EXT', 13), ('R-AM-MNR', 8), ('R-A2', 2), ('C-A1', 2)]
937
1043 1048 0.995
5769
5842 5853 0.998
