### Read the rules

In [2]:
import json

with open('/storage/rvacareanu/data/softrules/rules/fsre_dataset/TACRED/enhanced_syntax.jsonl') as fin:
    data = []
    for line in fin:
        data.append(json.loads(line))


### Read the ids of the rules used for matching (i.e. only support sentences)

In [3]:
ids = set()
for k in [1, 5]:
    for seed in [0, 1, 2, 3, 4]:
        with open(f'/storage/rvacareanu/data/softrules/fsre_dataset/TACRED/dev_episodes/5_way_{k}_shots_10K_episodes_3q_seed_16029{seed}.json') as fin:
            episodes = json.load(fin)[0]
            for ep in episodes:
                for ss_for_relation in ep['meta_train']:
                    for ss in ss_for_relation:
                        if ss['relation'] == 'org:parents':
                            ids.add(ss['id'])


In [4]:
len(ids)

96

In [5]:
data_for_relation = [{**x, 'skip': True} for x in data if x['id'] in ids]
data_rest         = [{**x, 'skip': False} for x in data if x['id'] not in ids]

In [6]:
print(len(data_for_relation))
print(len(data_rest))

96
118265


In [7]:
# Sanity check to see that there are no duplicates here
# (on TACRED; other datasets might not have a 100% unique `id` field)
from collections import Counter
print(sorted(Counter([(x['id'], x['line_to_hash']) for x in data_for_relation]).items(), key=lambda x: -x[1])[:3])
print(sorted(Counter([(x['id'], x['line_to_hash']) for x in data_for_relation]).items(), key=lambda x: -x[1])[:3])

[(('e7798f70605ecb7a381c', '054118511d1d553b9ffca2e3717f8fad'), 1), (('e7798f7060a537709eb2', '0a81f8d4d29fe276f2a6f824c8a24389'), 1), (('e7798f7060417ee4589b', '0bc421dedbb2e614a611d0959bb6100b'), 1)]
[(('e7798f70605ecb7a381c', '054118511d1d553b9ffca2e3717f8fad'), 1), (('e7798f7060a537709eb2', '0a81f8d4d29fe276f2a6f824c8a24389'), 1), (('e7798f7060417ee4589b', '0bc421dedbb2e614a611d0959bb6100b'), 1)]


In [8]:
id_to_line_to_hash = {x['id']: x['line_to_hash'] for x in data_for_relation}
line_to_hash_to_id = {x['line_to_hash']: x['id'] for x in data_for_relation}

id_to_rule           = {x['id']: x for x in data_for_relation}
line_to_hash_to_rule = {x['line_to_hash']: x for x in data_for_relation}

### Intervention Data

##### Intervention 1

In [9]:
# intervention_data = []
# with open('../../../intervention_data/annotator1/org_parents_prepared.jsonl') as fin:
#     for line in fin:
#         loaded_line = json.loads(line)
#         if loaded_line['id'] != 'GLOBAL':
#             line_to_hash = id_to_line_to_hash[loaded_line['id']]
#         else:
#             line_to_hash = "GLOBAL"
#         intervention_data.append({'line_to_hash': line_to_hash, **loaded_line, 'skip': False})
    
# id_to_intervention_data = {x['id']: x for x in intervention_data}

# with open('../../../intervention_data/annotator1/enhanced_syntax.jsonl', 'w+') as fout:
#     for line in data_rest + data_for_relation + intervention_data:
#         _=fout.write(json.dumps(line))
#         _=fout.write('\n')
    


##### Intervention 2

In [10]:
intervention_data = []
with open('../../../intervention_data/annotator2/org_parents_prepared.jsonl') as fin:
    for line in fin:
        loaded_line = json.loads(line)
        if loaded_line['id'] != 'GLOBAL':
            line_to_hash = id_to_line_to_hash[loaded_line['id']]
        else:
            line_to_hash = "GLOBAL"
        intervention_data.append({'line_to_hash': line_to_hash, **loaded_line, 'skip': False})
    
id_to_intervention_data = {x['id']: x for x in intervention_data}

with open('../../../intervention_data/annotator2/enhanced_syntax_v3.jsonl', 'w+') as fout:
    for line in data_rest + data_for_relation + intervention_data:
        _=fout.write(json.dumps(line))
        _=fout.write('\n')
    

